import { useEffect, useMemo, useState } from 'react'
import {
  Card,
  Container,
  Skeleton,
  Stack,
  Typography,
  useTheme,
} from '@mui/material'
import Grid from '@mui/material/Unstable_Grid2/Grid2'
import { LocalizationProvider } from '@mui/x-date-pickers'
import { AdapterDayjs } from '@mui/x-date-pickers/AdapterDayjs'
import { GroupOption } from '@/types/metrics'
import { Project } from '@/types/projects'
import {
  MetricsProjectModelVersion,
  ProjectModel,
} from '@/types/project-models'
import useOverlay from '@/hooks/useOverlay'
import useMetricsSearchParams from '@/hooks/useMetricsSearchParams'
import { useGetProjects } from '@/service-library/hooks/projects'
import { useGetProjectModels } from '@/service-library/hooks/project-models'
import { useGetTrainingMetrics } from '@/service-library/hooks/project-model-versions'
import { convertToDate } from '@/utils/date-metrics'
import GeneralInfoBox from '@/components/data-visualization/GeneralInfoBox'
import ListAutocomplete, {
  BaseDataType,
} from '@/components/list-autocomplete/ListAutocomplete'
import TimeBarChart from '@/components/charts/TimeBarChart'
import {
  CATEGORY_MODEL_TYPE_ID,
  OCR_MODEL_TYPE_ID,
  NER_MODEL_TYPE_ID,
} from '@/components/models-page/helpers'
import { useRootOrganization } from '@/components/organizations/RootOrganizationProvider'
import PageTitle from '@/components/PageTitle'
import MetricsCommonFilters from './MetricsCommonFilters'
import ModelVersionsDialog from './ModelVersionsDialog'
import ModelVersionsFieldsChart from './ModelVersionsFieldsChart'

const placeholderData = {
  id: '',
  name: '',
}

export default function ModelVersionsMetrics() {
  const theme = useTheme()
  const { rootOrganization } = useRootOrganization()
  const modelVersionsOverlay = useOverlay()
  const { values, updateValues } = useMetricsSearchParams()

  const startDateState = useState(convertToDate(values.from || 'start'))
  const endDateState = useState(convertToDate(values.to || 'end', false))
  const groupedByState = useState<GroupOption>(
    (values.grouped_by as GroupOption) || 'day',
  )

  const [internalTrainingModelsData, setInternalTrainingModelsData] = useState<
    typeof trainingData
  >([])

  const [modelVersions, setModelVersions] =
    useState<MetricsProjectModelVersion[]>()
  const { projects, isLoading: projectsIsLoading } = useGetProjects({
    refetchOnWindowFocus: false,
    filters: {
      org_id: rootOrganization.id,
      limit: '1000',
    },
  })

  const setUpProjects = projects.filter(
    (project) => project.setup_state === 'complete',
  )

  const defaultProject = useMemo(() => {
    if (setUpProjects.length && values.project_id) {
      return (
        setUpProjects.find(({ id }) => id === values.project_id) ||
        placeholderData
      )
    }
    return setUpProjects[0] || placeholderData
  }, [setUpProjects, values.project_id])

  const [selectedProject, setSelectedProject] = useState<
    Project | BaseDataType
  >(defaultProject)

  const { projectModels: trainingModels, isLoading: projectModelsIsLoading } =
    useGetProjectModels({
      filters: {
        limit: '1000',
        project_id: selectedProject?.id,
        'project_model_type_id!': OCR_MODEL_TYPE_ID,
        parent_model_id__isnull: 'true',
      },
      enabled: !!selectedProject?.id,
    })

  const defaultProjectModel = useMemo(() => {
    if (trainingModels.length && values.project_model_id) {
      return (
        trainingModels.find(({ id }) => id === values.project_model_id) ||
        placeholderData
      )
    }
    return trainingModels[0] || placeholderData
  }, [trainingModels, values.project_model_id])

  const [selectedProjectModel, setSelectedProjectModel] = useState<
    ProjectModel | BaseDataType
  >(defaultProjectModel)

  const { trainingMetrics, isLoading } = useGetTrainingMetrics({
    projectModelId: selectedProjectModel?.id,
    startDate: startDateState[0],
    endDate: endDateState[0],
    groupedBy: groupedByState[0],
    refetchOnWindowFocus: false,
  })

  const trainingData = useMemo(
    () => trainingMetrics?.results || [],
    [trainingMetrics?.results],
  )

  useEffect(() => {
    if (!projectsIsLoading) {
      setSelectedProject(defaultProject)
    }
  }, [defaultProject, projectsIsLoading])

  useEffect(() => {
    if (!projectModelsIsLoading) {
      setSelectedProjectModel(defaultProjectModel)
    }
  }, [defaultProjectModel, projectModelsIsLoading])

  useEffect(() => {
    !isLoading && trainingData && setInternalTrainingModelsData(trainingData)
  }, [trainingData, isLoading])

  const allModelVersions = trainingData.reduce<MetricsProjectModelVersion[]>(
    (acc, { model_versions }) =>
      model_versions.length > 0 ? [...acc, ...model_versions] : acc,
    [],
  )

  return (
    <Container sx={{ py: 4 }}>
      <PageTitle>Metrics - Model Versions</PageTitle>

      <Grid container spacing={2}>
        <Grid xs={12}>
          <LocalizationProvider dateAdapter={AdapterDayjs}>
            <Stack direction="row" spacing={2}>
              <ListAutocomplete
                autoSelect={false}
                options={setUpProjects}
                selected={selectedProject}
                setSelected={(project) => {
                  setSelectedProject(project)
                  setSelectedProjectModel({ id: '', name: '' })
                  updateValues({ project_id: project.id, project_model_id: '' })
                }}
                label="Project"
                fullWidth={false}
                sx={{ width: '180px' }}
              />
              <ListAutocomplete
                autoSelect={false}
                options={trainingModels}
                selected={selectedProjectModel}
                setSelected={(projectModel) => {
                  setSelectedProjectModel(projectModel)
                  updateValues({ project_model_id: projectModel.id })
                }}
                label="Training Model"
                fullWidth={false}
                sx={{ width: '180px' }}
              />
              <MetricsCommonFilters
                startDateState={startDateState}
                endDateState={endDateState}
                groupedByState={groupedByState}
                updateValues={updateValues}
              />
            </Stack>
          </LocalizationProvider>
        </Grid>

        {!trainingMetrics && isLoading && (
          <Grid xs={6}>
            <Skeleton height={100} />
          </Grid>
        )}
        {trainingMetrics && (
          <Grid xs={3}>
            <GeneralInfoBox
              label="Number of Model Versions"
              value={trainingMetrics.count || 0}
            />
          </Grid>
        )}

        <Grid xs={12}>
          <Card elevation={0} sx={{ borderRadius: 2, px: 2, py: 3 }}>
            <TimeBarChart
              height={300}
              from={new Date(startDateState[0])}
              to={new Date(endDateState[0])}
              data={internalTrainingModelsData}
              dataSum={trainingMetrics ? trainingMetrics.count : undefined}
              colors={theme.palette.primary.main}
              label={
                <Typography component="h2" variant="h5" sx={{ pl: 6.5 }}>
                  Model Versions
                </Typography>
              }
              groupedBy={groupedByState[0]}
              margin={{ right: 40 }}
              yKey="count"
              yLegend="Count"
              onClick={(data) => {
                modelVersionsOverlay.open()
                setModelVersions(data.data.model_versions)
              }}
            />
          </Card>
        </Grid>

        {(selectedProjectModel as ProjectModel).project_model_type_id ===
          NER_MODEL_TYPE_ID &&
          allModelVersions.length > 0 && (
            <Grid xs={12}>
              <Card
                elevation={0}
                sx={{ borderRadius: 2, pl: 2, pr: 5.5, py: 3 }}
              >
                <ModelVersionsFieldsChart
                  height={300}
                  modelVersions={allModelVersions}
                  projectId={selectedProject.id}
                  labelSx={{ pl: 6.5 }}
                />
              </Card>
            </Grid>
          )}
      </Grid>

      {modelVersions && selectedProject && (
        <ModelVersionsDialog
          project={selectedProject as Project}
          isCategoryModel={
            (selectedProjectModel as ProjectModel).project_model_type_id ===
            CATEGORY_MODEL_TYPE_ID
          }
          modelVersions={modelVersions}
          overlay={modelVersionsOverlay}
          projectId={selectedProject.id}
        />
      )}
    </Container>
  )
}
