import { Dispatch, useEffect, useMemo, useRef, useState } from 'react';
import { TreeApi as ArboristTreeApi, NodeApi } from 'react-arborist';
import { TreeData, TreeSelectionProps } from './tree-selection';
import { getNode, getPathToNode, updateDescendantSelections } from './utils';
import {
  Actions,
  rootReducer,
} from 'react-arborist/dist/module/state/root-reducer';

export type SelectionDescendantMap = Map<string, SelectionDescendantMap>;

type ExtendedActions = Actions | { type: 'OPEN_ALL' } | { type: 'CLOSE_ALL' };

const EMPTY_ARRAY: string[] = [];

export type TreeApiProps = {
  data: TreeSelectionProps['data'];
  defaultSelections?: string[];
  openByDefault: boolean;
};
export const useTreeApi = ({
  data,
  defaultSelections = EMPTY_ARRAY,
  openByDefault,
}: TreeApiProps) => {
  const treeRef = useRef<ArboristTreeApi<TreeData>>();
  /** ================================
   * Selection
   ================================ */
  const [selectedIds, setSelectedIds] = useState<string[]>([]);
  useEffect(() => {
    setSelectedIds([]);
    defaultSelections.forEach((id) => {
      const rootNode = treeRef.current?.root;
      if (!rootNode) return;

      const pathToNode = getPathToNode(id);
      const node = getNode({ nodes: rootNode?.children, pathToNode });

      if (node) selectNode(node);
    });
  }, [defaultSelections]);

  const [indeterminateSelectedIds, setIndeterminateSelectedIds] = useState<
    string[]
  >([]);

  const selectionDescendantMap = useRef<SelectionDescendantMap>(new Map());

  const selectNode = (node: NodeApi<TreeData>) => {
    const { id } = node;

    setSelectedIds((prevSelectedIds) => [...prevSelectedIds, id]);

    setIndeterminateSelectedIds(() =>
      updateDescendantSelections({
        operation: 'INSERT',
        node,
        selectionDescendantMap: selectionDescendantMap.current,
      })
    );
  };

  const deselectNode = (node: NodeApi<TreeData>) => {
    const { id } = node;

    setSelectedIds((prevSelectedIds) =>
      prevSelectedIds.filter((selectedId) => selectedId !== id)
    );
    setIndeterminateSelectedIds(() =>
      updateDescendantSelections({
        operation: 'REMOVE',
        node,
        selectionDescendantMap: selectionDescendantMap.current,
        selectedIds,
      })
    );
  };

  const clearAll = () => {
    setSelectedIds([]);
    setIndeterminateSelectedIds([]);
    selectionDescendantMap.current.clear();
  };

  const toggleNodeSelection = (node: NodeApi<TreeData>) => {
    const { id, data } = node;

    const isSelected = selectedIds.includes(id);

    /** ======== Select/Deselect Node ======== */
    if (!isSelected) selectNode(node);
    else deselectNode(node);

    /** ======== Select LinkedIDs ======== */
    data.linkedIds?.forEach((linkedId) => {
      const rootNode = treeRef.current?.root;

      if (!rootNode) return;

      const pathToNode = getPathToNode(linkedId);
      const linkedNode = getNode({ nodes: rootNode?.children, pathToNode });

      if (linkedNode) {
        if (!isSelected) selectNode(linkedNode);
        else deselectNode(linkedNode);
      }
    });
  };

  /** ================================
   * Visibility
   ================================ */
  const allNonLeafNodeIds = useMemo(() => {
    const getChildIds = (nodes: TreeData[]): string[] => {
      return nodes.flatMap((node) => {
        if (node.children) return [node.id, ...getChildIds(node.children)];
        return [];
      });
    };
    return getChildIds(data);
  }, [data]);

  /** ======== Extend Reducer ======== */
  useEffect(() => {
    const treeApi = treeRef.current;
    if (treeApi) {
      treeApi.store.replaceReducer((state, action) => {
        if (!state) return rootReducer(state, action);

        if (
          ['OPEN_ALL', 'CLOSE_ALL'].includes((action as ExtendedActions).type)
        ) {
          const isFiltered = treeApi.isFiltered;
          const openByDefault = treeApi.props.openByDefault;

          const isOpen = (action as ExtendedActions).type === 'OPEN_ALL';

          const clearFilterState =
            (isFiltered && isOpen) ||
            (!isFiltered && openByDefault && isOpen) ||
            (!isFiltered && !openByDefault && !isOpen);

          const filteredState = clearFilterState
            ? {}
            : allNonLeafNodeIds.reduce<{
                [key: string]: boolean;
              }>((acc, id) => {
                acc[id] = isOpen;
                return acc;
              }, {});

          const newState = {
            ...state,
            nodes: {
              ...state.nodes,
              open: {
                ...state.nodes.open,
                ...(isFiltered
                  ? { filtered: filteredState }
                  : { unfiltered: filteredState }),
              },
            },
          };
          return newState;
        } else {
          return rootReducer(state, action);
        }
      });
    }
  }, [allNonLeafNodeIds]);

  /** ======== Expand/Collapse ======== */
  const [isExpanded, setIsExpanded] = useState<boolean>(openByDefault);
  const openAll = () => {
    (treeRef.current?.store.dispatch as Dispatch<ExtendedActions>)({
      type: 'OPEN_ALL',
    });
  };

  const closeAll = () => {
    (treeRef.current?.store.dispatch as Dispatch<ExtendedActions>)({
      type: 'CLOSE_ALL',
    });
  };
  const expand = () => {
    setIsExpanded(true);
    openAll();
  };
  const collapse = () => {
    setIsExpanded(false);
    closeAll();
  };

  /** ================================
   * Search
   ================================ */
  const [search, setSearch] = useState<string>('');
  const onSearch = (search: string) => {
    setSearch(search);
    setIsExpanded(search.length > 0);
  };

  return {
    treeRef,
    selectedIds,
    indeterminateSelectedIds,
    toggleNodeSelection,
    isExpanded,
    expand,
    collapse,
    search,
    onSearch,
    clearAll,
  };
};

export type TreeApi = ReturnType<typeof useTreeApi>;
