import React, { useRef, useState, useEffect, useMemo } from 'react'
import * as d3 from 'd3'
import { createChartAxes } from '../shared/ChartAxes'
import { createChartSections } from '../shared/ChartSections'
import { createColorLegend } from '../shared/ColorLegend'
import { useD3 } from '../hooks/useD3'
import deepMerge from 'shared/utils/core/helpers/objects/deepObjectsMerge'

interface DataPoint {
  nameX: string
  nameY: string
  nameZ: string
  timestamp: number | null
  valueX: number
  valueY: number
  valueZ: number
}

export interface ScatterChartConfig {
  layout: {
    margin: {
      top: number
      right: number
      bottom: number
      left: number
    }
    pointSize: {
      min: number
      max: number
    }
  }
  interactions: {
    colorLegend: {
      enabled: boolean
      title: string
      titlePosition: 'left' | 'right'
      titleFontSize: number
    }
  }
  axes: {
    x: {
      label: string
      ticks: number
    }
    y: {
      label: string
      ticks: number
    }
    z: {
      label: string
    }
  }
  chartSections: {
    enabled: boolean
    massFlowRateLowerThreshold: number
    massFlowRateUpperThreshold: number
  }
}

export const defaultConfig: ScatterChartConfig = {
  layout: {
    margin: { top: 5, right: 100, bottom: 60, left: 60 },
    pointSize: { min: 1.5, max: 5 },
  },
  interactions: {
    colorLegend: {
      enabled: true,
      title: '',
      titlePosition: 'right',
      titleFontSize: 12,
    },
  },
  axes: {
    x: {
      label: '',
      ticks: 10,
    },
    y: {
      label: '',
      ticks: 10,
    },
    z: {
      label: '',
    },
  },
  chartSections: {
    enabled: false,
    massFlowRateLowerThreshold: 0,
    massFlowRateUpperThreshold: 0,
  },
}

export type ScatterChartCustomConfig = {
  layout?: {
    margin?: Partial<ScatterChartConfig['layout']['margin']>
    pointSize?: Partial<ScatterChartConfig['layout']['pointSize']>
  }
  interactions?: {
    colorLegend?: Partial<ScatterChartConfig['interactions']['colorLegend']>
  }
  axes?: {
    x?: Partial<ScatterChartConfig['axes']['x']>
    y?: Partial<ScatterChartConfig['axes']['y']>
    z?: Partial<ScatterChartConfig['axes']['z']>
  }
  chartSections?: Partial<ScatterChartConfig['chartSections']>
}

interface Props {
  data: DataPoint[]
  config?: ScatterChartCustomConfig
}

const useChartDimensions = (containerRef: React.RefObject<HTMLDivElement>) => {
  const [dimensions, setDimensions] = useState({ width: 0, height: 0 })

  useEffect(() => {
    if (!containerRef.current) return

    const resizeObserver = new ResizeObserver((entries) => {
      if (entries[0]) {
        const { width, height } = entries[0].contentRect
        setDimensions({ width, height })
      }
    })

    resizeObserver.observe(containerRef.current)

    return () => {
      resizeObserver.disconnect()
    }
  }, [])

  return dimensions
}

const ScatterChart: React.FC<Props> = ({ data, config: customConfig }) => {
  const containerRef = useRef<HTMLDivElement>(null)
  const dimensions = useChartDimensions(containerRef)

  const config = useMemo(
    () => deepMerge({ ...defaultConfig }, customConfig || {}) as ScatterChartConfig,
    [customConfig]
  )

  const createChart = (svg: d3.Selection<SVGSVGElement, unknown, null, undefined>) => {
    if (!data || !dimensions || dimensions.width === 0 || dimensions.height === 0) return

    // Clear previous chart
    svg.selectAll('*').remove()

    const { width, height } = dimensions
    const chartWidth = width - config.layout.margin.left - config.layout.margin.right
    const chartHeight = height - config.layout.margin.top - config.layout.margin.bottom

    // Set up SVG attribute
    svg.attr('width', width).attr('height', height)

    // Create scales
    const maxX = d3.max(data, (d) => d.valueX) || 0
    const maxY = d3.max(data, (d) => d.valueY) || 0
    const xAxisPadding = maxX * 0.1 // Add 10% padding for X-axis

    const xScale = d3
      .scaleLinear()
      .domain([0, maxX + xAxisPadding])
      .range([0, chartWidth])
      .nice(config.axes.x.ticks)

    const yScale = d3
      .scaleLinear()
      .domain([0, maxY * 1.1])
      .range([chartHeight, 0])
      .nice(config.axes.y.ticks)

    const groupedData = d3.group(
      data,
      (d) => Math.round(xScale(d.valueX) * 10) / 10,
      (d) => Math.round(yScale(d.valueY) * 10) / 10
    )

    const aggregatedData = Array.from(groupedData, ([x, yGroups]) =>
      Array.from(yGroups, ([y, points]) => ({
        x: +x,
        y: +y,
        count: points.length,
        avgZ: d3.mean(points, (d) => d.valueZ) || 0,
        nameX: points[0].nameX,
        nameY: points[0].nameY,
        nameZ: points[0].nameZ,
      }))
    ).flat()

    const sortedZValues = data.map((d) => d.valueZ).sort(d3.ascending)
    const cutoff = d3.quantile(sortedZValues, 0.99) || d3.max(data, (d) => d.valueZ) || 0
    const colorScale = d3
      .scaleSequential()
      .domain([d3.min(data, (d) => d.valueZ) || 0, cutoff])
      .interpolator(d3.interpolateViridis)

    if (config.chartSections.enabled) {
      createChartSections({
        svg,
        width,
        height,
        margin: config.layout.margin,
        yScale,
        massFlowRateLowerThreshold: config.chartSections.massFlowRateLowerThreshold,
        massFlowRateUpperThreshold: config.chartSections.massFlowRateUpperThreshold,
      })
    }

    const chartGroup = createChartAxes({
      svg,
      width,
      height,
      margin: config.layout.margin,
      xScale,
      yScale,
      xAxisLabel: config.axes.x.label,
      yAxisLabel: config.axes.y.label,
      y1AxisLabel: config.axes.z.label,
      yAxisTickCount: config.axes.y.ticks,
    })

    if (config.interactions.colorLegend.enabled) {
      createColorLegend({
        svg,
        width,
        height,
        margin: config.layout.margin,
        colorScale,
        title: config.interactions.colorLegend.title,
        titlePosition: config.interactions.colorLegend.titlePosition,
        titleFontSize: config.interactions.colorLegend.titleFontSize,
      })
    }

    // Create a unique identifier for the clipPath
    const clipPathId = `scatter-chart-clip-${Math.random().toString(36).substr(2, 9)}`

    // Add clipPath to crop points that extend beyond the chart boundaries
    chartGroup
      .append('defs')
      .append('clipPath')
      .attr('id', clipPathId)
      .append('rect')
      .attr('width', chartWidth)
      .attr('height', chartHeight)

    // Create a group for points with clipPath applied
    const pointsGroup = chartGroup.append('g').attr('clip-path', `url(#${clipPathId})`)

    // Add points to the group with clipPath
    pointsGroup
      .selectAll('circle')
      .data(aggregatedData)
      .enter()
      .append('circle')
      .attr('cx', (d) => d.x)
      .attr('cy', (d) => d.y)
      .attr('r', (d) => Math.max(1.5, Math.min(5, 1.5 + Math.log(d.count))))
      .attr('fill', (d) => colorScale(d.avgZ))
      .attr('opacity', 1)
  }

  const svgRef = useD3(createChart, [data, dimensions, config])

  return (
    <div
      ref={containerRef}
      style={{ width: '100%', height: '100%', userSelect: 'none', WebkitUserSelect: 'none' }}
    >
      <svg
        ref={svgRef}
        style={{ width: '100%', height: '100%', userSelect: 'none', WebkitUserSelect: 'none' }}
      />
    </div>
  )
}

export default ScatterChart
