import { Container, Stack } from '@mui/material'
import { useQueryClient } from '@tanstack/react-query'
import { ProjectModel, ProjectModelVersion } from '@/types/project-models'
import { useGetProjectModels } from '@/service-library/hooks/project-models'
import { showErrorSnackbar } from '@/utils/snackbars'
import { useNotifications } from '@/components/notifications/NotificationProvider'
import { useProjectContext } from '@/components/project-dashboard/ProjectProvider'
import ModelCard from './ModelCard'
import { OCR_MODEL_TYPE_ID } from './helpers'

export default function ModelsList() {
  const { project } = useProjectContext()
  const queryClient = useQueryClient()

  const { projectModels, queryKey } = useGetProjectModels({
    filters: {
      limit: '1000',
      project_id: project.id,
      fields__include: 'project_model_versions',
      project_model_versions__fields__only: 'id,training_status,is_current',
      parent_model_id__isnull: 'true',
      'project_model_type_id!': OCR_MODEL_TYPE_ID, // Filter out the OCR model type since we don't let the user manage it
    },
  })

  const showDenseCard = projectModels.length > 5

  async function updateModelVersionInCache(
    updatedVersion: ProjectModelVersion,
  ) {
    await queryClient.cancelQueries({ queryKey })
    const previous = queryClient.getQueryData<ProjectModel[]>(queryKey)
    queryClient.setQueryData<ProjectModel[]>(queryKey, () => {
      if (!previous) return previous

      const updatedModels = [...previous]
      const modelIndex = updatedModels.findIndex(
        ({ id }) => id === updatedVersion.project_model_id,
      )
      if (modelIndex === -1) return previous
      const model = updatedModels[modelIndex]
      const versionIndex = model.project_model_versions.findIndex(
        ({ id }) => id === updatedVersion.id,
      )
      if (versionIndex === -1) return previous
      model.project_model_versions[versionIndex] = updatedVersion
      updatedModels[modelIndex] = model
      return updatedModels
    })
  }

  useNotifications({
    keys: ['training_status'],
    callback: ({ model_version }) => {
      if (model_version) {
        updateModelVersionInCache(model_version)
        if (model_version.training_status === 'error')
          showErrorSnackbar('Training Failed')
      }
    },
  })

  return (
    <Container maxWidth="sm" sx={{ pt: 2 }}>
      <Stack spacing={showDenseCard ? 1 : 2}>
        {projectModels.map((projectModel) => (
          <ModelCard
            key={projectModel.id}
            projectModel={projectModel}
            dense={showDenseCard}
          />
        ))}
      </Stack>
    </Container>
  )
}
