import classNames from 'classnames';
import { sankey, sankeyLinkHorizontal } from 'd3-sankey';
import { cloneDeep } from 'lodash';
import { useRef } from 'react';

import { useResizeObserver } from '@revelio/core';

import { CommonPlotProps } from '../../types';
import { getMaxTextWidth } from '../../utilities';
import { AxisLabel } from '../axis/axis-label';
import { PlotLoadWrapper } from '../plot-loader/plot-loader';
import styles from './sankey-chart.module.css';
import { SankeyNodeTooltip } from './sankey-node-tooltip';
import { validateLinks, validateNodes } from './sankey-node-validator';
import { SankeyPath } from './sankey-path';
import {
  Link,
  LinkPath,
  Node,
  SankeyLinkValidated,
  SankeyNodeValidated,
} from './types';
import { SankeyData } from './types';

const Y_AXIS_MAX_WIDTH_SMALL = 30;
const Y_AXIS_MAX_WIDTH = 115;
const PLOT_LEFT_PADDING = 8;
const PLOT_RIGHT_PADDING = 8;
const PLOT_BOTTOM_PADDING = 4;
const NODE_WIDTH = 6;
const NODE_GAP = 4;
const NEGATIVE_LINK_OFFSET = 2;

type SankeyChartProps = {
  data: SankeyData;
  isOutflows: boolean;
} & CommonPlotProps;

export const SankeyChart = ({
  data,
  isOutflows,
  loading = false,
  className,
}: SankeyChartProps) => {
  const containerRef = useRef<HTMLDivElement>(null);

  const { width, height } = useResizeObserver(containerRef);
  const isSmallPlot = width < 250;
  const yAxisMaxWidth = isSmallPlot ? Y_AXIS_MAX_WIDTH_SMALL : Y_AXIS_MAX_WIDTH;

  /** ================ Labels ================ */
  const sourceLabels = Array.from(
    new Set(data.links.map((link) => link.source))
  );
  const sourceAxisWidth = Math.min(
    getMaxTextWidth({ texts: sourceLabels, fontSize: 12 }),
    yAxisMaxWidth
  );

  const targetLabels = Array.from(
    new Set(data.links.map((link) => link.target))
  );
  const targetAxisWidth = Math.min(
    getMaxTextWidth({ texts: targetLabels, fontSize: 12 }),
    yAxisMaxWidth
  );

  /** ================ Dimensions ================ */
  const plotWidth =
    width -
    sourceAxisWidth -
    targetAxisWidth -
    PLOT_LEFT_PADDING -
    PLOT_RIGHT_PADDING;
  const plotHeight = height - PLOT_BOTTOM_PADDING;

  const isPlotSizeValid = plotWidth > 0 && plotHeight > 0;

  /** ================ Sankey Generator ================ */
  // Node padding as a whole is ~40% of the height of the chart
  const numOfNodes = data.nodes.length - 1;
  const nodePadding = Math.floor((height * 0.4) / numOfNodes);

  const sankeyGenerator = sankey<Node, Link>()
    .nodeId((d) => d.id)
    .nodeWidth(NODE_WIDTH + NODE_GAP)
    .nodePadding(nodePadding)
    .size([plotWidth, plotHeight]);

  /** Sankey generator mutates input data. clone nodes and links so the original data is not mutated */
  const clonedNodes = cloneDeep(data.nodes);
  const clonedLinks = cloneDeep(data.links);

  const { nodes, links } =
    clonedNodes.length > 0 && clonedLinks.length > 0
      ? sankeyGenerator({
          nodes: clonedNodes,
          links: clonedLinks,
        })
      : { nodes: [], links: [] };

  /** validate nodes and links for easier typing */
  const nodesValidated = validateNodes(nodes);
  const linksValidated = validateLinks(links);

  /** Get the primary node and calculate where it's position should be from the total plot height
   * and the primary node's height. By default, the primary node is not centered, so we place it
   * at the halfway point of the plot height minus half the node's height.*/
  const primaryNode = nodesValidated.find(
    (node) =>
      (isOutflows && node.depth === 0) || (!isOutflows && node.depth === 1)
  );
  const totalNodes = nodesValidated.length - 1;
  /** Each node is offset by NEGATIVE_LINK_OFFSET to add a slight overlap and remove  */
  const offsetHeight = (totalNodes - 1) * NEGATIVE_LINK_OFFSET;

  const primaryNodeY0 = primaryNode?.y0;
  const primaryNodeHeight =
    (primaryNode?.y1 ?? 0) - (primaryNode?.y0 ?? 0) - offsetHeight;
  const primaryNodeOffset =
    plotHeight / 2 - primaryNodeHeight / 2 - (primaryNodeY0 ?? 0);

  /** Offset the primary node by it's true center position */
  nodesValidated.forEach((node) => {
    if (isOutflows && node.depth === 0) {
      node.y0 = node.y0 + primaryNodeOffset;
      node.y1 = node.y1 + primaryNodeOffset - offsetHeight;
    }
    if (!isOutflows && node.depth === 1) {
      node.y0 = node.y0 + primaryNodeOffset;
      node.y1 = node.y1 + primaryNodeOffset - offsetHeight;
    }
  });

  const sourceNodes = nodesValidated.filter((node) =>
    sourceLabels.includes(node.id)
  );
  const targetNodes = nodesValidated.filter((node) =>
    targetLabels.includes(node.id)
  );

  const sankeyLinkHorizontalGenerator = sankeyLinkHorizontal<
    SankeyNodeValidated,
    SankeyLinkValidated<SankeyNodeValidated, Link>
  >()
    .source((d) => {
      /** For outflows, shift the link starts down to match the centered primary node,
       * and offset the links to add some overlap. */
      return isOutflows
        ? [
            d.source?.x1,
            d?.y0 + primaryNodeOffset + (d?.index ?? 0) * -NEGATIVE_LINK_OFFSET,
          ]
        : [d.source?.x1, d?.y0];
    })
    .target((d) => {
      /** For inflows, shift the link ends down to match the centered primary node,
       * and offset the links to add some overlap. */
      return !isOutflows
        ? [
            d?.target?.x0,
            d?.y1 + primaryNodeOffset + (d?.index ?? 0) * -NEGATIVE_LINK_OFFSET,
          ]
        : [d?.target?.x0, d?.y1];
    });

  const linkPaths = linksValidated
    .map((link) => ({
      ...link,
      path: sankeyLinkHorizontalGenerator(link),
    }))
    .filter((link): link is LinkPath => link.path !== null);

  /** ================ Render ================ */
  return (
    <div
      ref={containerRef}
      className={classNames(styles.container, className)}
      data-testid="plot-SankeyDiagram"
    >
      <PlotLoadWrapper
        loading={loading}
        noData={nodesValidated.length === 0 || linksValidated.length === 0}
      >
        <svg width="100%" height="100%">
          {isPlotSizeValid && (
            <g id="render-container">
              <g id="source-axis">
                {sourceNodes.map((node) => (
                  <AxisLabel
                    key={node.id}
                    y={(node.y1 + node.y0) / 2}
                    width={sourceAxisWidth}
                    label={node.id}
                    availableHeight={node.y1 - node.y0}
                    fontSize={12}
                    testId={'sankey-source-node'}
                  />
                ))}
              </g>
              <g
                id="chart"
                transform={`translate(${sourceAxisWidth + PLOT_LEFT_PADDING}, 0)`}
                className={styles.chart}
              >
                {nodesValidated.map((node) => (
                  <SankeyNodeTooltip key={node.id} node={node}>
                    <rect
                      key={node.id}
                      x={
                        node.x0 +
                        ((node?.targetLinks?.length ?? 0) > 0 ? NODE_GAP : 0)
                      }
                      y={node.y0}
                      width={node.x1 - node.x0 - NODE_GAP}
                      height={node.y1 - node.y0}
                      fill="#2d426a"
                      className={styles.node}
                    />
                  </SankeyNodeTooltip>
                ))}
                {linkPaths.map((linkPath) => (
                  <SankeyPath
                    key={`${linkPath?.source?.id}_${linkPath?.target?.id}`}
                    link={linkPath}
                    isOutflows={isOutflows}
                  />
                ))}
              </g>
              <g
                id="target-axis"
                transform={`translate(${sourceAxisWidth + PLOT_LEFT_PADDING + PLOT_RIGHT_PADDING + plotWidth}, 0)`}
              >
                {targetNodes.map((node) => (
                  <AxisLabel
                    key={node.id}
                    y={(node.y1 + node.y0) / 2}
                    width={targetAxisWidth}
                    label={node.id}
                    availableHeight={node.y1 - node.y0}
                    align="left"
                    fontSize={12}
                    testId={'sankey-target-node'}
                  />
                ))}
              </g>
            </g>
          )}
        </svg>
      </PlotLoadWrapper>
    </div>
  );
};
