import { createDetail } from '../request-wrappers'
import { QueryKey, useQueryClient } from '@tanstack/react-query'
import { ProjectModel, ProjectModelVersion } from '@/types/project-models'
import useCreateDetail, {
  UseCreateDetailOptions,
} from '../core-hooks/useCreateDetail'
import { DetailQueryKeyOption } from '../query-types'

// TODO: update when pd-predict takes care of this call, too
/**
 * MARK: Action Hook: Train Model
 * Triggers training for a model
 */
export function useTrainModel({
  detailQueryKey,
  onMutate,
  ...options
}: UseCreateDetailOptions<ProjectModel> & { detailQueryKey?: QueryKey } = {}) {
  const queryClient = useQueryClient()

  const mutation = useCreateDetail<ProjectModel>({
    serviceFn: ({ item }) => {
      return createDetail<ProjectModel>({
        url: `/v2/pd-predict/models/${item.id}/start-training-new-version`,
        item: {} as unknown as ProjectModel, // We don't need to pass the item here
      })
    },
    onMutate: async (item) => {
      if (detailQueryKey) {
        await queryClient.cancelQueries({ queryKey: detailQueryKey })

        const previous = queryClient.getQueryData<ProjectModel>(detailQueryKey)
        queryClient.setQueryData<ProjectModel>(detailQueryKey, () => {
          if (!previous) return previous

          return {
            ...previous,
            project_model_versions: [
              ...(previous.project_model_versions || []),
              {
                id: 'training',
                training_status: 'training',
                version:
                  (previous.project_model_versions?.at(-1)?.version || 0) + 1,
                training_started_at: new Date().toUTCString(),
                project_model_version_fields: [],
                project_model_version_categories: [],
              } as unknown as ProjectModelVersion,
            ],
          }
        })
      }
      onMutate?.(item) // Gotta call this after the optimistic update since the call depends on the changes made to the cache
    },
    onSettled: () => {
      if (detailQueryKey) {
        queryClient.invalidateQueries(detailQueryKey)
      }
    },
    ...options,
  })

  return {
    trainModel: mutation.mutateAsync,
    ...mutation,
  }
}

/**
 * MARK: Action Hook: Cancel Training Model
 * Cancels training for a model
 */
export function useCancelTrainingModel({
  detailQueryKey,
  onMutate,
  ...options
}: UseCreateDetailOptions<{
  project_model_version_id: string
}> &
  DetailQueryKeyOption = {}) {
  const queryClient = useQueryClient()

  const mutation = useCreateDetail<{
    project_model_version_id: string
  }>({
    serviceFn: ({ item }) => {
      return createDetail<{ project_model_version_id: string }>({
        url: '/v2/pd-predict/model_train/cancel',
        item,
      })
    },
    onMutate: async (item) => {
      onMutate?.(item)
      if (!detailQueryKey) return
      await queryClient.cancelQueries({ queryKey: detailQueryKey })
      // Optimistically update the cache so the UI updates immediately with the new training status
      const previous = queryClient.getQueryData<ProjectModel>(detailQueryKey)
      queryClient.setQueryData<ProjectModel>(detailQueryKey, () => {
        if (!previous) return previous
        return {
          ...previous,
          project_model_versions: previous.project_model_versions?.map(
            (modelVersion) =>
              modelVersion.id === item.project_model_version_id
                ? { ...modelVersion, training_status: 'canceling' }
                : modelVersion,
          ),
        }
      })
    },
    onSettled: () => {
      if (detailQueryKey) {
        queryClient.invalidateQueries(detailQueryKey)
      }
    },
    ...options,
  })

  return {
    cancelTrainingModel: mutation.mutateAsync,
    ...mutation,
  }
}
