import { Tooltip } from '@chakra-ui/react';
import classNames from 'classnames';
import { extent, scaleBand, scaleLinear, scaleOrdinal } from 'd3';
import { clamp } from 'lodash';
import { useState } from 'react';

import { CommonPlotProps, FormatType } from '../../types';
import { getFormatter, plotColors, useResizeObserver } from '../../utilities';
import { AxisLabel } from '../axis/axis-label';
import { BarData } from '../bar/bar-chart';
import { GroupedBarTooltip } from '../grouped-bar/grouped-bar-tooltip';
import { PlotLoadWrapper } from '../plot-loader/plot-loader';
import styles from './grouped-bar-chart.module.css';

const Dimensions = {
  BarLeftPadding: 8,
  InnerSpacingFactor: 0.3,
  OuterSpacingFactor: 0.3,
  MinBarHeight: 6,
  MinLabelWidth: 80,
  MaxLabelWidth: 120,
  MaxLabelWidthRatio: 0.1,
};

export type GroupedBarData<TData extends object = object> = {
  group: string;
  values: BarData<TData>[];
};

export type GroupedBarChartProps<TData extends object = object> = {
  data: GroupedBarData<TData>[];
  colorMap: Record<string, string>;
  tooltipContent?: (group: string, bar: BarData<TData>) => JSX.Element;
  hideEmptyBars?: boolean;
} & CommonPlotProps;

export const GroupedBarChart = <TData extends object>({
  className,
  data,
  loading = false,
  colorMap,
  format = FormatType.PERCENTAGE,
  tooltipContent,
  hideEmptyBars = false,
}: GroupedBarChartProps<TData>) => {
  const { containerRef, width, height } = useResizeObserver();
  const [hoveredSeries, setHoveredSeries] = useState<string | null>(null);

  /* Plot dimensions */
  const minPlotHeight =
    data.length * Dimensions.MinBarHeight * data[0]?.values?.length;
  const plotHeight = Math.max(height, minPlotHeight);
  const labelWidth = clamp(
    width * Dimensions.MaxLabelWidthRatio,
    Dimensions.MinLabelWidth,
    Dimensions.MaxLabelWidth
  );
  const plotWidth = Math.max(0, width - labelWidth - Dimensions.BarLeftPadding);
  const isPlotSizeValid = plotWidth > 0 && plotHeight > 0;

  const processedData = data.map((group) => {
    const updatedValues = hideEmptyBars
      ? group.values.filter((bar) => bar.value !== 0)
      : group.values;

    const total = updatedValues.reduce((sum, bar) => sum + (bar.value ?? 0), 0);

    return {
      ...group,
      total,
      values: updatedValues,
    };
  });

  const barLabels = Array.from(
    new Set(processedData.flatMap((d) => d.values.map((v) => v.label)))
  );

  /* Scales */
  const groupScale = scaleBand()
    .domain(processedData.map((d) => d.group))
    .range([0, plotHeight])
    .paddingInner(Dimensions.InnerSpacingFactor)
    .paddingOuter(Dimensions.OuterSpacingFactor);

  const barScale = scaleBand()
    .domain(barLabels)
    .range([0, groupScale.bandwidth()]);

  // Calculate value scale
  const allValues = processedData
    .flatMap((d) => d.values.map((v) => v.value))
    .filter((v): v is number => v !== null);
  const [minValue, maxValue] = extent(allValues);

  const valueScale = (() => {
    if (minValue === undefined || maxValue === undefined) return null;
    const zeroedMinValue = maxValue > 0 ? Math.min(minValue, 0) : minValue;
    const zeroedMaxValue = maxValue < 0 ? 0 : maxValue;
    return scaleLinear([zeroedMinValue, zeroedMaxValue], [0, plotWidth]);
  })();

  /** ======== Colors ======== */
  const colorScale = scaleOrdinal<string>()
    .domain(barLabels)
    .range(
      barLabels.map(
        (label) => colorMap?.[label] || plotColors[barLabels.indexOf(label)]
      )
    );

  /** ======== Helper functions ======== */
  const calculateGroupDimensions = (groupValues: BarData<TData>[]) => {
    const barCount = groupValues.length;
    const totalBarHeight = barCount * barScale.bandwidth();
    const totalSpacingHeight =
      (barCount - 1) * (barScale.step() - barScale.bandwidth());
    const groupHeight = totalBarHeight + totalSpacingHeight;
    const verticalOffset = (groupScale.bandwidth() - groupHeight) / 2;
    return { groupHeight, verticalOffset, barCount };
  };

  const renderBar = (bar: BarData<TData>, index: number, group: string) => {
    if (bar.value === null || !valueScale) return null;

    const barY = index * barScale.step();
    const x0 = valueScale(0);
    const x1 = valueScale(bar.value);
    const x = x1 < x0 ? x1 : x0;
    const barWidth = Math.abs(x1 - x0);

    return (
      <g
        key={bar.label}
        className={classNames(styles.barGroup, {
          [styles.dimmed]:
            hoveredSeries !== null && hoveredSeries !== bar.label,
        })}
        onMouseEnter={() => setHoveredSeries(bar.label)}
        onMouseLeave={() => setHoveredSeries(null)}
      >
        <Tooltip
          hasArrow
          label={
            tooltipContent ? (
              tooltipContent(group, bar)
            ) : (
              <GroupedBarTooltip
                title={bar.label}
                label={group}
                value={getFormatter(format)(bar.value)}
              />
            )
          }
          padding={'8px 12px'}
          placement="top"
          bg="#2d426a"
          borderRadius="4px"
          data-testid="grouped-bar-tooltip"
        >
          <rect
            x={x}
            y={barY}
            width={barWidth}
            height={barScale.bandwidth()}
            rx={2}
            fill={bar.color || colorScale(bar.label)}
            className={classNames(styles.bar, 'bar')}
            data-testid={`grouped-bar-chart-bar`}
          />
        </Tooltip>
      </g>
    );
  };

  const renderGroup = (group: GroupedBarData<TData>) => {
    const groupY = groupScale(group.group);
    if (groupY === undefined) return null;

    const { verticalOffset } = calculateGroupDimensions(group.values);

    return (
      <g
        key={group.group}
        transform={`translate(0, ${groupY + verticalOffset})`}
        data-testid="grouped-bar-chart-group"
      >
        {group.values.map((bar, index) => renderBar(bar, index, group.group))}
      </g>
    );
  };

  const renderAxisLabel = (domain: string) => {
    const y = groupScale(domain);
    if (y === undefined) return null;

    const group = processedData.find((g) => g.group === domain);
    if (!group) return null;

    const { groupHeight, verticalOffset } = calculateGroupDimensions(
      group.values
    );

    return (
      <AxisLabel
        key={domain}
        y={y + verticalOffset + groupHeight / 2}
        width={labelWidth}
        fontSize={11}
        label={domain}
        availableHeight={groupHeight}
        testId={`grouped-bar-chart-axis-label`}
      />
    );
  };

  return (
    <div
      ref={containerRef}
      className={classNames(styles.container, className)}
      data-testid="plot-GroupedBarChartHorizontal"
    >
      <PlotLoadWrapper loading={loading} noData={processedData.length === 0}>
        <svg width="100%" height="100%">
          <g id="y-axis">
            {isPlotSizeValid && groupScale.domain().map(renderAxisLabel)}
          </g>
          <g
            id="chart"
            transform={`translate(${labelWidth + Dimensions.BarLeftPadding}, 0)`}
          >
            {isPlotSizeValid && valueScale && processedData.map(renderGroup)}
          </g>
        </svg>
      </PlotLoadWrapper>
    </div>
  );
};
