import { AxiosError } from 'axios'
import {
  useQueryClient,
  useMutation as useMutationHook,
  UseMutationOptions as UseMutationBaseOptions,
  QueryKey,
} from '@tanstack/react-query'
import { useAuthentication } from '@/components/auth/AuthProvider'

type UseMutationOptions<
  TData,
  TError = unknown,
  TVariables = void,
  TContext = unknown,
> = UseMutationBaseOptions<TData, TError, TVariables, TContext> & {
  sideEffectQueryKeys?: QueryKey[]
}

export default function useMutation<
  TData,
  TError = unknown,
  TVariables = void,
  TContext = unknown,
>({
  sideEffectQueryKeys = [],
  onError,
  onSettled,
  ...options
}: UseMutationOptions<TData, TError, TVariables, TContext>) {
  const { getFreshIdToken } = useAuthentication()
  const queryClient = useQueryClient()

  function cancelSideEffectQueries() {
    return Promise.all(
      sideEffectQueryKeys.map((key) => {
        queryClient.cancelQueries({ queryKey: key })
      }),
    )
  }

  async function invalidateSideEffectQueries() {
    // Cancel queries first so we don't clobber any optimistic updates we might make in hooks using this hook
    await cancelSideEffectQueries()
    sideEffectQueryKeys.forEach((key) => {
      queryClient.invalidateQueries({ queryKey: key })
    })
  }

  const mutation = useMutationHook<TData, TError, TVariables, TContext>({
    ...options,
    onMutate: (...args) => {
      // Cancel any ongoing side effect queries being fetched, since we're about to change them anyway
      cancelSideEffectQueries()
      return options.onMutate?.(...args)
    },
    onError: (error, variables, context) => {
      if ((error as AxiosError)?.response?.status === 401) {
        getFreshIdToken()?.catch(() => {})
      }
      onError?.(error, variables, context)
    },
    onSettled: (...args) => {
      cancelSideEffectQueries()
      // Only invalidate side effect queries if there are no other mutations with the same key in flight
      // If mutationKey is undefined, it counts any other mutations without a mutationKey
      invalidateSideEffectQueries()

      onSettled?.(...args)
    },
  })

  return mutation
}
