import { Box, Divider, Stack, Typography, useTheme } from '@mui/material'
import { blue } from '@mui/material/colors'
import {
  MetricsProjectModelVersion,
  ProjectModelVersion,
} from '@/types/project-models'
import { useGetProjectModelVersionFields } from '@/service-library/hooks/project-model-version-fields'
import BarChart from '@/components/charts/BarChart'

type FieldCountComparisonChartProps = {
  modelVersion: ProjectModelVersion | MetricsProjectModelVersion
  height?: number

  // If the chart is enabled, it will fetch the data from the backend
  enabled?: boolean
}

export default function FieldCountComparisonChart({
  modelVersion,
  height = 330,
  enabled = true,
}: FieldCountComparisonChartProps) {
  const theme = useTheme()

  const { projectModelVersionFields } = useGetProjectModelVersionFields({
    filters: {
      limit: '50',
      project_model_version_id: modelVersion.id,
      fields__include: 'project_grid_field,total_count',
      project_grid_field__fields__only: 'id,name',
      ordering: '-total_count',
    },
    enabled,
  })

  const isTraining =
    'training_status' in modelVersion // checking like this so TS doesn't complain
      ? modelVersion.training_status === 'training'
      : false

  const data: {
    name: string
    trained_count: number
    training_batch_count: number
    trained_countColor: string
    training_batch_countColor: string
    project_grid_field_id: string
  }[] = []
  const newFieldNames = new Set<string>()

  projectModelVersionFields.forEach((field) => {
    let fieldName = field.project_grid_field?.name || ''

    while (newFieldNames.has(fieldName)) {
      fieldName += ' ' // Putting a space lets the chart index the fields separately
    }
    newFieldNames.add(fieldName)

    data.push({
      name: fieldName,
      trained_count: field.trained_count,
      training_batch_count: Math.max(
        field.total_count ? field.total_count - field.trained_count : 0,
        0,
      ),
      trained_countColor: theme.palette.primary.main,
      training_batch_countColor: blue['A700'],
      project_grid_field_id: field.project_grid_field_id,
    })
  })

  return (
    <BarChart
      data={data}
      indexBy="name"
      keys={['trained_count', 'training_batch_count']}
      colorBy="id"
      colors={({ id, data }) => String(data[`${id}Color`])}
      xKey="name"
      yKey="trained_count"
      height={height}
      label="Fields"
      padding={0.5}
      yLegend="Count"
      axisBottom={{
        truncateTickAt: 10,
        tickRotation: 90,
      }}
      margin={{
        bottom: data.length > 0 ? 90 : 48,
      }}
      tooltipContent={(data) => (
        <Stack
          sx={{
            p: 1,
          }}
        >
          <Typography noWrap>{data.name}</Typography>
          <Divider sx={{ my: 0.5 }} />
          <Typography noWrap variant="caption">
            Trained:{' '}
            <Box
              component="span"
              sx={{
                fontWeight: 'bold',
                color: theme.palette.primary.main,
              }}
            >
              {data.trained_count || 0}{' '}
            </Box>
          </Typography>
          <Typography noWrap variant="caption">
            {isTraining ? 'Training' : 'New'}:{' '}
            <Box
              component="span"
              sx={{
                fontWeight: 'bold',
                color: blue['A700'],
              }}
            >
              {data.training_batch_count || 0}
            </Box>
          </Typography>
          <Typography noWrap variant="caption">
            Total:{' '}
            <Box component="span" sx={{ fontWeight: 'bold' }}>
              {(data.trained_count || 0) + (data.training_batch_count || 0)}
            </Box>
          </Typography>
        </Stack>
      )}
    />
  )
}
