import { useEffect, useMemo, useState } from 'react'
import { ScatterPlotDatum, ScatterPlotValue } from '@nivo/scatterplot'
import { FilterList } from '@mui/icons-material'
import {
  IconButton,
  Stack,
  Tooltip,
  Typography,
  TypographyProps,
  useTheme,
} from '@mui/material'
import { ProjectGridField } from '@/types/fields'
import { MetricsProjectModelVersion } from '@/types/project-models'
import useOverlay from '@/hooks/useOverlay'
import { getColor } from '@/utils/chart-utils'
import { getNonGridFields, sortBySortOrder } from '@/utils/field-utils'
import LineChart from '@/components/charts/LineChart'
import VisibleFieldsDialog from './VisibleFieldsDialog'
import { useGetProjectGrids } from '@/service-library/hooks/project-grids'

type ModelVersionsFieldsChartProps = {
  modelVersions: MetricsProjectModelVersion[]
  projectId: string | undefined
  height?: number
  titleVariant?: TypographyProps['variant']
  labelSx?: TypographyProps['sx']
}

export default function ModelVersionsFieldsChart({
  modelVersions,
  projectId,
  height = 350,
  titleVariant = 'h5',
  labelSx,
}: ModelVersionsFieldsChartProps) {
  const theme = useTheme()
  const [visibleFieldIds, setVisibleFieldIds] = useState<string[]>([])
  const fieldsFilterOverlay = useOverlay()
  const { projectGrids, isLoading } = useGetProjectGrids({
    filters: {
      limit: '1000',
      project_id: projectId,
      fields__include: 'project_grid_fields,sub_project_grid_fields',
      project_grid_fields__fields__include: 'project_grid_field_type',
    },
    enabled: !!projectId,
    refetchOnWindowFocus: true,
  })
  const allVersionsAreZero = modelVersions.every(
    (version) => version.version === 0,
  )
  const sortedModelVersions = allVersionsAreZero
    ? modelVersions.slice(-20)
    : modelVersions
        .sort((versionA, versionB) => versionA.version - versionB.version)
        .slice(-20)

  const projectGridFields = useMemo(
    () =>
      projectGrids.reduce<ProjectGridField[]>((acc, grid) => {
        //TODO: Remove sorting for when the backend returns data sorted
        const tempProjectGridFields = sortBySortOrder(
          getNonGridFields(grid.project_grid_fields),
        )
        return tempProjectGridFields.length > 0
          ? [...acc, ...tempProjectGridFields]
          : acc
      }, []),
    [projectGrids],
  )

  const visibleFields = projectGridFields.filter(({ id }) =>
    visibleFieldIds.includes(id),
  )

  const data: {
    id: string
    color: string
    data: ScatterPlotDatum[]
  }[] = []
  const newFieldNames = new Set<string>()

  visibleFields.forEach((field, index) => {
    const dataPoints = sortedModelVersions.reduce<ScatterPlotDatum[]>(
      (acc, modelVersion, index) => {
        const fieldInModelVersion =
          modelVersion.project_model_version_fields?.find(
            ({ project_grid_field_id }) => project_grid_field_id === field.id,
          )
        if (fieldInModelVersion) {
          const modelVersionVersion = allVersionsAreZero
            ? (index + 1).toFixed(1)
            : Number.isInteger(modelVersion.version)
            ? modelVersion.version.toFixed(1)
            : modelVersion.version
          acc.push({
            x: modelVersionVersion,
            y: fieldInModelVersion.accuracy,
          })
        }
        return acc
      },
      [],
    )
    if (!dataPoints.length) return

    let fieldName = field.name || ''
    while (newFieldNames.has(fieldName)) {
      fieldName += ' ' // Putting a space makes the id unique [sc-11584]
    }
    newFieldNames.add(fieldName)

    const colorIndex = theme.palette.mode === 'dark' ? 200 : 400

    data.push({
      id: fieldName,
      color: getColor(index, colorIndex),
      data: dataPoints,
    })
  })

  useEffect(() => {
    if (!isLoading) {
      const projectGridFieldIds = projectGridFields.map(({ id }) => id)
      setVisibleFieldIds(projectGridFieldIds)
    }
  }, [isLoading, projectGridFields])

  return (
    <>
      <LineChart
        label={
          <Stack direction="row" spacing={1} alignItems="center" sx={labelSx}>
            <Typography component="h2" variant={titleVariant}>
              Model Version Field Accuracy
            </Typography>
            <Tooltip title="Filter" enterDelay={1000}>
              <IconButton onClick={fieldsFilterOverlay.open}>
                <FilterList />
              </IconButton>
            </Tooltip>
          </Stack>
        }
        data={data}
        dataHasColor
        height={height}
        xFormat={(x: ScatterPlotValue) => `v${x}`}
        yFormat=".0%"
        yLegend="Accuracy"
        yScale={{
          min: 0,
          max: 1,
          type: 'linear',
        }}
        tooltipContent={(point, stackedPoints) => {
          return (
            <Stack direction="row">
              {stackedPoints.map((stackedPoint, index) => (
                <Stack key={index} p={1} maxWidth={170}>
                  <Typography
                    noWrap
                    variant="body2"
                    sx={{ color: stackedPoint.color || point.color }}
                  >
                    <b>{stackedPoint.id}</b>
                  </Typography>
                  <Typography component="p" noWrap variant="caption">
                    <b>Version:</b> {point.data.xFormatted}
                  </Typography>
                  <Typography component="p" noWrap variant="caption">
                    <b>Accuracy:</b> {point.data.yFormatted}
                  </Typography>
                </Stack>
              ))}
            </Stack>
          )
        }}
      />
      <VisibleFieldsDialog
        projectGrids={projectGrids}
        overlay={fieldsFilterOverlay}
        visibleFieldIds={visibleFieldIds}
        setVisibleFieldIds={setVisibleFieldIds}
      />
    </>
  )
}
