import { ReactNode, useRef } from 'react'
import { Stack } from '@mui/material'
import {
  ResponsiveScatterPlot,
  ScatterPlotDatum,
  ScatterPlotNodeData,
  ScatterPlotRawSerie,
  ScatterPlotSvgProps,
  ScatterPlotValue,
} from '@nivo/scatterplot'
import TooltipCard from './TooltipCard'
import constants from './chart-constants'

type ScatterPlotChartProps = {
  data: (ScatterPlotRawSerie<ScatterPlotDatum> & { color?: string })[]
  dataHasColor?: boolean
  height?: number | string
  label?: string
  tooltipContent?: (
    node: ScatterPlotNodeData<ScatterPlotDatum>,
    stackedNodes: {
      id: string | number
      color?: string
      x: ScatterPlotValue
      y: ScatterPlotValue
    }[],
  ) => ReactNode
  xLegend?: string
  yLegend?: string
} & Omit<ScatterPlotSvgProps<ScatterPlotDatum>, 'data' | 'width' | 'height'>

export const formatGroupWithOnlyOnePoint = (
  item: Record<string, string | number>,
  idKey: string,
  xKey: string,
  yKey: string,
) => {
  return {
    id: item[idKey],
    color: item.color as string,
    data: [{ x: item[xKey], y: item[yKey] }],
  }
}

export default function ScatterPlotChart({
  data,
  dataHasColor,
  height = 420,
  label,
  tooltip,
  tooltipContent,
  axisBottom,
  axisLeft,
  margin,
  xFormat,
  xLegend,
  yFormat,
  yLegend,
  ...props
}: ScatterPlotChartProps) {
  const chartRef = useRef<HTMLDivElement>(null)

  const noDataProps =
    data.length === 0
      ? {
          gridXValues: 0,
          gridYValues: 1,
          xScale: { min: 0, max: 1, type: 'linear' as const },
          yScale: { min: 0, max: 1, type: 'linear' as const },
          markers: [
            {
              axis: 'y' as const,
              value: 0.5,
              lineStyle: {
                stroke: 'transparent',
              },
              legend: 'No Data',
              textStyle: {
                fill: '#ffffff',
              },
              legendPosition: 'top' as const,
              legendOrientation: 'horizontal' as const,
            },
          ],
        }
      : {}

  return (
    <Stack sx={{ height }} ref={chartRef}>
      {label && (
        <Stack direction="row" spacing={0.5}>
          {label}
        </Stack>
      )}
      {/* ResponsiveScatterPlot needs to be wrapped by a container with a set height */}
      <ResponsiveScatterPlot
        data={data}
        colors={
          dataHasColor
            ? ({ serieId }) =>
                data.find(({ id }) => id === serieId)?.color || ''
            : undefined
        }
        margin={{
          top: constants.MARGIN_TOP,
          right: constants.MARGIN_RIGHT,
          bottom: xLegend
            ? constants.MARGIN_BOTTOM_WITH_LEGEND
            : constants.MARGIN_BOTTOM,
          left: yLegend
            ? constants.MARGIN_LEFT_WITH_LEGEND
            : constants.MARGIN_LEFT,
          ...margin,
        }}
        xFormat={xFormat}
        axisBottom={{
          format: xFormat,
          legend: xLegend,
          legendPosition: 'middle',
          legendOffset: 40,
          ...axisBottom,
          tickValues: data.length ? axisBottom?.tickValues : [0, 1],
        }}
        yFormat={yFormat}
        axisLeft={{
          format: yFormat,
          legend: yLegend,
          legendPosition: 'middle',
          legendOffset: -50,
          ...axisLeft,
          tickValues: data.length ? axisLeft?.tickValues : [0, 1],
        }}
        tooltip={
          tooltipContent
            ? ({ node }) => {
                const allNodesData = data.flatMap(({ id, color, data }) =>
                  data.map((point) => ({ ...point, id, color })),
                )
                const stackedNodes = allNodesData.filter(
                  (nodeData) =>
                    nodeData.x == node.data.x && nodeData.y == node.data.y,
                )
                return (
                  <TooltipCard>
                    {tooltipContent(node, stackedNodes)}
                  </TooltipCard>
                )
              }
            : tooltip
        }
        theme={{
          axis: {
            legend: {
              text: {
                fontSize: 14,
                fill: '#d2d2d2',
              },
            },
            ticks: {
              text: {
                fill: '#b3b3b3',
              },
            },
          },
          grid: {
            line: {
              stroke: '#686868',
            },
          },
        }}
        useMesh={false}
        {...noDataProps}
        {...props}
      />
    </Stack>
  )
}
