import React, { useMemo } from "react";
import Constants from "./Constants";
import { useD3 } from "../../hooks";
import { scaleLinear } from "d3";
import { getSeriesColor, measureTextWidth } from "./utilities";
import Defaults from "./Defaults";

interface Styles {
  gridColor?: string;
  axisColor?: string;
  polygonColor?: string;
}

interface Elements {
  cations: string[][];
  anions: string[][];
}

interface StiffChartProps<TData> {
  data: TData[][];
  elements: Elements;
  getElementValue: (pointGroup: TData[], element: string) => number;
  series: {
    getSeriesTitle?: (point: TData) => string;
    getSeriesSubtitle?: (point: TData) => string;
    unitLabel?: string;
  };
  styles?: Styles;
}

const margin = {
  top: 20,
  left: 20,
  right: 20,
  bottom: 20,
  chartGutter: 20,
};
const mainTitleHeight = 30;
const subTitleHeight = 20;
const axisHeight = 20;
const elementHeight = 40;
const gutterHeight = 40;

function extractSeriesData<TData>(
  data: TData[][],
  elements: Elements,
  getElementValue: (pointGroup: TData[], element: string) => number,
  labelFontSize: number
) {
  const canvas = document.createElement("canvas");
  let min = 0;
  let max = 0;
  let maxLabelWidth = 0;

  const extract = (elements: string[], series: TData[]) => {
    const value = elements.reduce(
      (sum, element) => sum + getElementValue(series, element),
      0
    );
    const label = elements.join(" + ");
    const labelWidth = measureTextWidth(label, canvas, `${labelFontSize}px`);

    if (value < min) min = value;
    if (value > max) max = value;
    if (labelWidth > maxLabelWidth) maxLabelWidth = labelWidth;

    return { elements, value, label };
  };

  const grouped = data.map((series) => ({
    cations: elements.cations.map((elements) => extract(elements, series)),
    anions: elements.anions.map((elements) => extract(elements, series)),
  }));

  canvas.remove();

  return { min, max, maxLabelWidth, data: grouped };
}

export default function StiffChart<TData>({
  data,
  elements,
  getElementValue,
  series,
  styles,
}: StiffChartProps<TData>) {
  const elementCount = useMemo(
    () =>
      elements.cations.length > elements.anions.length
        ? elements.cations.length
        : elements.anions.length,
    [elements]
  );
  const stiffHeight = useMemo(
    () =>
      mainTitleHeight +
      (series.getSeriesSubtitle ? subTitleHeight : 0) +
      axisHeight +
      elementHeight * elementCount +
      gutterHeight * (elementCount - 1),
    [elementCount, series]
  );
  const grid = useMemo(() => {
    const columns = data.length >= 3 ? 3 : data.length;
    const rows = Math.ceil(data.length / columns);

    return { columns, rows };
  }, [data.length]);

  const ref = useD3(
    (svg) => {
      const plotWidth =
        (Constants.plotWidth - margin.left - margin.right) / grid.columns;
      let stiffCount = 0;
      const labelFontSize = 10;
      const {
        min,
        max,
        maxLabelWidth,
        data: seriesData,
      } = extractSeriesData(data, elements, getElementValue, labelFontSize);
      const xAxisScale = scaleLinear()
        .domain([min, max])
        .range([0, (plotWidth - margin.chartGutter) / 2 - maxLabelWidth])
        .nice();
      const xAxisTicks = xAxisScale.ticks();

      for (let r = 0; r < grid.rows; r++) {
        for (let c = 0; c < grid.columns; c++) {
          const plotHeight = stiffHeight - axisHeight;
          const baseX = margin.left + c * plotWidth;
          const baseY = margin.top + stiffHeight * (r + 1) - axisHeight;
          const stiff = svg.select(".plot-area").append("g");

          // X Axis
          stiff
            .append("line")
            .attr("x1", baseX + margin.chartGutter / 2 + maxLabelWidth)
            .attr(
              "x2",
              baseX + plotWidth - margin.chartGutter / 2 - maxLabelWidth
            )
            .attr("y1", baseY - axisHeight)
            .attr("y2", baseY - axisHeight)
            .attr("stroke", styles?.axisColor ?? "#000000")
            .attr("stroke-width", 1);

          if (series.unitLabel)
            stiff
              .append("text")
              .attr("x", baseX + plotWidth / 2)
              .attr("y", baseY)
              .attr("font-size", labelFontSize)
              .attr("text-anchor", "middle")
              .text(series.unitLabel);

          // Cations X Axis
          stiff
            .append("g")
            .selectAll("text")
            .data(xAxisTicks)
            .enter()
            .append("text")
            .attr("x", (t) => baseX + plotWidth / 2 - xAxisScale(t))
            .attr("y", baseY)
            .attr("font-size", labelFontSize)
            .attr("text-anchor", "middle")
            .text((t, i) =>
              (i === 0 && !series.unitLabel) || i === xAxisTicks.length - 1
                ? `${t}`
                : ""
            );
          stiff
            .append("g")
            .selectAll("line")
            .data(xAxisTicks)
            .enter()
            .append("line")
            .attr("x1", (t) => baseX + plotWidth / 2 - xAxisScale(t) - 0.5)
            .attr("x2", (t) => baseX + plotWidth / 2 - xAxisScale(t) - 0.5)
            .attr("y1", baseY - axisHeight)
            .attr("y2", baseY - axisHeight + 5)
            .attr("stroke", styles?.axisColor ?? "#000000")
            .attr("stroke-width", 1);

          // Anions X Axis
          stiff
            .append("g")
            .selectAll("text")
            .data(xAxisTicks.filter((t) => t > 0))
            .enter()
            .append("text")
            .attr("x", (t) => baseX + plotWidth / 2 + xAxisScale(t))
            .attr("y", baseY)
            .attr("font-size", labelFontSize)
            .attr("text-anchor", "middle")
            .text((t, i) => (i === xAxisTicks.length - 2 ? `${t}` : ""));
          stiff
            .append("g")
            .selectAll("line")
            .data(xAxisTicks.filter((t) => t > 0))
            .enter()
            .append("line")
            .attr("x1", (t) => baseX + plotWidth / 2 + xAxisScale(t) - 0.5)
            .attr("x2", (t) => baseX + plotWidth / 2 + xAxisScale(t) - 0.5)
            .attr("y1", baseY - axisHeight)
            .attr("y2", baseY - axisHeight + 5)
            .attr("stroke", styles?.axisColor ?? "#000000")
            .attr("stroke-width", 1);

          // Title & Subtitle
          const title =
            series.getSeriesTitle?.(data[stiffCount]?.[0]) ??
            `Series ${stiffCount + 1}`;
          const subTitle =
            series.getSeriesSubtitle?.(data[stiffCount]?.[0]) ?? "";

          stiff
            .append("text")
            .attr("x", baseX + plotWidth / 2)
            .attr("y", baseY - plotHeight + (subTitle ? 20 : 15))
            .attr("text-anchor", "middle")
            .attr("font-size", 14)
            .attr("font-weight", "bold")
            .text(title);

          if (subTitle)
            stiff
              .append("text")
              .attr("x", baseX + plotWidth / 2)
              .attr("y", baseY - plotHeight + mainTitleHeight + 5)
              .attr("text-anchor", "middle")
              .attr("font-size", 12)
              .text(subTitle);

          const titleHeight = mainTitleHeight + (subTitle ? subTitleHeight : 0);
          const axisBaseY = baseY - plotHeight + titleHeight;

          // Cations
          stiff
            .append("g")
            .selectAll("text")
            .data(elements.cations)
            .enter()
            .append("text")
            .attr("x", baseX + maxLabelWidth + margin.chartGutter / 2 - 5)
            .attr(
              "y",
              (_, i) => axisBaseY + elementHeight * i + (gutterHeight * i - 1)
            )
            .attr("text-anchor", "end")
            .attr("alignment-baseline", "middle")
            .attr("font-size", labelFontSize)
            .text((e) => e.join(" + "));

          // Anions
          stiff
            .append("g")
            .selectAll("text")
            .data(elements.anions)
            .enter()
            .append("text")
            .attr(
              "x",
              baseX + plotWidth - maxLabelWidth - margin.chartGutter / 2 + 5
            )
            .attr(
              "y",
              (_, i) => axisBaseY + elementHeight * i + (gutterHeight * i - 1)
            )
            .attr("text-anchor", "start")
            .attr("alignment-baseline", "middle")
            .attr("font-size", labelFontSize)
            .text((e) => e.join(" + "));

          // Grid
          stiff
            .append("g")
            .selectAll("line")
            .data(new Array(elementCount))
            .enter()
            .append("line")
            .attr("x1", baseX + maxLabelWidth + margin.chartGutter / 2)
            .attr(
              "x2",
              baseX + plotWidth - maxLabelWidth - margin.chartGutter / 2
            )
            .attr(
              "y1",
              (_, i) => axisBaseY + elementHeight * i + gutterHeight * i
            )
            .attr(
              "y2",
              (_, i) => axisBaseY + elementHeight * i + gutterHeight * i
            )
            .attr("stroke", styles?.gridColor ?? "#aaaaaa")
            .attr("stroke-width", 1)
            .attr("stroke-dasharray", "5,5");

          // Polygon
          const cationMaxIndex = seriesData[stiffCount].cations.length - 1;
          const anionMaxIndex = seriesData[stiffCount].anions.length - 1;
          const polyPoints: { x: number; y: number }[] = [
            { x: baseX + plotWidth / 2, y: axisBaseY },
            ...seriesData[stiffCount].cations.map((e, i) => ({
              x: baseX + plotWidth / 2 - xAxisScale(e.value),
              y: axisBaseY + elementHeight * i + gutterHeight * i,
            })),
            {
              x: baseX + plotWidth / 2,
              y:
                axisBaseY +
                elementHeight * cationMaxIndex +
                gutterHeight * cationMaxIndex,
            },
            {
              x: baseX + plotWidth / 2,
              y:
                axisBaseY +
                elementHeight * anionMaxIndex +
                gutterHeight * anionMaxIndex,
            },
            ...seriesData[stiffCount].anions.reverse().map((e, i) => ({
              x: baseX + plotWidth / 2 + xAxisScale(e.value),
              y:
                axisBaseY +
                elementHeight * (anionMaxIndex - i) +
                gutterHeight * (anionMaxIndex - i),
            })),
          ];

          stiff
            .append("polygon")
            .attr(
              "points",
              polyPoints.reduce((str, p, i) => {
                return `${str} ${p.x},${p.y}`;
              }, "")
            )
            .attr("stroke", styles?.axisColor ?? "#000000")
            .attr("stroke-width", 1)
            .attr(
              "fill",
              styles?.polygonColor ??
                getSeriesColor(data[stiffCount][0], stiffCount)
            );

          // Y Axis
          stiff
            .append("line")
            .attr("x1", baseX + plotWidth / 2 - 0.5)
            .attr("x2", baseX + plotWidth / 2 - 0.5)
            .attr("y1", baseY - axisHeight)
            .attr("y2", baseY - plotHeight + titleHeight)
            .attr("stroke", styles?.axisColor ?? "#000000")
            .attr("stroke-width", 1);

          if (++stiffCount === data.length) break;
        }
      }
    },
    [data, stiffHeight, grid]
  );

  const height = grid.rows * stiffHeight + margin.top + margin.bottom;

  return (
    <svg
      ref={ref}
      className="h-app-chart-stiff"
      viewBox={`0 0 ${Constants.renderWidth} ${height}`}
      height={height}
      width="100%"
      preserveAspectRatio="xMidYMid meet"
    >
      <Defaults />
      <g className="plot-area" />
    </svg>
  );
}
