import {
  ProjectLinkedModel,
  ProjectLinkedModelInfoType,
} from '@/types/project-linked-models'
import { ProjectModelVersionInfoType } from '@/types/project-models'
import { useTheme } from '@mui/material'
import { useMemo } from 'react'
import { ProjectGridField } from '@/types/fields'
import { Node } from '@xyflow/react'
import { ocrResultsNode } from './static-nodes'
import {
  getInputOutputInfoTypes,
  getVersionInfoTypesByNamespace,
} from './helpers'

type UseModelVersionNodesOptions = {
  projectModelVersionInfoTypes: ProjectModelVersionInfoType[]
  projectLinkedModel: ProjectLinkedModel
  projectGridFields: ProjectGridField[]
  projectLinkedModelInfoTypes: ProjectLinkedModelInfoType[]
}

const commonStyle = {
  background: 'transparent',
  border: 'solid 0px transparent',
  padding: 0,
  overflow: 'visible',
}

export default function useModelVersionNodes({
  projectModelVersionInfoTypes,
  projectLinkedModel,
  projectGridFields,
  projectLinkedModelInfoTypes,
}: UseModelVersionNodesOptions) {
  const theme = useTheme()

  const { base, tables } = useMemo(() => {
    const namespaceMap = getVersionInfoTypesByNamespace(
      projectModelVersionInfoTypes,
    )
    return {
      base: Object.values(namespaceMap).find(
        (namespaceDetails) => namespaceDetails.namespace.name === 'base',
      ),
      tables: Object.values(namespaceMap).filter(
        (namespaceDetails) => namespaceDetails.namespace.name !== 'base',
      ),
    }
  }, [projectModelVersionInfoTypes])

  const isLabelingModel =
    projectLinkedModel.project_model?.project_model_type?.code === 'NER'

  const hasMappingMap = useMemo(() => {
    return projectGridFields.reduce<Record<string, ProjectGridField>>(
      (acc, field) => {
        function hasMapping(fieldInfoTypeId?: string) {
          if (!fieldInfoTypeId) return
          return projectModelVersionInfoTypes.some(
            (verInfoType) => verInfoType.info_type.id === fieldInfoTypeId,
          )
        }
        if (hasMapping(field.info_type?.id)) {
          acc[field.id] = field
        }
        if (field.fields) {
          field.fields.forEach((tableField) => {
            if (hasMapping(tableField.info_type?.id)) {
              acc[tableField.id] = tableField
            }
          })
        }
        return acc
      },
      {},
    )
  }, [projectGridFields, projectModelVersionInfoTypes])

  return useMemo(() => {
    const infoTypeNodes =
      base?.versionInfoTypes.map((versionInfoType) => {
        return {
          id: versionInfoType.id,
          type: 'infotype',
          position: {
            x: 0,
            y: 0,
          },
          data: {
            flow: versionInfoType.flow,
            label: versionInfoType.info_type.name,
            showSourceHandle: true,
            showTargetHandle: versionInfoType.flow === 'out',
            versionInfoType,
            projectLinkedModel,
            projectLinkedModelInfoType: projectLinkedModelInfoTypes.find(
              (linkedModelInfoType) =>
                linkedModelInfoType.info_type_id ===
                versionInfoType.info_type.id,
            ),
          },
          draggable: false,
          style: {
            ...commonStyle,
            color: theme.palette.text.primary,
          },
        }
      }) || []

    const tableNamespaceNodes = tables.map((namespaceDetails) => {
      return {
        id: namespaceDetails.namespace.id,
        type: 'tableNamespace',
        position: {
          x: 0,
          y: 0,
        },
        data: {
          namespaceDetails,
          flow: namespaceDetails.versionInfoTypes[0].flow,
          projectLinkedModel,
          projectLinkedModelInfoTypes,
        },
        draggable: false,
        style: {
          ...commonStyle,
          color: theme.palette.text.primary,
        },
      }
    })

    const fieldNodes = projectGridFields.map((field) => {
      return {
        id: field.id,
        type: 'field',
        position: {
          x: 0,
          y: 0,
        },
        data: {
          field,
          hasMappingMap,
        },
        draggable: false,
        style: {
          ...commonStyle,
          color: theme.palette.text.primary,
        },
      }
    })

    const hasInputs =
      infoTypeNodes.some(({ data }) => data.flow === 'in') || isLabelingModel

    const modelNode = {
      id: projectLinkedModel.id,
      type: 'model',
      data: {
        showTargetHandle: hasInputs,
        showSourceHandle: true,
        projectModel: projectLinkedModel.project_model,
      },
      draggable: false,
      selectable: false,
      position: {
        x: 0,
        y: 0,
      },
      style: {
        background: 'transparent',
        color: theme.palette.text.primary,
        border: 'solid 0px transparent',
      },
    }

    const { inputInfoTypeNodes, outputInfoTypeNodes } =
      getInputOutputInfoTypes(infoTypeNodes)

    const sortedOutputInfoTypeNodes = outputInfoTypeNodes.sort((a, b) => {
      const sortedFieldsWithMapping = Object.values(hasMappingMap)
      const aIndex = sortedFieldsWithMapping.findIndex(
        (field) =>
          // @ts-expect-error -- Not going to bother typing this
          field.info_type?.id === a.data.versionInfoType?.info_type?.id,
      )
      const bIndex = sortedFieldsWithMapping.findIndex(
        (field) =>
          // @ts-expect-error -- Not going to bother typing this
          field.info_type?.id === b.data.versionInfoType?.info_type?.id,
      )
      return aIndex - bIndex
    })

    const allNodes = [
      ...inputInfoTypeNodes,
      ...sortedOutputInfoTypeNodes,
      ...tableNamespaceNodes,
      ...fieldNodes,
      modelNode,
    ] as Node[]

    if (isLabelingModel) {
      allNodes.push(ocrResultsNode)
    }

    return allNodes
  }, [
    base?.versionInfoTypes,
    hasMappingMap,
    isLabelingModel,
    projectGridFields,
    projectLinkedModel,
    projectLinkedModelInfoTypes,
    tables,
    theme.palette.text.primary,
  ])
}
