import { useReducer, useRef, useCallback, useEffect, DependencyList, useMemo } from 'react'
import React from 'react'

export type Matrix = [number, number, number, number, number, number]

// Identity matrix
const identity: Matrix = [1, 0, 0, 1, 0, 0]

// Helpers to create transform matrices
function translateMatrix(tx: number, ty: number): Matrix {
  return [1, 0, 0, 1, tx, ty]
}

function scaleMatrix(sx: number, sy: number): Matrix {
  return [sx, 0, 0, sy, 0, 0]
}

function rotateMatrix(deg: number): Matrix {
  const rad = (deg * Math.PI) / 180
  const cos = Math.cos(rad)
  const sin = Math.sin(rad)
  return [cos, sin, -sin, cos, 0, 0]
}

function multiplyMatrix(m1: Matrix, m2: Matrix): Matrix {
  const [a1, b1, c1, d1, e1, f1] = m1
  const [a2, b2, c2, d2, e2, f2] = m2
  return [
    a1 * a2 + c1 * b2,
    b1 * a2 + d1 * b2,
    a1 * c2 + c1 * d2,
    b1 * c2 + d1 * d2,
    a1 * e2 + c1 * f2 + e1,
    b1 * e2 + d1 * f2 + f1
  ]
}

export function matrixToCssString(m: Matrix): string {
  const [a, b, c, d, e, f] = m
  return `matrix(${a},${b},${c},${d},${e},${f})`
}

function useRafCallback<A extends unknown[], R>(callback: (...args: A) => R, deps: DependencyList): (...args: A) => R {
  const throttling = useRef(false)

  const memoized = useCallback((...args: A): R => {
    if (throttling.current) {
      // If callback returns void, returning undefined is fine.
      // If callback returns something else, you must provide a compatible fallback.
      return undefined as R
    }

    throttling.current = true
    requestAnimationFrame(() => {
      throttling.current = false
    })

    return callback(...args)
  }, deps)

  return memoized
}

interface UseTransformMatrixOptions {
  initialScaleX?: number
  initialScaleY?: number
  scaleStep?: number
  minScale?: number
  initialRotation?: number // in degrees
  rotationStep?: number // in degrees
  initialTranslateX?: number
  initialTranslateY?: number
}

type State = {
  translateX: number
  translateY: number
  scaleX: number
  scaleY: number
  rotation: number
}

type Action =
  | { type: 'TRANSLATE'; dx: number; dy: number }
  | { type: 'ROTATE'; deltaDeg: number }
  | { type: 'SCALE'; factor: number; centerX?: number; centerY?: number }
  | { type: 'RESET'; initial: State }
  | { type: 'FLIP_HORIZONTAL' }

function reducer(state: State, action: Action): State {
  switch (action.type) {
    case 'TRANSLATE': {
      const { dx, dy } = action
      return { ...state, translateX: state.translateX + dx, translateY: state.translateY + dy }
    }
    case 'ROTATE': {
      const { deltaDeg } = action
      return { ...state, rotation: state.rotation + deltaDeg }
    }
    case 'SCALE': {
      const { factor, centerX = 0, centerY = 0 } = action
      const { scaleX, scaleY, translateX, translateY } = state

      // Compute the point in image coordinates that corresponds to the cursor
      const imageX = (centerX - translateX) / scaleX
      const imageY = (centerY - translateY) / scaleY

      // New scale
      const newScaleX = scaleX * factor
      const newScaleY = scaleY * factor

      // New translate so that (imageX, imageY) maps to (cx, cy)
      const newTranslateX = centerX - imageX * newScaleX
      const newTranslateY = centerY - imageY * newScaleY

      return {
        ...state,
        translateX: newTranslateX,
        translateY: newTranslateY,
        scaleX: Math.max(0.1, Math.abs(state.scaleX * factor)) * Math.sign(state.scaleX * factor), // avoid non-invertible scale
        scaleY: Math.max(0.1, Math.abs(state.scaleY * factor)) * Math.sign(state.scaleY * factor)
      }
    }
    case 'FLIP_HORIZONTAL': {
      return { ...state, scaleX: -state.scaleX }
    }
    case 'RESET': {
      return action.initial
    }
    default:
      return state
  }
}

interface UseTransformMatrixReturn {
  transform: Matrix
  rotateClockwise: () => void
  rotateCounterClockwise: () => void
  zoomIn: () => void
  zoomOut: () => void
  flipHorizontal: () => void
  resetTransform: () => void
  wrapperRef: React.RefObject<HTMLDivElement>
}

export function useTransformMatrix({
  initialScaleX = 1,
  initialScaleY = 1,
  initialTranslateX = 0,
  initialTranslateY = 0,
  initialRotation = 0,
  scaleStep = 0.1,
  rotationStep = 90
}: UseTransformMatrixOptions = {}): UseTransformMatrixReturn {
  const [state, dispatch] = useReducer(reducer, {
    translateX: initialTranslateX,
    translateY: initialTranslateY,
    scaleX: initialScaleX,
    scaleY: initialScaleY,
    rotation: initialRotation
  })

  const wrapperRef = useRef<HTMLDivElement>(null)

  const lastPointer = useRef({ x: 0, y: 0 })

  const rotateClockwise = useCallback(() => {
    dispatch({ type: 'ROTATE', deltaDeg: rotationStep })
  }, [rotationStep])

  const rotateCounterClockwise = useCallback(() => {
    dispatch({ type: 'ROTATE', deltaDeg: -rotationStep })
  }, [rotationStep])

  const zoomIn = useCallback(() => {
    dispatch({ type: 'SCALE', factor: 1 + scaleStep })
  }, [scaleStep])

  const zoomOut = useCallback(() => {
    dispatch({ type: 'SCALE', factor: 1 - scaleStep })
  }, [scaleStep])

  const flipHorizontal = useCallback(() => {
    dispatch({ type: 'FLIP_HORIZONTAL' })
  }, [])

  const resetTransform = useCallback(() => {
    dispatch({
      type: 'RESET',
      initial: {
        translateX: initialTranslateX,
        translateY: initialTranslateY,
        scaleX: initialScaleX,
        scaleY: initialScaleY,
        rotation: initialRotation
      }
    })
  }, [initialRotation, initialScaleX, initialScaleY, initialTranslateX, initialTranslateY])

  const onPointerMove = useRafCallback((e: PointerEvent) => {
    const dx = e.clientX - lastPointer.current.x
    const dy = e.clientY - lastPointer.current.y
    lastPointer.current = { x: e.clientX, y: e.clientY }
    dispatch({ type: 'TRANSLATE', dx, dy })
  }, [])

  const onPointerUp = useCallback(() => {
    window.removeEventListener('pointermove', onPointerMove)
    window.removeEventListener('pointerup', onPointerUp)
  }, [onPointerMove])

  const onPointerDown = useCallback(
    (e: PointerEvent) => {
      e.preventDefault()
      lastPointer.current = { x: e.clientX, y: e.clientY }
      window.addEventListener('pointermove', onPointerMove)
      window.addEventListener('pointerup', onPointerUp)
    },
    [onPointerMove, onPointerUp]
  )

  const onWheel = useRafCallback((e: WheelEvent) => {
    const rect = (e.currentTarget as HTMLDivElement).getBoundingClientRect()
    const cx = e.clientX - rect.left - rect.width / 2
    const cy = e.clientY - rect.top - rect.height / 2

    const factor = 1 - e.deltaY * 0.01

    dispatch({
      type: 'SCALE',
      factor,
      centerX: cx,
      centerY: cy
    })
  }, [])

  useEffect(() => {
    if (!wrapperRef.current) return
    const wrapper = wrapperRef.current
    // prevent default scroll because onWheel is throttled and
    // does not prevent all scroll events
    const preventDefault = (e: Event) => e.preventDefault()

    wrapper.addEventListener('pointerdown', onPointerDown)
    wrapper.addEventListener('wheel', preventDefault)
    wrapper.addEventListener('wheel', onWheel)

    return () => {
      wrapper.removeEventListener('pointerdown', onPointerDown)
      wrapper.removeEventListener('wheel', preventDefault)
      wrapper.removeEventListener('wheel', onWheel)
    }
  }, [onPointerDown, onWheel])

  const transform = useMemo(() => {
    // Compute final matrix: Translate * Rotate * Scale
    const translateM = translateMatrix(state.translateX, state.translateY)
    const rotateM = rotateMatrix(state.rotation)
    const scaleM = scaleMatrix(state.scaleX, state.scaleY)

    let finalMatrix = identity
    finalMatrix = multiplyMatrix(finalMatrix, translateM)
    finalMatrix = multiplyMatrix(finalMatrix, rotateM)
    finalMatrix = multiplyMatrix(finalMatrix, scaleM)
    return finalMatrix
  }, [state])

  return {
    transform,
    rotateClockwise,
    rotateCounterClockwise,
    zoomIn,
    zoomOut,
    flipHorizontal,
    resetTransform,
    wrapperRef
  }
}
