import { useEffect, useMemo, useState } from 'react'
import { blue } from '@mui/material/colors'
import { ChevronLeft, ChevronRight } from '@mui/icons-material'
import {
  Box,
  Divider,
  IconButton,
  Stack,
  Typography,
  useTheme,
} from '@mui/material'

import {
  MetricsProjectModelVersion,
  ProjectModelVersion,
} from '@/types/project-models'
import { useGetProjectModelVersionCategories } from '@/service-library/hooks/project-model-version-categories'
import BarChart from '@/components/charts/BarChart'

type CategoryModelItemCountChartProps = {
  modelVersion: ProjectModelVersion | MetricsProjectModelVersion
  height?: number
  // If the chart is enabled, it will fetch the data from the backend
  enabled?: boolean
  isForVersionMetricsPage?: boolean
}

export default function CategoryModelItemCountChart({
  modelVersion,
  height = 330,
  enabled = true,
  isForVersionMetricsPage = false,
}: CategoryModelItemCountChartProps) {
  const theme = useTheme()
  const [pageIndex, setPageIndex] = useState(0)

  const { data, hasNextPage, fetchNextPage, isFetching, isError } =
    useGetProjectModelVersionCategories({
      filters: {
        project_model_version_id: modelVersion.id,
        limit: '50',
        // including project_content_category_item and data_list_entry_cell_value for backwards compatibility
        fields__include: `project_content_category_item,data_list_entry_cell_value${
          isForVersionMetricsPage ? '' : ',total_count'
        }`,
        project_content_category_item__fields__only: 'id,name',
        data_list_entry_cell_value__fields__only: 'id,value',
        // ordering: isForVersionMetricsPage ? '' : '-total_count', // TODO: Uncomment when bug in backed is fixed (Ordering takes too long to load, so it never loads.)
      },
      refetchOnWindowFocus: !isForVersionMetricsPage, // if in the metrics page, the response shouldn't change
      refetchOnReconnect: !isForVersionMetricsPage, // if in the metrics page, the response shouldn't change
      enabled,
      retry: 1,
    })

  const [isErrorPerPage, setIsErrorPerPage] = useState<Record<number, boolean>>(
    {},
  )

  const projectModelVersionCategories = useMemo(
    () => data?.pages?.[pageIndex]?.results || [],
    [data?.pages, pageIndex],
  )

  const isFetchingCurrentPage = isFetching && !data?.pages?.[pageIndex]
  const canGoToNextPage =
    data?.pages?.[pageIndex] && (hasNextPage || data?.pages?.[pageIndex + 1])
  const showArrows = hasNextPage || (data?.pages?.length || 0) > 1
  const isTraining =
    'training_status' in modelVersion // checking like this so TS doesn't complain
      ? modelVersion.training_status === 'training'
      : false

  const categoryCountsWithNames =
    projectModelVersionCategories.map((category) => {
      const name = category.project_content_category_item_id
        ? category.project_content_category_item?.name
        : category.data_list_entry_cell_value?.value ||
          category.trained_with_value
      return {
        name,
        trained_count: category.trained_count,
        training_batch_count:
          (category.total_count
            ? category.total_count - category.trained_count
            : 0) || 0,
        trained_countColor: theme.palette.primary.main,
        training_batch_countColor: blue['A700'],
      }
    }) || []

  useEffect(() => {
    if (isError && !data?.pages?.[pageIndex]) {
      setIsErrorPerPage((prev) => ({ ...prev, [pageIndex]: true }))
    } else {
      setIsErrorPerPage((prev) => {
        const newErrorState = { ...prev }
        Object.keys(newErrorState).forEach((key) => {
          const pageIndexKey = parseInt(key)
          if (data?.pages?.[pageIndexKey]) {
            delete newErrorState[pageIndexKey]
          }
        })
        return newErrorState
      })
    }
  }, [data?.pages, isError, pageIndex])

  return (
    <>
      <BarChart
        data={categoryCountsWithNames}
        isError={isErrorPerPage[pageIndex]}
        isLoading={isFetchingCurrentPage}
        colors={({ id, data }) => String(data[`${id}Color`])}
        xKey="name"
        keys={['trained_count', 'training_batch_count']}
        yKey="trained_count"
        height={height}
        label="Categories"
        padding={0.85}
        axisBottom={{
          format: (name) =>
            name.length > 12 ? `${name.substring(0, 10)}...` : name,
          tickRotation: 90,
        }}
        margin={{
          bottom:
            categoryCountsWithNames.length > 0 ? (showArrows ? 100 : 92) : 28,
        }}
        yLegend="Count"
        tooltipContent={(data) => (
          <Stack
            sx={{
              p: 1,
            }}
          >
            <Typography noWrap>{data.name}</Typography>
            <Divider sx={{ my: 0.5 }} />
            <Typography noWrap variant="caption">
              Trained:{' '}
              <Box
                component="span"
                sx={{
                  fontWeight: 'bold',
                  color: theme.palette.primary.main,
                }}
              >
                {data.trained_count || 0}{' '}
              </Box>
            </Typography>
            <Typography noWrap variant="caption">
              {isTraining ? 'Training' : 'New'}:{' '}
              <Box
                component="span"
                sx={{
                  fontWeight: 'bold',
                  color: blue['A700'],
                }}
              >
                {data.training_batch_count || 0}
              </Box>
            </Typography>
            <Typography noWrap variant="caption">
              Total:{' '}
              <Box component="span" sx={{ fontWeight: 'bold' }}>
                {(data.trained_count || 0) + (data.training_batch_count || 0)}
              </Box>
            </Typography>
          </Stack>
        )}
      />
      {showArrows && (
        <Stack direction="row" justifyContent="flex-end">
          <IconButton
            disabled={pageIndex === 0}
            onClick={() => {
              setPageIndex((prev) => prev - 1)
            }}
          >
            <ChevronLeft fontSize="inherit" />
          </IconButton>
          <IconButton
            disabled={!canGoToNextPage}
            onClick={() => {
              setPageIndex((prev) => prev + 1)
              if (!data?.pages?.[pageIndex + 1] && hasNextPage) {
                fetchNextPage()
              }
            }}
          >
            <ChevronRight fontSize="inherit" />
          </IconButton>
        </Stack>
      )}
    </>
  )
}
