import { quantile, scaleLog, scaleLinear } from "d3";
import {
  addAxisLabels,
  addBandAndGuideLegend,
  addChartBands,
  addChartGrid,
  addChartGuides,
  calculateBandAndGuideRows,
  calculatePlotHeight,
  calculateRightMargin,
  calculateRotatedLabelHeight,
  measureTextWidth,
} from "./utilities";
import { AxisSettings, Guide, Band, Series } from "./interfaces";
import { useD3 } from "../../hooks";
import Constants from "./Constants";
import React from "react";
import Defaults from "./Defaults";

type WhiskerType = "MinMax" | "IQR";

interface BoxAndWhiskerChartProps<TData> {
  data: TData[][];
  getPointValue: (point: TData) => number;
  showGrid?: boolean;
  whiskerType?: WhiskerType;
  showLatestValue?: boolean;
  axis?: AxisSettings;
  guides?: Guide[];
  bands?: Band[];
  series?: Series<TData>;
  styles?: {
    gridColor?: string;
    barColor?: string;
    outlierTopColor?: string;
    outlierBottomColor?: string;
    barStrokeColor?: string;
    latestValueColor?: string;
  };
}

function extractSeriesData<TData>(
  data: TData[][],
  getPointValue: (point: TData) => number,
  series: Series<TData>,
  whiskerType: WhiskerType,
  canvas: HTMLCanvasElement
) {
  let plotMin = 0;
  let plotMax = 0;
  let maxLabelWidth = 0;

  return {
    seriesData: data.map((s, i) => {
      const values = s.map((p) => getPointValue(p));
      const q0 = quantile(values, 0) ?? 0;
      const q1 = quantile(values, 0.25) ?? 0;
      const q2 = quantile(values, 0.5) ?? 0;
      const q3 = quantile(values, 0.75) ?? 0;
      const q4 = quantile(values, 1) ?? 0;
      const iqr = q3 - q1;
      const min = whiskerType === "IQR" ? q1 - 1.5 * iqr : q0;
      const max = whiskerType === "IQR" ? q3 + 1.5 * iqr : q4;
      const outliers = values.filter((d) => {
        if (d < plotMin) plotMin = d;
        if (d > plotMax) plotMax = d;

        return d < min || d > max;
      });
      const label = series.getSeriesName?.(s[0], i) ?? `Series ${i + 1}`;
      const labelWidth = measureTextWidth(label, canvas);

      if (labelWidth > maxLabelWidth) maxLabelWidth = labelWidth;
      if (min < plotMin) plotMin = min;
      if (max > plotMax) plotMax = max;

      return {
        label,
        labelWidth,
        min,
        max,
        median: q2,
        firstQuartile: q1,
        thirdQuartile: q3,
        iqr,
        values,
        outliers,
      };
    }),
    min: plotMin,
    max: plotMax,
    maxLabelWidth,
  };
}

export default function BoxAndWhiskerChart<TData>({
  data,
  getPointValue,
  showGrid,
  whiskerType,
  axis,
  styles,
  series,
  guides,
  bands,
}: BoxAndWhiskerChartProps<TData>) {
  const ref = useD3(
    (svg) => {
      const {
        plotWidth,
        axisLabelHeight,
        axisLabelMarginBottom,
        axisLabelMarginTop,
      } = Constants;
      const canvas = document.createElement("canvas");
      const width = plotWidth;
      const margin = {
        top: 40,
        right: calculateRightMargin([], guides, series),
        left: 80,
      };
      let { seriesData, min, max, maxLabelWidth } = extractSeriesData(
        data,
        getPointValue,
        series ?? {},
        whiskerType ?? "MinMax",
        canvas
      );
      const barWidth = (width - margin.left - margin.right) / seriesData.length;
      const barGutter = barWidth / 2;
      const strokeWidth = 1;
      const defaultStrokeColor = styles?.barStrokeColor ?? "rgb(50,80,150)";

      if (axis?.y?.min && axis?.y?.min < min) min = axis.y.min;
      if (axis?.y?.max && axis?.y?.max > max) max = axis.y.max;

      const xAxisLabelHeight = calculateRotatedLabelHeight(
        maxLabelWidth,
        maxLabelWidth > barWidth ? 35 : 0
      );
      const bandAndGuideLegend = calculateBandAndGuideRows(
        guides ?? [],
        bands ?? [],
        canvas
      );
      let height = calculatePlotHeight(
        data,
        series,
        xAxisLabelHeight,
        bandAndGuideLegend.totalHeight,
        canvas
      );

      if (axis?.x?.label)
        height -= axisLabelHeight + axisLabelMarginBottom + axisLabelMarginTop;

      const yAxisScale = (axis?.y?.scale === "Log" ? scaleLog() : scaleLinear())
        .domain([min, max])
        .range([height, margin.top])
        .nice();

      const plotGroups = svg
        .select(".plot-area")
        .selectAll("g")
        .data(seriesData)
        .enter();

      // Boxes
      plotGroups
        .append("polygon")
        .attr("points", (p, i) => {
          const x1 = margin.left + i * barWidth + barGutter / 2;
          const x2 = x1 + barWidth - barGutter;
          const y1 = yAxisScale(p.firstQuartile);
          const y2 = yAxisScale(p.thirdQuartile);

          return `${x1},${y1} ${x2},${y1} ${x2},${y2} ${x1},${y2}`;
        })
        .attr("stroke", defaultStrokeColor)
        .attr("stroke-width", strokeWidth)
        .attr("fill", styles?.barColor ?? "rgb(80,180,255)");

      // Medians
      plotGroups
        .append("line")
        .attr("x1", (_, i) => margin.left + i * barWidth + barGutter / 2)
        .attr("x2", (_, i) => margin.left + (i + 1) * barWidth - barGutter / 2)
        .attr("y1", (p) => yAxisScale(p.median))
        .attr("y2", (p) => yAxisScale(p.median))
        .attr("stroke", defaultStrokeColor)
        .attr("stroke-width", strokeWidth);

      // Whiskers
      plotGroups
        .append("line")
        .attr(
          "x1",
          (_, i) => margin.left + i * barWidth + barWidth / 2 - strokeWidth / 2
        )
        .attr(
          "x2",
          (_, i) => margin.left + i * barWidth + barWidth / 2 - strokeWidth / 2
        )
        .attr("y1", (p) => yAxisScale(p.max))
        .attr("y2", (p) => yAxisScale(p.thirdQuartile))
        .attr("stroke-width", strokeWidth)
        .attr("stroke", defaultStrokeColor);
      plotGroups
        .append("line")
        .attr("x1", (_, i) => margin.left + i * barWidth + barGutter / 2)
        .attr("x2", (_, i) => margin.left + (i + 1) * barWidth - barGutter / 2)
        .attr("y1", (p) => yAxisScale(p.max))
        .attr("y2", (p) => yAxisScale(p.max))
        .attr("stroke-width", strokeWidth)
        .attr("stroke", defaultStrokeColor);

      plotGroups
        .append("line")
        .attr(
          "x1",
          (_, i) => margin.left + i * barWidth + barWidth / 2 - strokeWidth / 2
        )
        .attr(
          "x2",
          (_, i) => margin.left + i * barWidth + barWidth / 2 - strokeWidth / 2
        )
        .attr("y1", (p) => yAxisScale(p.firstQuartile))
        .attr("y2", (p) => yAxisScale(p.min))
        .attr("stroke-width", strokeWidth)
        .attr("stroke", defaultStrokeColor);
      plotGroups
        .append("line")
        .attr("x1", (_, i) => margin.left + i * barWidth + barGutter / 2)
        .attr("x2", (_, i) => margin.left + (i + 1) * barWidth - barGutter / 2)
        .attr("y1", (p) => yAxisScale(p.min))
        .attr("y2", (p) => yAxisScale(p.min))
        .attr("stroke-width", strokeWidth)
        .attr("stroke", defaultStrokeColor);

      const outlierData = seriesData.reduce(
        (
          arr: { value: number; index: number; color: string }[],
          series,
          seriesIndex
        ) => {
          return [
            ...arr,
            ...series.outliers.map((s) => ({
              value: s,
              index: seriesIndex,
              color:
                (s > series.max
                  ? styles?.outlierTopColor
                  : styles?.outlierBottomColor) ?? defaultStrokeColor,
            })),
          ];
        },
        []
      );

      svg
        .select(".plot-area")
        .selectAll("circle")
        .data(outlierData)
        .enter()
        .append("circle")
        .attr(
          "cx",
          (p) => margin.left + p.index * barWidth + barWidth / 2 - 0.5
        )
        .attr("cy", (p) => yAxisScale(p.value))
        .attr("r", 3)
        .attr("fill", (p) => p.color);

      // Latest Values
      plotGroups
        .append("circle")
        .attr("cx", (_, i) => margin.left + i * barWidth + barWidth / 2 - 0.5)
        .attr("cy", (p) => yAxisScale(p.values[p.values.length - 1]))
        .attr("r", 5)
        .attr("fill", styles?.latestValueColor ?? "rgb(255, 50, 50)")
        .attr("stroke", defaultStrokeColor)
        .attr("stroke-width", 1);

      // X Axis
      const rotate = maxLabelWidth > barWidth;

      svg
        .select(".x-axis")
        .append("line")
        .attr("x1", margin.left)
        .attr("x2", width - margin.right)
        .attr("y1", height)
        .attr("y2", height)
        .style("stroke", axis?.x?.color ?? "#000000")
        .style("stroke-width", 1);

      const xAxisLabels = svg
        .select(".x-axis")
        .selectAll("g")
        .data(seriesData)
        .enter()
        .append("g");
      xAxisLabels
        .append("text")
        .attr("height", margin.top)
        .attr("text-anchor", rotate ? "end" : "middle")
        .attr("x", (_, i) => margin.left + i * barWidth + barWidth / 2)
        .attr("y", height + 20)
        .style("font-size", "12px")
        .style("transform", rotate ? "rotate(-35deg)" : "")
        .style("transform-box", "fill-box")
        .style("transform-origin", "right center")
        .text((p) => p.label);

      // Y Axis
      const yAxisTicks: { value: number; y: number }[] = yAxisScale
        .ticks(
          axis?.y?.scale === "Linear" ? axis?.y?.tickCount ?? 10 : undefined
        )
        .map((t) => ({
          value: t,
          y: yAxisScale(t) - margin.top,
        }));

      svg
        .select(".y-axis")
        .append("line")
        .attr("x1", margin.left)
        .attr("x2", margin.left)
        .attr("y1", margin.top)
        .attr("y2", height)
        .style("stroke", axis?.y?.color ?? "#000000")
        .style("stroke-width", 1);

      svg
        .select(".y-axis")
        .selectAll("text")
        .data(yAxisTicks)
        .enter()
        .append("text")
        .attr("text-anchor", "end")
        .style("font-size", "12px")
        .attr("x", margin.left - 5)
        .attr("y", (p) => p.y + margin.top + 5)
        .text((p) => p.value.toFixed(axis?.y?.tickPrecision ?? 2));

      // Axis Labels
      if (axis)
        addAxisLabels(
          svg,
          canvas,
          axis,
          margin,
          width,
          height,
          xAxisLabelHeight
        );

      // Bands
      if (bands?.length)
        addChartBands(
          svg,
          bands,
          margin,
          yAxisTicks[0].value,
          yAxisTicks[yAxisTicks.length - 1].value,
          width,
          height,
          "Linear"
        );

      // Guides
      if (guides?.length)
        addChartGuides(
          svg,
          guides,
          margin,
          yAxisTicks[0].value,
          yAxisTicks[yAxisTicks.length - 1].value,
          width,
          height,
          "Linear"
        );

      // Guide and Band Legend
      if (guides?.length || bands?.length)
        addBandAndGuideLegend(
          svg,
          bandAndGuideLegend.labelData,
          bandAndGuideLegend.totalHeight
        );

      // Grid
      if (showGrid)
        addChartGrid(
          svg,
          { ...margin, right: margin.right },
          seriesData.map((_, i) => ({ x: barWidth * (i + 1) })),
          yAxisTicks.filter((_, i) => i > 0),
          height,
          width,
          styles?.gridColor
        );

      canvas.remove();
    },
    [data]
  );

  return (
    <svg
      ref={ref}
      className="h-app-chart-box-and-whisker"
      viewBox={`0 0 ${Constants.renderWidth} ${Constants.renderHeight}`}
      height={Constants.renderHeight}
      width="100%"
      preserveAspectRatio="xMidYMid meet"
    >
      <Defaults />
      <g className="bands" />
      <g className="grid" />
      <g className="plot-area" />
      <g className="guides" />
      <g className="x-axis" />
      <g className="y-axis" />
      <g className="series" />
    </svg>
  );
}
