import { Dispatch, SetStateAction } from 'react'
import { Node } from '@xyflow/react'
import distributeNodesVertically from '../react-flow/distributeNodesVertically'
import {
  getDynamicDistanceBetweenGroups,
  getNodesByType,
  getWidestNodeFromGroup,
} from '../react-flow/helpers'

const Y_GAP = 12

type UseAutoLayoutOptions = {
  setNodes: Dispatch<SetStateAction<Node[]>>
}

export default function useAutoLayout({ setNodes }: UseAutoLayoutOptions) {
  return () => {
    setTimeout(() => {
      setNodes((prev) => {
        const nodesByType = getNodesByType(prev)
        nodesByType.infotype ??= [] // If there aren't any info types in the project
        nodesByType.field ??= [] // If there aren't any fields in the project
        nodesByType.tableNamespace ??= [] // If there aren't any fields in the project
        nodesByType.ocrResults ??= [] // If there aren't any fields in the project

        const inputInfoTypeNodes = nodesByType.infotype.filter(
          ({ data }) => data.flow === 'in',
        )
        const outputInfoTypeNodes = nodesByType.infotype.filter(
          ({ data }) => data.flow === 'out',
        )
        const inputTableNamespaceNodes = nodesByType.tableNamespace.filter(
          ({ data }) => data.flow === 'in',
        )
        const outputTableNamespaceNodes = nodesByType.tableNamespace.filter(
          ({ data }) => data.flow === 'out',
        )

        const allInputNodes = [
          ...nodesByType.ocrResults,
          ...inputTableNamespaceNodes,
          ...inputInfoTypeNodes,
        ]

        const allOutputNodes = [
          ...outputInfoTypeNodes,
          ...outputTableNamespaceNodes,
        ]

        const inputNodesToMatchWidth = [
          ...inputTableNamespaceNodes,
          ...inputInfoTypeNodes,
        ]

        // Update info type nodes so they are as wide as the widest node in their group
        const widestInputInfoTypeNode = getWidestNodeFromGroup(
          inputNodesToMatchWidth,
        )
        const widestOutputNode = getWidestNodeFromGroup(allOutputNodes)
        const widestModelNode = getWidestNodeFromGroup(nodesByType.model)
        const widestFieldNode = getWidestNodeFromGroup(nodesByType.field)
        inputNodesToMatchWidth.forEach((node) => {
          node.width = widestInputInfoTypeNode.measured?.width
        })
        allOutputNodes.forEach((node) => {
          node.width = widestOutputNode.measured?.width
        })
        nodesByType.field.forEach((node) => {
          node.width = widestFieldNode.measured?.width
        })

        // Distribute inputs and outputs vertically
        const inputs = distributeNodesVertically({
          nodes: allInputNodes,
          yGap: Y_GAP,
          startingY: 0,
        })
        const outputs = distributeNodesVertically({
          nodes: allOutputNodes,
          yGap: Y_GAP,
          startingY: 0,
        })
        const fields = distributeNodesVertically({
          nodes: nodesByType.field || [],
          yGap: Y_GAP,
          startingY: 0,
        })
        const modelNodes = distributeNodesVertically({
          nodes: nodesByType.model,
          startingY: 0,
        })

        // MARK: Update distance between outputs and fields (must be set before outputs-to-fields)
        const modelToOutputsDistance = getDynamicDistanceBetweenGroups(
          nodesByType.model,
          outputs,
        )
        outputs.forEach((node) => {
          node.position.x =
            (widestModelNode.measured?.width || 0) + modelToOutputsDistance
        })

        // MARK: Update distance between outputs and fields
        const outputsToFieldDistance = getDynamicDistanceBetweenGroups(
          outputs,
          nodesByType.field,
        )
        fields.forEach((node) => {
          node.position.x =
            (widestOutputNode.measured?.width || 0) +
            (outputs[0]?.position.x ||
              (widestModelNode.measured?.width || 0) * 2 || // In case we don't have any outputs for some reason, use the model width * 2
              0) +
            outputsToFieldDistance
        })

        // MARK: Update distance between inputs and the model
        if (inputs.length > 0) {
          const inputsToModelDistance = getDynamicDistanceBetweenGroups(
            inputs,
            nodesByType.model,
            0,
          )
          inputs.forEach((node) => {
            node.position.x =
              0 - (node.measured?.width || 0) - inputsToModelDistance
          })
        }

        return [...inputs, ...outputs, ...fields, ...modelNodes]
      })
    }, 0)
  }
}
