import React from "react";
import { useD3 } from "../../hooks";
import {
  addAxisLabels,
  addBandAndGuideLegend,
  addChartBands,
  addChartGrid,
  addChartGuides,
  calculateBandAndGuideRows,
  calculatePlotHeight,
  calculateRightMargin,
  calculateRotatedLabelHeight,
  getSeriesColor,
} from "./utilities";
import { AxisSettings, Band, Guide, Series } from "./interfaces";
import addSeriesLegend from "./utilities/addSeriesLegend";
import Constants from "./Constants";
import { groupArray } from "../../utilities";
import { scaleLinear } from "d3";
import Defaults from "./Defaults";

type ScatterSeries<TData> = Series<TData> & {
  getElementGroup: (point: TData) => string;
};

interface ScatterChartProps<TData> {
  data: TData[][];
  getPointXValue: (pointGroup: TData[]) => number;
  getPointYValue: (pointGroup: TData[]) => number;
  showGrid?: boolean;
  axis?: AxisSettings;
  guides?: Guide[];
  bands?: Band[];
  series?: ScatterSeries<TData>;
  styles?: {
    gridColor?: string;
  };
}

export default function ScatterChart<TData>({
  data,
  getPointXValue,
  getPointYValue,
  showGrid,
  axis,
  styles,
  guides,
  bands,
  series,
}: ScatterChartProps<TData>) {
  const ref = useD3(
    (svg) => {
      const {
        plotWidth,
        axisLabelHeight,
        axisLabelMarginBottom,
        axisLabelMarginTop,
      } = Constants;
      const canvas = document.createElement("canvas");
      let xMax: number = axis?.x?.max ?? 0;
      let yMax: number = axis?.y?.max ?? 0;
      let xMin = axis?.x?.min ?? 0;
      let yMin = axis?.y?.min ?? 0;
      const width = plotWidth;
      const margin = {
        top: 40,
        right: calculateRightMargin(data, guides, series),
        left: 80,
      };
      let graphData: {
        xValue: number;
        yValue: number;
        x: number;
        y: number;
      }[][] = [];

      graphData = data.map((seriesData) => {
        return groupArray(
          seriesData,
          series?.getElementGroup ?? ((p) => "ALL")
        ).map((p) => {
          const vX = getPointXValue(p.data);
          const vY = getPointYValue(p.data);

          if (vX > xMax) xMax = vX;
          if (vX < xMin) xMin = vX;
          if (vY > yMax) yMax = vY;
          if (vY < yMin) yMin = vY;

          return {
            xValue: vX,
            yValue: vY,
            x: 0,
            y: 0,
          };
        });
      });

      const xAxisLabelHeight = calculateRotatedLabelHeight(0, 0);
      const bandAndGuideLegend = calculateBandAndGuideRows(
        guides ?? [],
        bands ?? [],
        canvas
      );
      let height = calculatePlotHeight(
        data,
        series,
        xAxisLabelHeight,
        bandAndGuideLegend.totalHeight,
        canvas
      );

      if (axis?.x?.label)
        height -= axisLabelHeight + axisLabelMarginBottom + axisLabelMarginTop;

      const xAxisScale = scaleLinear()
        .domain([xMin, xMax])
        .range([margin.left, width - margin.right])
        .nice();
      const yAxisScale = scaleLinear()
        .domain([yMin, yMax])
        .range([height, margin.top])
        .nice();

      graphData.forEach((seriesData, seriesIndex) => {
        seriesData.forEach((p) => {
          p.x = xAxisScale(p.xValue);
          p.y = yAxisScale(p.yValue);
        });

        const seriesGroup = svg.select(".plot-area").append("g");

        // Data points
        seriesGroup
          .selectAll("circle")
          .data(seriesData)
          .enter()
          .append("circle")
          .attr("cx", (p) => p.x)
          .attr("cy", (p) => p.y)
          .attr("r", 2)
          .attr("fill", (_, i) =>
            getSeriesColor(data[seriesIndex][i], seriesIndex, series)
          );
      });

      // X Axis
      const xAxisTicks: { value: number; x: number }[] = xAxisScale
        .ticks(axis?.x?.tickCount ?? 10)
        .map((t) => ({
          value: t,
          x: xAxisScale(t),
        }));
      const xPlotWidth = width - margin.left - margin.right;

      svg
        .select(".x-axis")
        .append("line")
        .attr("x1", margin.left)
        .attr("x2", margin.left + xPlotWidth)
        .attr("y1", height)
        .attr("y2", height)
        .style("stroke", axis?.x?.color ?? "#000000")
        .style("stroke-width", 1);
      svg
        .select(".x-axis")
        .selectAll("text")
        .data(xAxisTicks)
        .enter()
        .append("text")
        .attr("text-anchor", "middle")
        .style("font-size", "12px")
        .attr("x", (p) => p.x)
        .attr("y", height + 20)
        .text((p) => p.value.toFixed(axis?.x?.tickPrecision ?? 2));

      // Y Axis
      const yAxisTicks: { value: number; y: number }[] = yAxisScale
        .ticks(axis?.y?.tickCount ?? 10)
        .map((t) => ({
          value: t,
          y: yAxisScale(t),
        }));

      svg
        .select(".y-axis")
        .append("line")
        .attr("x1", margin.left)
        .attr("x2", margin.left)
        .attr("y1", margin.top)
        .attr("y2", height)
        .style("stroke", axis?.y?.color ?? "#000000")
        .style("stroke-width", 1);
      svg
        .select(".y-axis")
        .selectAll("text")
        .data(yAxisTicks)
        .enter()
        .append("text")
        .attr("text-anchor", "end")
        .style("font-size", "12px")
        .attr("x", margin.left - 5)
        .attr("y", (p) => p.y + 5)
        .text((p) => p.value.toFixed(axis?.y?.tickPrecision ?? 2));

      // Axis Labels
      if (axis)
        addAxisLabels(
          svg,
          canvas,
          axis,
          margin,
          width,
          height,
          xAxisLabelHeight
        );

      // Bands
      if (bands?.length)
        addChartBands(svg, bands, margin, yMin, yMax, width, height, "Linear");

      // Guides
      if (guides?.length)
        addChartGuides(
          svg,
          guides,
          margin,
          yMin,
          yMax,
          width,
          height,
          "Linear"
        );

      // Guide and Band Legend
      if (guides?.length || bands?.length)
        addBandAndGuideLegend(
          svg,
          bandAndGuideLegend.labelData,
          bandAndGuideLegend.totalHeight
        );

      // Grid
      if (showGrid)
        addChartGrid(
          svg,
          margin,
          xAxisTicks
            .filter((_, i) => i > 0)
            .map((v) => ({ x: v.x - margin.left })),
          yAxisTicks
            .filter((_, i) => i > 0)
            .map((v) => ({ y: v.y - margin.top })),
          height,
          width,
          styles?.gridColor
        );

      // Series Legend
      if (data.length > 1)
        addSeriesLegend(
          svg,
          series,
          data,
          width,
          bandAndGuideLegend.totalHeight
        );

      canvas.remove();
    },
    [data]
  );

  return (
    <svg
      key={new Date().getTime()}
      ref={ref}
      className="h-app-chart-scatter"
      viewBox={`0 0 ${Constants.renderWidth} ${Constants.renderHeight}`}
      height={Constants.renderHeight}
      width="100%"
      preserveAspectRatio="xMidYMid meet"
    >
      <Defaults />
      <g className="bands" />
      <g className="grid" />
      <g className="guides" />
      <g className="plot-area" />
      <g className="x-axis" />
      <g className="y-axis" />
      <g className="series" />
    </svg>
  );
}
