import { ProjectModelVersion } from '@/types/project-models'

export const OCR_MODEL_TYPE_ID = '234e0789-4b35-4fa0-8140-e0d0f682e9c2'
export const NER_MODEL_TYPE_ID = '60b499f1-49ef-4986-92c4-4095797c0492'
export const CATEGORY_MODEL_TYPE_ID = '4263c7ff-d64e-46a1-9e8a-28ece0e8ea1e'

export const validTrainingStatuses = [
  'training',
  'starting',
  'starting_gpu',
  'evaluating',
  'terminating_gpu',
]

export function getCurrentVersion(versions: ProjectModelVersion[] = []) {
  return versions.find((version) => version.is_current)
}

export function getVersionInCancelingState(
  versions: ProjectModelVersion[] = [],
) {
  return versions.find((version) => version.training_status === 'canceling')
}

export function getTrainingVersion(versions: ProjectModelVersion[] = []) {
  return versions.find(({ training_status }) =>
    validTrainingStatuses.includes(training_status),
  )
}

export function getNewFieldOccurrencesCount(
  fieldCountsForTrainingBatch: Record<string, number>,
  currentModelVersion: ProjectModelVersion,
) {
  const trainingBatchTotal = Object.values(fieldCountsForTrainingBatch).reduce(
    (acc, count) => acc + count,
    0,
  )

  if (trainingBatchTotal === 0) return 0

  const totalTrained = currentModelVersion.field_stats?.total_trained_count || 0

  return Math.max(trainingBatchTotal - totalTrained, 0)
}

export function getNewItemOccurrencesCount(
  itemCountsForTrainingBatch: Record<string, number>,
  currentModelVersion: ProjectModelVersion,
) {
  const trainingBatchTotal = Object.values(itemCountsForTrainingBatch).reduce(
    (acc, count) => acc + count,
    0,
  )

  if (trainingBatchTotal === 0) return 0

  const totalTrained =
    currentModelVersion.category_stats?.total_trained_count || 0

  return Math.max(trainingBatchTotal - totalTrained, 0)
}
