import { Stack, Typography, useTheme } from '@mui/material'
import ScatterPlotChart, {
  formatGroupWithOnlyOnePoint,
} from '../charts/ScatterPlotChart'
import {
  MetricsProjectModelVersion,
  ProjectModelVersion,
  ProjectModelVersionField,
} from '@/types/project-models'
import { getColor } from '@/utils/chart-utils'
import { Project } from '@/types/projects'
import { useGetProjectGridFields } from '@/service-library/hooks/project-grid-fields'

type FieldAccuracyChartProps = {
  modelVersion: ProjectModelVersion | MetricsProjectModelVersion
  project: Project
  height?: number

  // If the chart is enabled, it will fetch the data from the backend
  enabled?: boolean
}

export default function FieldAccuracyChart({
  modelVersion,
  project,
  height = 240,
  enabled = true,
}: FieldAccuracyChartProps) {
  const theme = useTheme()

  const { projectGridFields = [] } = useGetProjectGridFields({
    filters: {
      limit: '1000',
      project_id: project.id,
      fields__only: 'id,name,project_grid_field_type',
    },
    enabled,
  })

  const projectModelVersionFieldsWithNames =
    modelVersion?.project_model_version_fields?.reduce<
      (Omit<
        ProjectModelVersionField,
        'id' | 'project_model_version_id' | 'project_grid_field'
      > & {
        name: string
        color: string
      })[]
    >((acc, versionField, index) => {
      const field = projectGridFields.find(
        ({ id }) => id === versionField.project_grid_field_id,
      )
      if (field) {
        const colorIndex = theme.palette.mode === 'dark' ? 200 : 400
        acc.push({
          ...versionField,
          name: field.name,
          color: getColor(index, colorIndex),
        })
      }
      return acc
    }, []) || []

  return (
    <ScatterPlotChart
      data={projectModelVersionFieldsWithNames.map((group) =>
        formatGroupWithOnlyOnePoint(group, 'name', 'trained_count', 'accuracy'),
      )}
      dataHasColor
      height={height}
      yFormat=".0%"
      yLegend="Accuracy"
      label="Field Accuracy"
      yScale={{
        min: 0,
        max: 1,
        type: 'linear',
      }}
      tooltipContent={(node, stackedNodes) => (
        <Stack direction="row">
          {stackedNodes.map(({ id, color }, index) => (
            <Stack key={index} p={1} maxWidth={170}>
              <Typography
                noWrap
                variant="body2"
                sx={{ color: color || node.color }}
              >
                <b>{id}</b>
              </Typography>
              <Typography component="p" noWrap variant="caption">
                <b>Fields Trained:</b> {node.formattedX}
              </Typography>
              <Typography component="p" noWrap variant="caption">
                <b>Accuracy:</b> {node.formattedY}
              </Typography>
            </Stack>
          ))}
        </Stack>
      )}
    />
  )
}
