import { getMaxCeil, LineChart, useThrottle } from '@clef/client-library';
import { ConfusionMatrixPerThreshold, LabelType } from '@clef/shared/types';
import { range, round } from 'lodash';
import React, { useCallback, useMemo, useRef } from 'react';
import SliderIcon from '../../pages/DataBrowser/icon/SliderIcon';
import Legends from './Legends';
import { PerformanceType } from '../../pages/DataBrowser/ProjectModelDetails/states';
import {
  getConfusionMatrixCountTitle,
  Palette,
} from '../../pages/DataBrowser/ProjectModelDetails/utils';
import { CONFIDENCE_THRESHOLD_OPTIONS } from '@clef/shared/constants';

export type ConfusionMatrixCumulativeGraphProps = {
  performanceType: PerformanceType;
  confusionMatrixPerThreshold: ConfusionMatrixPerThreshold;
  enableCap?: boolean;
  valueScaleFunction?: (value: number) => number;
  valueDescaleFunction?: (value: number) => number;
  threshold: number;
  onUpdateThreshold?: (newThreshold: number) => void;
  labelType: LabelType | null | undefined;
};

const ConfusionMatrixCumulativeGraph: React.FC<ConfusionMatrixCumulativeGraphProps> = ({
  performanceType,
  confusionMatrixPerThreshold,
  enableCap = true,
  valueScaleFunction = x => x,
  valueDescaleFunction = x => x,
  threshold,
  onUpdateThreshold,
  labelType,
}) => {
  const isDraggingThreshold = useRef(false);

  const { thresholdsMin, thresholdsMax, thresholdsStep } = CONFIDENCE_THRESHOLD_OPTIONS;

  const thresholds = range(thresholdsMin, thresholdsMax, thresholdsStep);

  const data = confusionMatrixPerThreshold;

  const confusionMatrixCountTitle = getConfusionMatrixCountTitle(performanceType, labelType!);

  const {
    truePositives = [],
    falseNegatives = [],
    misclassified = [],
    trueNegatives = [],
    falsePositives = [],
  } = confusionMatrixPerThreshold ?? {};
  // Better to use TP, FN and MC for calculating max value.
  // If it is 0 then fallback to the max of TN and FP.
  const cappedMaxValue = valueScaleFunction(
    getMaxCeil(
      Math.max(...[truePositives, falseNegatives, misclassified].flat()) ||
        Math.max(...[trueNegatives, falsePositives].flat()),
    ),
  );

  const getDataPoints = useCallback(
    (values: number[]) => {
      return values.map((value, index) => {
        return {
          y: valueScaleFunction(value),
          x: thresholds[index],
        };
      });
    },
    [thresholds, valueScaleFunction],
  );

  const chartData = useMemo(() => {
    const commonData = [
      {
        name: 'TP',
        color: Palette.TP,
        values: getDataPoints(data?.truePositives || []),
      },
      {
        name: 'FP',
        color: Palette.FP,
        values: getDataPoints(data?.falsePositives || []),
      },
      {
        name: 'FN',
        color: Palette.FN,
        values: getDataPoints(data?.falseNegatives || []),
      },
    ];
    if (performanceType === PerformanceType.MediaLevel) {
      commonData.push({
        name: 'TN',
        color: Palette.TN,
        values: getDataPoints(data?.trueNegatives || []),
      });
    } else {
      commonData.push({
        name: 'MC',
        color: Palette.MC,
        values: getDataPoints(data?.misclassified || []),
      });
    }
    return commonData;
  }, [
    data.falseNegatives,
    data.falsePositives,
    data.misclassified,
    data.trueNegatives,
    data.truePositives,
    getDataPoints,
    performanceType,
  ]);

  const updateThreshold = useThrottle((newThreshold: number) => {
    onUpdateThreshold?.(round(newThreshold, 2));
  }, 48);

  return (
    <div style={{ position: 'relative' }}>
      <LineChart
        id="confusion-matrix-graphs-main"
        data-testid="confusion-matrix-graphs-main"
        labelFormatterX={(v: number) => (v === 1 ? v : '')}
        labelFormatterY={(v: number) => Math.round(valueDescaleFunction(v))}
        axisTitleX={''}
        axisTitleY={confusionMatrixCountTitle}
        minX={0}
        maxX={1}
        data={chartData}
        showDataPoints={false}
        maxY={enableCap ? cappedMaxValue : undefined}
        aspectRatio={1.6}
      >
        {({ getRelativeCoordX, getRelativeCoordY, maxGridLineY = 0, getXTickIndex }) => {
          return (
            <g
              cursor="col-resize"
              onPointerDown={e => {
                isDraggingThreshold.current = true;
                if (e.pointerId) {
                  (e.target as SVGElement).setPointerCapture(e.pointerId);
                }
              }}
              onPointerUp={e => {
                isDraggingThreshold.current = false;
                if (e.pointerId) {
                  (e.target as SVGElement).releasePointerCapture(e.pointerId);
                }
              }}
              onMouseMove={e => {
                if (isDraggingThreshold.current) {
                  const index = getXTickIndex(e);
                  const newThreshold = thresholds[index];
                  if (newThreshold !== undefined) {
                    updateThreshold(newThreshold);
                  }
                }
              }}
            >
              <line
                x1={getRelativeCoordX(threshold)}
                y1={getRelativeCoordY(0)}
                x2={getRelativeCoordX(threshold)}
                y2={getRelativeCoordY(maxGridLineY)}
                stroke="#37414D"
                strokeWidth="1"
              />
              <SliderIcon
                x={getRelativeCoordX(threshold) - 4}
                y={getRelativeCoordY(maxGridLineY / 2) - 7}
              />
            </g>
          );
        }}
      </LineChart>
      <Legends threshold={threshold} confusionMatrixPerThreshold={confusionMatrixPerThreshold} />
    </div>
  );
};

export default ConfusionMatrixCumulativeGraph;
