import { ProjectGridField } from '@/types/fields'
import {
  ProjectLinkedModel,
  ProjectLinkedModelInfoType,
} from '@/types/project-linked-models'
import { ProjectModelVersionInfoType } from '@/types/project-models'
import { useMemo } from 'react'
import { Edge } from '@xyflow/react'
import { useTheme } from '@mui/material'

type UseModelVersionEdgesOptions = {
  projectModelVersionInfoTypes: ProjectModelVersionInfoType[]
  projectLinkedModel: ProjectLinkedModel
  projectGridFields: ProjectGridField[]
  projectLinkedModelInfoTypes: ProjectLinkedModelInfoType[]
}

export default function useModelVersionEdges({
  projectModelVersionInfoTypes,
  projectLinkedModel,
  projectGridFields,
  projectLinkedModelInfoTypes,
}: UseModelVersionEdgesOptions) {
  const theme = useTheme()
  const isLabelingModel =
    projectLinkedModel.project_model?.project_model_type?.code === 'NER'

  const inputOutputToModelEdges = useMemo(() => {
    const edges = projectModelVersionInfoTypes.reduce<Edge[]>(
      (acc, versionInfoType) => {
        // If we've already created an edge for this infotype's namespace to the model, move on
        if (
          versionInfoType.flow === 'out' &&
          acc.find((e) => e.id.endsWith(versionInfoType.info_type.namespace.id))
        ) {
          return acc
        }

        const infoTypeIsInTableNamespace =
          versionInfoType.info_type.namespace.name !== 'base'
        const infoTypeNodeId = infoTypeIsInTableNamespace
          ? versionInfoType.info_type.namespace.id
          : versionInfoType.id
        let edgeId = `${projectLinkedModel.id}-${versionInfoType.id}`
        if (infoTypeIsInTableNamespace)
          edgeId += versionInfoType.info_type.namespace.id

        acc.push({
          id: edgeId,
          source:
            versionInfoType.flow === 'in'
              ? infoTypeNodeId
              : projectLinkedModel.id,
          target:
            versionInfoType.flow === 'in'
              ? projectLinkedModel.id
              : infoTypeNodeId,
          sourceHandle:
            infoTypeIsInTableNamespace && versionInfoType.flow === 'in'
              ? versionInfoType.id
              : undefined,
          type: 'default',
          animated: true,
          selectable: false,
        })
        return acc
      },
      [],
    )
    return edges
  }, [projectLinkedModel.id, projectModelVersionInfoTypes])

  const outputToFieldEdges = useMemo(() => {
    const outputVersionInfoTypes = projectModelVersionInfoTypes.filter(
      (infoType) => infoType.flow === 'out',
    )
    const edges = outputVersionInfoTypes
      .map((versionInfoType) => {
        let isInTable = false
        let tableField: ProjectGridField | undefined

        const field = projectGridFields.find((field) => {
          if (field.info_type?.id === versionInfoType.info_type.id) {
            return true
          }
          if (field.fields) {
            tableField = field.fields.find((tableField) => {
              if (tableField.info_type?.id === versionInfoType.info_type.id) {
                isInTable = true
                return true
              }
            })
            if (tableField) return true
          }
        })

        if (!field) return null

        const infoTypeIsInTableNamespace =
          versionInfoType.info_type.namespace.name !== 'base'
        const sourceHandle = infoTypeIsInTableNamespace
          ? versionInfoType.id
          : undefined
        const targetHandle = isInTable ? tableField?.id : undefined
        const edgeId = `${sourceHandle || versionInfoType.id}-${
          targetHandle || field.id
        }`

        const projectLinkedModelInfoType = projectLinkedModelInfoTypes.find(
          (linkedInfoType) =>
            linkedInfoType.info_type_id === versionInfoType.info_type.id,
        )
        const isUsed = projectLinkedModelInfoType?.use === false ? false : true

        return {
          id: edgeId,
          // If we're trying to connect from a table namespace, the source is the namespace node
          // If we're trying to connect from a base field, it's the node for the base info type
          source: infoTypeIsInTableNamespace
            ? versionInfoType.info_type.namespace.id
            : versionInfoType.id,
          // The target is always the field node, even if it is a table node
          target: field.id,
          sourceHandle,
          targetHandle,
          type: 'default',
          animated: isUsed,
          style: {
            stroke: isUsed ? undefined : theme.palette.divider,
            strokeWidth: isUsed ? undefined : 1,
            strokeDasharray: isUsed ? undefined : '5',
          },
          selectable: false,
        }
      })
      .filter(Boolean)
    return edges
  }, [
    projectGridFields,
    projectLinkedModelInfoTypes,
    projectModelVersionInfoTypes,
    theme.palette.divider,
  ])

  return useMemo(() => {
    const allEdges = [
      ...inputOutputToModelEdges,
      ...outputToFieldEdges,
    ] as Edge[]
    if (isLabelingModel) {
      allEdges.push({
        id: `ocrResults-${projectLinkedModel.id}`,
        source: 'ocrResults',
        target: projectLinkedModel.id,
        type: 'default',
        animated: true,
        selectable: false,
      })
    }
    return allEdges
  }, [
    inputOutputToModelEdges,
    isLabelingModel,
    outputToFieldEdges,
    projectLinkedModel.id,
  ])
}
