import { useEffect, useMemo, useState } from 'react'
import { Link, useNavigate } from 'react-router-dom'
import {
  Box,
  Button,
  Card,
  CardActionArea,
  MenuItem,
  Select,
  Stack,
  Typography,
} from '@mui/material'
import Grid2 from '@mui/material/Unstable_Grid2/Grid2'
import { useQueryClient } from '@tanstack/react-query'
import { ProjectModelVersion } from '@/types/project-models'
import useIsSuperUser from '@/services/hooks/useIsSuperUser'
import {
  useTrainModel,
  useCancelTrainingModel,
} from '@/service-library/hooks/model-train'
import { useGetProjectModel } from '@/service-library/hooks/project-models'
import { useUpdateProjectModelVersions } from '@/service-library/hooks/project-model-versions'
import { useGetWorkflows } from '@/service-library/hooks/workflows'
import {
  useGetWorkflowStateFieldCounts,
  useGetWorkflowStatePickerItemCounts,
} from '@/service-library/hooks/workflow-states'
import {
  convertSecondsToHours,
  prettifyDate,
} from '@/utils/getFormattedDateTimeString'
import { showErrorSnackbar } from '@/utils/snackbars'
import CopyIDButton from '@/components/copy-id-button/CopyIDButton'
import GeneralInfoBox, {
  GeneralInfoBoxContent,
} from '@/components/data-visualization/GeneralInfoBox'
import { useDemoModeContext } from '@/components/demo-mode-provider/DemoModeProvider'
import { useNotifications } from '@/components/notifications/NotificationProvider'
import ProjectBreadcrumb from '@/components/project-dashboard/ProjectBreadcrumb'
import { useProjectContext } from '@/components/project-dashboard/ProjectProvider'
import { useSelectedWorkflowContext } from '@/components/workflows/SelectedWorkflowProvider'
import AccuracyColorBar from './AccuracyColorBar'
import CategoryItemCountChartCard from './CategoryItemCountChartCard'
import FieldAccuracyChartCard from './FieldAccuracyChartCard'
import FieldCountComparisonChartCard from './FieldCountComparisonChartCard'
import HandleTrainingButton from './HandleTrainingButton'
import ModelVersionAccuracyCard from './ModelVersionAccuracyCard'
import ModelVersionFieldAccuracyCard from './ModelVersionFieldAccuracyCard'
import NotTrainedBanner from './NotTrainedBanner'
import TrainInfoBanner from './TrainInfoBanner'
import {
  CATEGORY_MODEL_TYPE_ID,
  getCurrentVersion,
  getNewFieldOccurrencesCount,
  getNewItemOccurrencesCount,
  getTrainingVersion,
  getVersionInCancelingState,
  validTrainingStatuses,
} from './helpers'

const trainingStatusColorMap: Record<
  ProjectModelVersion['training_status'],
  string
> = {
  training: 'blue.main',
  canceled: 'text.secondary',
  canceling: 'text.secondary',
  error: 'red.main',
  trained: 'green.main',
  starting: 'green.main',
  starting_gpu: 'green.main',
  evaluating: 'green.main',
  terminating_gpu: 'green.main',
  completed: 'green.main',
}

export default function ModelDashboard() {
  const navigate = useNavigate()
  const { selectedWorkflow } = useSelectedWorkflowContext()
  const queryClient = useQueryClient()
  const isSuperUser = useIsSuperUser()
  const [demoMode] = useDemoModeContext()
  const { project } = useProjectContext()

  const { projectModel, refetch, queryKey } = useGetProjectModel({
    id: selectedWorkflow.project_model_id || '',
    filters: {
      fields__include:
        'doc_counts,project_model_versions,project_grid_fields,has_children',
      project_model_versions__fields__include:
        'field_stats,category_stats,project_model_version_fields',
    },
  })
  const { projectModel: possibleRootModel } = useGetProjectModel({
    id: projectModel?.parent_model_id || '',
    filters: {
      fields__only: 'project_grid_fields',
    },
  })

  const rootModel = projectModel?.parent_model_id
    ? possibleRootModel
    : projectModel

  const { updateProjectModelVersions } = useUpdateProjectModelVersions({
    onMutate: (updatedVersions) => {
      const updatedVersionsMap = updatedVersions.reduce<
        Record<string, ProjectModelVersion>
      >((acc, version) => {
        acc[version.id] = version
        return acc
      }, {})

      // We update projectModel cache to show "active" on the new version
      queryClient.setQueryData(queryKey, {
        ...projectModel,
        project_model_versions: projectModel?.project_model_versions.map(
          (version) => updatedVersionsMap[version.id] || version,
        ),
      })
    },
    onError: () => {
      queryClient.invalidateQueries({ queryKey })
      showErrorSnackbar(
        'Failed to update projectModel version. Please try again.',
      )
    },
  })

  const field = rootModel?.project_grid_fields?.[0] // Only the root model has a relationship to the field
  const isCategoryModelUsingLists =
    !!field?.metadata.data_list_id || !!field?.params?.data_list_id

  const { workflows } = useGetWorkflows({
    filters: {
      limit: '1000',
      project_id: project.id,
      project_model_id: projectModel?.id || '',
      fields__include: 'workflow_states',
    },
    enabled: !!projectModel,
  })

  const availableForTrainingState = workflows[0]?.workflow_states?.find(
    ({ code }) => code === 'available_for_training',
  )
  const trainingBatchState = workflows[0]?.workflow_states?.find(
    ({ code }) => code === 'training_batch',
  )

  const isCategoryModel =
    projectModel?.project_model_type_id === CATEGORY_MODEL_TYPE_ID

  const { fieldCounts: fieldCountsForTrainingBatch = {} } =
    useGetWorkflowStateFieldCounts({
      id: trainingBatchState?.id || '',
      enabled: !!trainingBatchState && !isCategoryModel,
    })

  const { itemCounts: itemCountsForTrainingBatch = {} } =
    useGetWorkflowStatePickerItemCounts({
      id: trainingBatchState?.id || '',
      pickerFieldId: field?.id || '',
      enabled:
        !!trainingBatchState &&
        isCategoryModel &&
        isCategoryModelUsingLists &&
        !!field,
    })

  const modelVersions = useMemo(
    () => projectModel?.project_model_versions || [],
    [projectModel?.project_model_versions],
  )

  const sortedVersions = modelVersions.sort((a, b) => {
    return a.version - b.version
  })

  const [selectedVersionId, setSelectedVersionId] = useState<string>('')

  const { trainModel } = useTrainModel({
    detailQueryKey: queryKey,
    onMutate: () => {
      setSelectedVersionId('training')
    },
  })

  const { cancelTrainingModel } = useCancelTrainingModel({
    detailQueryKey: queryKey,
    onError: () => {
      showErrorSnackbar('Failed to cancel training. Please try again later.')
    },
  })

  const selectedVersion = useMemo(
    () => modelVersions.find(({ id }) => id === selectedVersionId),
    [modelVersions, selectedVersionId],
  )

  const currentVersion = useMemo(
    () => getCurrentVersion(modelVersions),
    [modelVersions],
  )

  const lastVersion = sortedVersions.at(-1)

  const trainingVersion = useMemo(
    () => getTrainingVersion(modelVersions),
    [modelVersions],
  )

  const versionInCancelingState = useMemo(
    () => getVersionInCancelingState(modelVersions),
    [modelVersions],
  )

  function handleCancelOrTrainModel() {
    if (trainingVersion) {
      cancelTrainingModel({ project_model_version_id: trainingVersion.id })
    } else if (projectModel) {
      trainModel(projectModel)
    }
  }

  function updateCurrentVersion() {
    if (projectModel && selectedVersion) {
      const versionsToUpdate = [
        {
          ...selectedVersion,
          training_status: 'completed' as const, // We need to use "completed" since v2 API doesn't support "trained"
          is_current: true,
        },
      ]
      if (currentVersion) {
        versionsToUpdate.unshift({
          ...currentVersion,
          training_status: 'completed',
          is_current: false,
        })
      }
      updateProjectModelVersions(versionsToUpdate)
    }
  }

  useEffect(() => {
    if (modelVersions.length === 0 || selectedVersion) return
    setSelectedVersionId(currentVersion?.id || lastVersion?.id || '')
  }, [
    currentVersion?.id,
    lastVersion?.id,
    modelVersions.length,
    selectedVersion,
  ])

  useEffect(() => {
    if (selectedVersionId === 'training' && trainingVersion?.id) {
      setSelectedVersionId(trainingVersion?.id)
    }
  }, [selectedVersionId, trainingVersion?.id])

  useEffect(() => {
    if (selectedWorkflow && !selectedWorkflow?.project_model_id) {
      navigate('..', { relative: 'path' })
    }
  })

  useNotifications({
    keys: ['training_status'],
    callback: ({ model_version }) => {
      if (
        projectModel &&
        model_version &&
        model_version.project_model_id === projectModel.id
      ) {
        if (
          model_version.training_status === 'trained' ||
          model_version.training_status === 'error' ||
          model_version.training_status === 'canceled'
        )
          refetch()
        if (model_version.training_status === 'error')
          showErrorSnackbar('Training Failed')
      }
    },
  })

  const largeNumberOfFields = selectedVersion
    ? (selectedVersion.field_stats?.total_count || 0) > 50
    : false

  const largeNumberOfCategories = selectedVersion
    ? (selectedVersion.category_stats?.total_count || 0) > 50
    : false

  const canTrain =
    projectModel &&
    // Not currently training
    !trainingVersion &&
    // Not currently canceling training
    !versionInCancelingState &&
    // Must have at least 10 documents in the training batch
    projectModel.doc_counts.by_workflow_state_code.training_batch >= 10 &&
    // Is current version, or...
    (selectedVersion?.is_current ||
      // There is no current version and this is the most recent version
      (!currentVersion && selectedVersion?.id === lastVersion?.id))

  const estimatedAccuracy = selectedVersion?.accuracy
    ? `${Math.floor(+selectedVersion.accuracy * 100)}%`
    : '0%'

  const newFieldOccurrencesCount =
    selectedVersion &&
    selectedVersion.training_status !== 'training' &&
    !isCategoryModel
      ? getNewFieldOccurrencesCount(
          fieldCountsForTrainingBatch,
          selectedVersion,
        )
      : 0

  const newItemOccurrencesCount =
    selectedVersion &&
    selectedVersion.training_status !== 'training' &&
    isCategoryModel
      ? getNewItemOccurrencesCount(itemCountsForTrainingBatch, selectedVersion)
      : 0

  return projectModel ? (
    <Box sx={{ p: 4, maxWidth: 2000, m: '0 auto' }}>
      <ProjectBreadcrumb label="Dashboard" url="." />

      <Stack spacing={2} sx={{ zIndex: 1 }}>
        <Grid2 container spacing={2} columns={12}>
          <Grid2 xs={12}>
            <Stack
              direction="row"
              spacing={2}
              alignItems="center"
              justifyContent="space-between"
            >
              <Stack direction="row" spacing={2}>
                <Typography component="h1" variant="h6">
                  {projectModel.name}
                </Typography>
                {(selectedVersion?.training_status === 'trained' ||
                  selectedVersion?.training_status === 'completed') &&
                  !selectedVersion.is_current && (
                    <Button
                      variant="text"
                      onClick={() => {
                        updateCurrentVersion()
                      }}
                    >
                      Set as Active Version
                    </Button>
                  )}
              </Stack>

              <Stack direction="row" spacing={2}>
                {isSuperUser && !demoMode && (
                  <>
                    <CopyIDButton
                      stringToCopy={projectModel.id}
                      isSuperUser
                      label="Model ID"
                    />
                    {selectedVersion && (
                      <CopyIDButton
                        stringToCopy={selectedVersion.id}
                        isSuperUser
                        label="Version ID"
                      />
                    )}
                  </>
                )}
                {sortedVersions.length > 0 && (
                  <Select
                    value={selectedVersionId}
                    size="small"
                    sx={{ px: 1 }}
                    onChange={(event) => {
                      setSelectedVersionId(event.target.value)
                    }}
                  >
                    {sortedVersions.map((version) => (
                      <MenuItem key={version.id} value={version.id}>
                        <Stack direction="row" spacing={1} alignItems="center">
                          <Box>v{version.version}.0</Box>
                          <Typography
                            variant="caption"
                            sx={{
                              color:
                                trainingStatusColorMap[version.training_status],
                            }}
                          >
                            {version.is_current && '(active)'}
                            {validTrainingStatuses.includes(
                              version.training_status,
                            ) && '(training)'}
                            {version.training_status === 'error' && '(failed)'}
                            {version.training_status === 'canceling' &&
                              '(canceling)'}
                            {version.training_status === 'canceled' &&
                              '(canceled)'}
                          </Typography>
                        </Stack>
                      </MenuItem>
                    ))}
                  </Select>
                )}
              </Stack>
            </Stack>
          </Grid2>

          {modelVersions.length === 0 && (
            <Grid2 xs={12}>
              <NotTrainedBanner />
            </Grid2>
          )}
          {(selectedVersion?.training_status === 'error' ||
            selectedVersion?.training_status === 'canceled') && (
            <Grid2 xs={12}>
              <TrainInfoBanner modelVersion={selectedVersion} />
            </Grid2>
          )}

          {/* ------------------------------- */}
          {/* ----------INFO BOXES----------- */}
          {/* ------------------------------- */}

          {/* Training Button */}
          <Grid2 xs={12} md={4}>
            <HandleTrainingButton
              onClick={handleCancelOrTrainModel}
              status={
                trainingVersion
                  ? 'training'
                  : versionInCancelingState
                  ? 'canceling'
                  : undefined
              }
              disabled={
                trainingVersion
                  ? selectedVersion?.id !== trainingVersion.id
                  : !canTrain
              }
              newOccurrencesCount={
                newFieldOccurrencesCount || newItemOccurrencesCount
              }
              model={projectModel}
            />
          </Grid2>

          {/* Estimated Accuracy */}
          <Grid2 xs={12} sm={6} md={2}>
            <GeneralInfoBox
              label="Estimated Accuracy"
              value={
                selectedVersion?.id === trainingVersion?.id
                  ? '---'
                  : estimatedAccuracy
              }
            >
              {selectedVersion && (
                <AccuracyColorBar currentVersion={selectedVersion} />
              )}
            </GeneralInfoBox>
          </Grid2>

          {/* Model Version */}
          <Grid2 xs={12} sm={6} md={2}>
            <GeneralInfoBox
              label={
                <Stack>
                  <Typography component="h2" variant="body2">
                    Model Version
                  </Typography>
                  <Typography variant="subtitle1">
                    {selectedVersion
                      ? prettifyDate(
                          selectedVersion.training_completed_at ||
                            selectedVersion.training_started_at,
                        )
                      : '---'}
                  </Typography>
                </Stack>
              }
              value={selectedVersion ? `${selectedVersion.version}.0` : '0.0'}
            />
          </Grid2>
          <Grid2 xs={12} sm={6} md={2}>
            <Card
              elevation={0}
              sx={{
                p: 1.5,
                display: 'flex',
                flexDirection: 'column',
                rowGap: 1.5,
                height: '100%',
              }}
            >
              <GeneralInfoBoxContent
                label="Documents Used to Train"
                shrinkValueSize
                value={
                  selectedVersion?.id === trainingVersion?.id
                    ? '---'
                    : selectedVersion?.document_count || '---'
                }
              />
              <GeneralInfoBoxContent
                label="Training Duration"
                shrinkValueSize
                value={
                  selectedVersion &&
                  selectedVersion.id !== trainingVersion?.id &&
                  selectedVersion.training_started_at &&
                  selectedVersion.training_completed_at
                    ? convertSecondsToHours(
                        Math.round(
                          (new Date(
                            selectedVersion.training_completed_at,
                          ).getTime() -
                            new Date(
                              selectedVersion.training_started_at,
                            ).getTime()) /
                            1000,
                        ),
                      )
                    : '---'
                }
              />
            </Card>
          </Grid2>
          <Grid2 xs={12} sm={6} md={2}>
            <Card
              elevation={0}
              sx={{
                height: '100%',
              }}
            >
              {/* Number in Available to Train */}
              <CardActionArea
                component={Link}
                to={`../data?workflow_state=${availableForTrainingState?.id}&workflow=${workflows[0]?.id}`}
                sx={{ px: 1.5, pt: 1.5, pb: 0.75 }}
              >
                <GeneralInfoBoxContent
                  label="Available to Train"
                  shrinkValueSize
                  value={
                    projectModel.doc_counts.by_workflow_state_code
                      .available_for_training
                  }
                />
              </CardActionArea>

              {/* Number in Training Batch */}
              <CardActionArea
                component={Link}
                to={`../data?workflow_state=${trainingBatchState?.id}&workflow=${workflows[0]?.id}`}
                sx={{ px: 1.5, pb: 1.5, pt: 0.75 }}
              >
                <GeneralInfoBoxContent
                  label="Training Batch"
                  shrinkValueSize
                  value={
                    projectModel.doc_counts.by_workflow_state_code
                      .training_batch
                  }
                />
              </CardActionArea>
            </Card>
          </Grid2>
          {/* --------------------------- */}
          {/* ----------CHARTS----------- */}
          {/* --------------------------- */}

          {/* Field Count Comparison Chart */}
          {!isCategoryModel && selectedVersion && (
            <Grid2 xs={12} md={largeNumberOfFields ? 12 : 7}>
              <FieldCountComparisonChartCard currentVersion={selectedVersion} />
            </Grid2>
          )}

          {isCategoryModel && selectedVersion && (
            <Grid2 xs={12} md={largeNumberOfCategories ? 12 : 7}>
              <CategoryItemCountChartCard modelVersion={selectedVersion} />
            </Grid2>
          )}

          {/* Model Versions Accuracy Chart */}
          {selectedVersion && (
            <Grid2 xs={12} md={largeNumberOfFields ? 12 : 5}>
              <ModelVersionAccuracyCard modelVersions={sortedVersions} />
            </Grid2>
          )}

          {/* Model Version Field Accuracy Chart */}
          {selectedVersion && !isCategoryModel && (
            <Grid2 xs={12} md={6}>
              <ModelVersionFieldAccuracyCard modelVersions={sortedVersions} />
            </Grid2>
          )}

          {/* Field Accuracy Chart or Category Item Count */}
          {selectedVersion && !isCategoryModel && (
            <Grid2 xs={12} md={6}>
              <FieldAccuracyChartCard modelVersion={selectedVersion} />
            </Grid2>
          )}
        </Grid2>
      </Stack>
    </Box>
  ) : null
}
