import { useEffect, useRef, useState } from 'react'
import * as d3 from 'd3'
import { createColorLegend } from './chart-components/ColorLegend'

interface HeatmapData {
  timeRange: number
  massFlowRange: string
  percentage: number
}

interface Props {
  data: HeatmapData[]
  margin?: { top: number; right: number; bottom: number; left: number }
}

const HeatmapChart = ({ data, margin = { top: 20, right: 100, bottom: 80, left: 100 } }: Props) => {
  const svgRef = useRef<SVGSVGElement>(null)
  const axesRef = useRef<SVGSVGElement>(null)
  const containerRef = useRef<HTMLDivElement>(null)
  const scrollContainerRef = useRef<HTMLDivElement>(null)
  const miniMapRef = useRef<SVGSVGElement>(null)
  const [dimensions, setDimensions] = useState({ width: 0, height: 0 })
  const columsPerView = 13

  useEffect(() => {
    if (!data || !svgRef.current || !axesRef.current || !miniMapRef.current || dimensions.width === 0) return

    const { width, height } = dimensions

    const svg = d3.select(svgRef.current)
    svg.selectAll('*').remove()

    const axesSvg = d3.select(axesRef.current)
    axesSvg.selectAll('*').remove()

    const miniMap = d3.select(miniMapRef.current)
    miniMap.selectAll('*').remove()

    const uniqueTimeRanges = Array.from(new Set(data.map((d) => d.timeRange.toString())))
    const totalColumns = uniqueTimeRanges.length

    const availableWidth = width - margin.left - margin.right
    const cellWidth = availableWidth / Math.min(columsPerView, totalColumns)
    const totalWidth = cellWidth * totalColumns

    const miniMapWidth = width - margin.left - margin.right
    const miniCellWidth = miniMapWidth / uniqueTimeRanges.length

    const xScale = d3
      .scaleBand()
      .domain(uniqueTimeRanges)
      .range([margin.left, totalWidth + margin.left])
      .padding(0)

    const yScale = d3
      .scaleBand()
      .domain(Array.from(new Set(data.map((d) => d.massFlowRange))).reverse())
      .range([10, height - margin.bottom])
      .padding(0)

    const xScaleMini = d3.scaleBand().domain(uniqueTimeRanges).range([0, miniMapWidth]).padding(0)

    const yScaleMini = d3
      .scaleBand()
      .domain(Array.from(new Set(data.map((d) => d.massFlowRange))))
      .range([100, 0])
      .padding(0)

    const colorScale = d3
      .scaleSequential()
      .domain([0, 100])
      .interpolator((t: number) => {
        if (t < 0.5) {
          return d3.interpolateRgb('#ffffff', '#ffff00')(t * 2)
        } else {
          return d3.interpolateRgb('#ffff00', '#ff0000')((t - 0.5) * 2)
        }
      })

    const cells = svg.append('g').selectAll('g').data(data).join('g')

    cells
      .append('rect')
      .attr('x', (d) => xScale(String(d.timeRange)) || 0)
      .attr('y', (d) => yScale(d.massFlowRange) || 0)
      .attr('width', xScale.bandwidth())
      .attr('height', yScale.bandwidth())
      .attr('fill', (d) => colorScale(d.percentage))
      .attr('opacity', (d) => (d.percentage === 0 ? 0.1 : 1))

    cells
      .filter((d) => d.percentage > 0)
      .append('text')
      .attr('x', (d) => (xScale(String(d.timeRange)) || 0) + xScale.bandwidth() / 2)
      .attr('y', (d) => (yScale(d.massFlowRange) || 0) + yScale.bandwidth() / 2)
      .attr('text-anchor', 'middle')
      .attr('dominant-baseline', 'middle')
      .style('font-size', '10px')
      .style('fill', (d) => (d.percentage > 50 ? 'white' : 'black'))
      .text((d) => (d.percentage >= 20 ? `${d.percentage.toFixed(1)}%` : ''))

    const miniCells = miniMap
      .append('g')
      .selectAll('rect')
      .data(data)
      .join('rect')
      .attr('x', (d) => xScaleMini(String(d.timeRange)) || 0)
      .attr('y', (d) => yScaleMini(d.massFlowRange) || 0)
      .attr('width', xScaleMini.bandwidth())
      .attr('height', yScaleMini.bandwidth())
      .attr('fill', (d) => colorScale(d.percentage))
      .attr('opacity', (d) => (d.percentage === 0 ? 0.1 : 1))

    miniMap
      .append('rect')
      .attr('x', 0)
      .attr('y', 0)
      .attr('width', miniMapWidth)
      .attr('height', 100)
      .attr('fill', 'none')
      .attr('stroke', 'black')
      .attr('stroke-width', 1)

    const visibleCells = Math.floor((width - margin.left - margin.right) / cellWidth)

    const brushWidth = visibleCells * miniCellWidth
    const brush = d3
      .brushX()
      .extent([
        [0, 0],
        [miniMapWidth, 100],
      ])
      .on('brush end', (event) => {
        if (!event.sourceEvent) return
        if (!event.selection) return
        if (!scrollContainerRef.current) return

        const [x0] = event.selection as [number, number]
        const scrollPercentage = x0 / (miniMapWidth - brushWidth)
        const maxScroll = totalWidth - (width - margin.left - margin.right)
        const scrollLeft = maxScroll * scrollPercentage

        scrollContainerRef.current.scrollLeft = scrollLeft

        const startIdx = Math.floor((x0 / miniMapWidth) * uniqueTimeRanges.length)
        const endIdx = Math.min(
          Math.ceil(((x0 + brushWidth) / miniMapWidth) * uniqueTimeRanges.length) - 1,
          uniqueTimeRanges.length - 1
        )

        miniMap.selectAll('.brush-time-label').remove()
      })

    const brushGroup = miniMap.append('g').attr('class', 'brush')
    brushGroup.call(brush)

    brushGroup.call(brush.move, [0, brushWidth])

    brushGroup.call((g) => {
      g.selectAll('.handle').remove()
      g.selectAll('.overlay').remove()
    })

    miniMap
      .append('rect')
      .attr('x', 0)
      .attr('y', 0)
      .attr('width', miniMapWidth)
      .attr('height', 100)
      .attr('fill', 'none')
      .attr('stroke', 'black')
      .attr('stroke-width', 1)

    const xAxis = (g: d3.Selection<SVGGElement, unknown, null, undefined>) =>
      g
        .attr('transform', `translate(0,${height - margin.bottom})`)
        .call(d3.axisBottom(xScale))
        .selectAll('text')
        .style('text-anchor', 'middle')
        .each(function (d) {
          const date = new Date(Number(d))
          const timeString = `${date.getHours()}:${String(date.getMinutes()).padStart(2, '0')}`
          const dateString = `${date.getDate()}/${date.getMonth() + 1}/${date.getFullYear()}`

          d3.select(this)
            .text(null)
            .append('tspan')
            .attr('x', 0)
            .attr('dy', '1em')
            .text(timeString)
            .append('tspan')
            .attr('x', 0)
            .attr('dy', '1em')
            .text(dateString)
        })

    const yAxis = (g: d3.Selection<SVGGElement, unknown, null, undefined>) =>
      g
        .attr('transform', `translate(${margin.left},0)`)
        .call(d3.axisLeft(yScale))
        .selectAll('text')
        .text(function (d) {
          return d as string
        })

    axesSvg.append('g').call(yAxis)

    svg.append('g').call(xAxis)

    axesSvg
      .append('text')
      .attr('x', width / 2)
      .attr('y', height - margin.bottom / 4)
      .style('text-anchor', 'middle')
      .text('Time Periods (2-hour intervals)')

    axesSvg
      .append('text')
      .attr('transform', 'rotate(-90)')
      .attr('x', -height / 2)
      .attr('y', margin.left / 3)
      .style('text-anchor', 'middle')
      .text('Mass Flow Rate Ranges')

    axesSvg
      .append('text')
      .attr('transform', 'rotate(90)')
      .attr('x', height / 2)
      .attr('y', -(width - margin.right / 2 + 30))
      .style('text-anchor', 'middle')
      .text('Percentage')

    createColorLegend({
      svg: axesSvg,
      width,
      height,
      margin,
      colorScale,
      gradientColors: ['#ffffff', '#ffff00', '#ff0000'],
      useCustomGradient: true,
    })

    if (scrollContainerRef.current) {
      scrollContainerRef.current.addEventListener('scroll', () => {
        if (!scrollContainerRef.current) return
        const scrollLeft = scrollContainerRef.current.scrollLeft
        const maxScroll = totalWidth - (width - margin.left - margin.right)
        const scrollPercentage = scrollLeft / maxScroll
        const brushX = scrollPercentage * (miniMapWidth - brushWidth)
        brushGroup.call(brush.move, [brushX, brushX + brushWidth])
      })
    }
  }, [data, dimensions, margin])

  useEffect(() => {
    const updateDimensions = () => {
      if (!containerRef.current) return
      const { width, height } = containerRef.current.getBoundingClientRect()
      setDimensions({ width, height })
    }

    updateDimensions()
    const resizeObserver = new ResizeObserver(updateDimensions)
    if (containerRef.current) {
      resizeObserver.observe(containerRef.current)
    }

    return () => resizeObserver.disconnect()
  }, [])

  return (
    <div style={{ position: 'relative', width: '100%', height: '100%' }}>
      <div
        ref={containerRef}
        style={{
          width: '100%',
          height: 'calc(100% - 120px)',
          position: 'relative',
          paddingLeft: `${margin.left}px`,
          paddingRight: `${margin.right}px`,
        }}
      >
        <div
          ref={scrollContainerRef}
          style={{
            width: '100%',
            height: '100%',
            overflowX: 'hidden',
            position: 'relative',
          }}
        >
          <svg
            ref={svgRef}
            style={{
              display: 'block',
              width: `${Math.max(100, (data.length * 100) / Math.min(columsPerView, data.length))}%`,
              height: '100%',
              marginLeft: `-${margin.left}px`,
            }}
          ></svg>
        </div>
        <svg
          ref={axesRef}
          style={{
            position: 'absolute',
            top: 0,
            left: 0,
            pointerEvents: 'none',
            zIndex: 1,
            width: '100%',
            height: '100%',
          }}
        ></svg>
      </div>
      <svg
        ref={miniMapRef}
        style={{
          width: `calc(100% - ${margin.left + margin.right}px)`,
          height: '120px',
          display: 'block',
          marginTop: '10px',
          marginLeft: `${margin.left}px`,
        }}
      ></svg>
    </div>
  )
}

export default HeatmapChart
