import { useCallback, useEffect, useRef, useState } from 'react'
import { WorkflowStateDocumentCount } from '@/types/document-workflow-states'
import { ProjectModelVersion } from '@/types/project-models'
import { ExportFile } from '@/services/export'
import generateUuid from '@/utils/generate-uuid'
import { useAuthentication } from '@/components/auth/AuthProvider'

let tm: ReturnType<typeof setTimeout>

export type WebsocketData = {
  action: string
  updated_entity_ids: string[]
  error?: string
  model_version?: ProjectModelVersion
  export_response?: ExportFile
  workflow_state_counts?: WorkflowStateDocumentCount
}

const ping = (ws: WebSocket) => {
  if (ws.readyState !== ws.OPEN) return
  ws.send('{"route": "/ping"}')
  tm = setTimeout(() => ws.close(), 15000)
}

const pong = () => {
  clearTimeout(tm)
}

const decodeJWT = (token: string) => {
  const base64Url = token.split('.')[1]
  const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/')
  return JSON.parse(window.atob(base64))
}

export type UseWebsocketOptions = {
  /**
   * A function to handle incoming WebSocket messages.
   * It should be memoized to prevent unnecessary re-renders.
   *
   * @param event - The WebSocket message event.
   */
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
  onMessage: (event: MessageEvent<any>) => void
}

export default function useWebsocket({ onMessage }: UseWebsocketOptions) {
  const { authenticated, getIdToken, getFreshIdToken } = useAuthentication()
  const attemptsCountRef = useRef(0)

  const [websocket, setWebsocket] = useState<WebSocket | null>(null)
  const [websocketId, setWebsocketId] = useState<string | null>(null)
  const [tokenExpiration, setTokenExpiration] = useState<Date | null>(null)
  const [activeWebsockets, setActiveWebsockets] = useState<
    Record<string, WebSocket>
  >({})

  const openSocket = useCallback(
    async (useRefreshedToken = false) => {
      if (!authenticated) return

      const fetchToken = useRefreshedToken
        ? getFreshIdToken
        : () => new Promise(getIdToken)

      const token = await fetchToken()
      setTokenExpiration(new Date(decodeJWT(token).exp * 1000))
      setWebsocket(
        new WebSocket(import.meta.env.VITE_NOTIFICATION_SOCKET, token),
      )
      setWebsocketId(generateUuid())
    },
    [authenticated, getFreshIdToken, getIdToken],
  )

  useEffect(() => {
    openSocket()
  }, [openSocket])

  useEffect(() => {
    if (tokenExpiration) {
      const timeUntilExpiration = tokenExpiration.getTime() - Date.now()
      const timeout = setTimeout(() => {
        openSocket(true)
      }, timeUntilExpiration - 2 * 60 * 1000) // Refresh 2 min before expiration
      return () => clearTimeout(timeout)
    }
  }, [openSocket, tokenExpiration])

  useEffect(() => {
    let interval: ReturnType<typeof setInterval>
    if (!websocket || !websocketId) return

    const handleOpen = () => {
      // Adding this to help with debugging for now
      websocket.send(
        JSON.stringify({
          route: '/v2/ws-connection/get-id',
        }),
      )
      interval = setInterval(() => ping(websocket), 20000)
      setActiveWebsockets((prev) => {
        return { ...prev, [websocketId]: websocket }
      })
    }

    const handleClose = (event: CloseEvent) => {
      if (event.wasClean || attemptsCountRef.current < 2) {
        openSocket()
        !event.wasClean && attemptsCountRef.current++
      } else {
        attemptsCountRef.current = 0
      }
    }

    websocket.addEventListener('open', handleOpen)
    websocket.addEventListener('close', handleClose)

    return () => {
      interval && clearInterval(interval)
      websocket.removeEventListener('open', handleOpen)
      websocket.removeEventListener('close', handleClose)
    }
  }, [openSocket, websocket, websocketId])

  useEffect(() => {
    if (!websocket) return

    const handleMessage = (event: MessageEvent) => {
      const data = JSON.parse(event.data)
      const { action } = data as WebsocketData

      if (action === 'pong') {
        pong()
        return
      }

      onMessage(event)
    }
    websocket.addEventListener('message', handleMessage)

    return () => {
      websocket.removeEventListener('message', handleMessage)
    }
  }, [onMessage, websocket])

  useEffect(() => {
    setActiveWebsockets((prev) => {
      if (
        Object.keys(prev).length > 1 &&
        websocketId &&
        Object.hasOwn(prev, websocketId)
      ) {
        Object.entries(prev).map(([wsId, ws]) => {
          if (wsId !== websocketId) {
            ws.close()
          }
        })
        return { [websocketId]: prev[websocketId] }
      }
      return prev
    })
  }, [activeWebsockets, websocketId])

  return websocketId ? activeWebsockets[websocketId] : null
}
