import React from "react";
import { useD3 } from "../../hooks";
import {
  addAxisLabels,
  addBandAndGuideLegend,
  addChartBands,
  addChartGrid,
  addChartGuides,
  calculateBandAndGuideRows,
  calculatePlotHeight,
  calculateRightMargin,
  calculateRotatedLabelHeight,
  getSeriesColor,
  measureTextWidth,
} from "./utilities";
import { AxisSettings, Band, Guide, Series } from "./interfaces";
import { Selection, max as d3max } from "d3";
import { addSeriesLegend } from "./utilities";
import Constants from "./Constants";
import Defaults from "./Defaults";

interface BarChartProps<TData> {
  data: TData[][];
  getBarValue: (point: TData) => number;
  formatBarValue: (point: TData) => string;
  getBarLabel: (point: TData) => string;
  showBarValues?: boolean;
  showGrid?: boolean;
  stacked?: boolean;
  axis?: AxisSettings;
  guides?: Guide[];
  bands?: Band[];
  series?: Series<TData>;
  styles?: {
    gridColor?: string;
    barStrokeColor?: string;
  };
}

function addBarLabels(
  svg: Selection<SVGGElement, unknown, null, undefined>,
  canvas: HTMLCanvasElement,
  widthDivisor: number,
  margin: { top: number },
  barWidth: number,
  barGutter: number,
  data: {
    width: number;
    x: number;
    y: number;
    formattedValue: string;
  }[]
) {
  const widths = data.map((d) => measureTextWidth(d.formattedValue, canvas));
  const maxWidth = d3max(widths) ?? 0;
  const orientation =
    maxWidth > barWidth - barGutter ? "vertical" : "horizontal";

  // Data labels
  const barLabels = svg.selectAll("text").data(data).enter().append("g");
  barLabels
    .append("rect")
    .attr("fill", "transparent")
    .attr("height", margin.top)
    .attr("width", (p) => p.width / widthDivisor)
    .attr("x", (p) => p.x)
    .attr("y", (p) => p.y - margin.top);
  barLabels
    .append("text")
    .attr("height", margin.top)
    .attr("text-anchor", orientation === "horizontal" ? "middle" : "start")
    .attr("x", (p) => p.x + p.width / widthDivisor / 2)
    .attr("y", (p) => p.y - 4)
    .style("font-size", "12px")
    .style("transform", orientation === "vertical" ? "rotate(-90deg)" : "")
    .style("transform-box", orientation === "vertical" ? "fill-box" : "")
    .style("transform-origin", orientation === "vertical" ? "left center" : "")
    .text((p) => p.formattedValue);
}

export default function BarChart<TData>({
  data,
  getBarValue,
  formatBarValue,
  getBarLabel,
  showBarValues,
  showGrid,
  stacked,
  axis,
  styles,
  series,
  guides,
  bands,
}: BarChartProps<TData>) {
  const ref = useD3(
    (svg) => {
      const {
        plotWidth,
        axisLabelHeight,
        axisLabelMarginBottom,
        axisLabelMarginTop,
      } = Constants;
      const canvas = document.createElement("canvas");
      let max: number = axis?.y?.max ?? 0;
      let maxLabelWidth: number = 0;
      let largestSeriesIndex = 0;

      data.forEach((series, seriesIndex) => {
        if (series.length > data[largestSeriesIndex].length)
          largestSeriesIndex = seriesIndex;
      });

      const labelMargin = 10;
      const width = plotWidth;
      const margin = {
        top: 40,
        right: calculateRightMargin(data, guides, series),
        left: 80,
      };
      const barGutter = 10;
      let graphData: {
        value: number;
        formattedValue: string;
        label: string;
        height: number;
        width: number;
        labelWidth: number;
        x: number;
        y: number;
      }[][] = [];
      const itemSums: number[] = [
        ...new Array(data[largestSeriesIndex].length),
      ].fill(0);
      const labelYs: number[] = [
        ...new Array(data[largestSeriesIndex].length),
      ].fill(Constants.plotHeight);
      const widthDivisor = stacked ? 1 : data.length;
      const barWidth =
        (1 / data[largestSeriesIndex].length) *
        (width - margin.left - margin.right);

      graphData = data.map((series) => {
        return series.map((p, i) => {
          const v = getBarValue(p);
          const label = getBarLabel(p);
          const labelWidth = measureTextWidth(label, canvas) + labelMargin;

          if (v > max) max = v;
          if (labelWidth > maxLabelWidth) maxLabelWidth = labelWidth;

          itemSums[i] += v;

          if (stacked && itemSums[i] > max) max = itemSums[i];

          return {
            value: v,
            formattedValue: formatBarValue(p),
            label,
            labelWidth,
            height: 0,
            width: barWidth,
            x: margin.left + i * barWidth,
            y: 0,
          };
        });
      });

      const xAxisLabelHeight = calculateRotatedLabelHeight(
        maxLabelWidth,
        maxLabelWidth > barWidth ? 35 : 0
      );
      const bandAndGuideLegend = calculateBandAndGuideRows(
        guides ?? [],
        bands ?? [],
        canvas
      );
      let height = calculatePlotHeight(
        data,
        series,
        xAxisLabelHeight,
        bandAndGuideLegend.totalHeight,
        canvas
      );

      if (axis?.x?.label)
        height -= axisLabelHeight + axisLabelMarginBottom + axisLabelMarginTop;

      graphData.forEach((seriesData, seriesIndex) => {
        seriesData.forEach((p, i) => {
          p.height = (p.value / max) * (height - margin.top);

          let prevY: number | undefined = undefined;

          if (stacked) {
            const prevSeries =
              seriesIndex > 0 ? graphData[seriesIndex - 1] : undefined;

            if (prevSeries && prevSeries[i]) prevY = prevSeries[i].y;
          } else p.x += (p.width / data.length) * seriesIndex;

          p.y = (prevY ?? height) - p.height;

          if (p.y < labelYs[i]) labelYs[i] = p.y;
        });

        const seriesGroup = svg.select(".plot-area").append("g");

        // Data bars
        seriesGroup
          .selectAll("rect")
          .data(seriesData)
          .enter()
          .append("rect")
          .attr("x", (p) => p.x + barGutter / 2)
          .attr("y", (p) => p.y)
          .attr("width", (p) => p.width / widthDivisor - barGutter)
          .attr("height", (p) => p.height)
          .attr("fill", (_, i) =>
            getSeriesColor(data[seriesIndex][i], seriesIndex, series)
          )
          .attr("stroke", styles?.barStrokeColor ?? "")
          .attr("stroke-width", styles?.barStrokeColor ? 1 : 0);

        if (!stacked && showBarValues)
          addBarLabels(
            seriesGroup,
            canvas,
            widthDivisor,
            margin,
            seriesData[0].width / widthDivisor,
            barGutter,
            seriesData
          );
      });

      if (stacked && showBarValues) {
        const seriesGroup = svg.select(".plot-area").append("g");
        const labelData = itemSums.map((sum, i) => ({
          width: barWidth,
          x: margin.left + i * barWidth,
          y: labelYs[i],
          formattedValue: sum.toFixed(axis?.y?.tickPrecision ?? 2),
        }));

        addBarLabels(
          seriesGroup,
          canvas,
          widthDivisor,
          margin,
          barWidth / widthDivisor,
          barGutter,
          labelData
        );
      }

      // X Axis
      svg
        .select(".x-axis")
        .append("line")
        .attr("x1", margin.left)
        .attr("x2", width - margin.right)
        .attr("y1", height)
        .attr("y2", height)
        .style("stroke", axis?.x?.color ?? "#000000")
        .style("stroke-width", 1);

      const xAxisLabels = svg
        .select(".x-axis")
        .selectAll("g")
        .data(graphData[largestSeriesIndex])
        .enter()
        .append("g");
      xAxisLabels
        .append("text")
        .attr("height", margin.top)
        .attr("text-anchor", (p) =>
          maxLabelWidth > p.width ? "end" : "middle"
        )
        .attr("x", (p) => p.x + p.width / 2 + 5)
        .attr("y", height + 20)
        .style("font-size", "12px")
        .style("transform", (p) =>
          maxLabelWidth > p.width ? "rotate(-35deg)" : ""
        )
        .style("transform-box", "fill-box")
        .style("transform-origin", "right center")
        .text((p) => p.label);

      // Y Axis
      const yAxisTicks: { value: number; y: number }[] = [];
      const yTickIncrement = max / (axis?.y?.tickCount ?? 10);
      const plotHeight = height - margin.top;

      for (let i = 0; i <= (axis?.y?.tickCount ?? 10); i++)
        yAxisTicks.push({
          value: i * yTickIncrement,
          y: plotHeight - (i / (axis?.y?.tickCount ?? 10)) * plotHeight,
        });

      svg
        .select(".y-axis")
        .append("line")
        .attr("x1", margin.left)
        .attr("x2", margin.left)
        .attr("y1", margin.top)
        .attr("y2", height)
        .style("stroke", axis?.y?.color ?? "#000000")
        .style("stroke-width", 1);

      svg
        .select(".y-axis")
        .selectAll("text")
        .data(yAxisTicks)
        .enter()
        .append("text")
        .attr("text-anchor", "end")
        .style("font-size", "12px")
        .attr("x", margin.left - 5)
        .attr("y", (p) => p.y + margin.top + 5)
        .text((p) => p.value.toFixed(axis?.y?.tickPrecision ?? 2));

      // Axis Labels
      if (axis)
        addAxisLabels(
          svg,
          canvas,
          axis,
          margin,
          width,
          height,
          xAxisLabelHeight
        );

      // Bands
      if (bands?.length)
        addChartBands(svg, bands, margin, 0, max, width, height, "Linear");

      // Guides
      if (guides?.length)
        addChartGuides(svg, guides, margin, 0, max, width, height, "Linear");

      // Guide and Band Legend
      if (guides?.length || bands?.length)
        addBandAndGuideLegend(
          svg,
          bandAndGuideLegend.labelData,
          bandAndGuideLegend.totalHeight
        );

      // Grid
      if (showGrid)
        addChartGrid(
          svg,
          { ...margin, right: margin.right },
          graphData[largestSeriesIndex].map((p) => ({
            x: p.x - margin.left + p.width,
          })),
          yAxisTicks.filter((_, i) => i > 0),
          height,
          width,
          styles?.gridColor
        );

      // Series Legend
      if (data.length > 1)
        addSeriesLegend(
          svg,
          series,
          data,
          width,
          bandAndGuideLegend.totalHeight
        );

      canvas.remove();
    },
    [data]
  );

  return (
    <svg
      ref={ref}
      className="h-app-chart-bar"
      viewBox={`0 0 ${Constants.renderWidth} ${Constants.renderHeight}`}
      height={Constants.renderHeight}
      width="100%"
      preserveAspectRatio="xMidYMid meet"
    >
      <Defaults />
      <g className="bands" />
      <g className="grid" />
      <g className="plot-area" />
      <g className="guides" />
      <g className="x-axis" />
      <g className="y-axis" />
      <g className="series" />
    </svg>
  );
}
