import { useKeyPress, useThrottle } from '@clef/client-library';
import { ConfusionMatrixPerThreshold, LabelType } from '@clef/shared/types';
import React from 'react';
import ConfusionMatrixCumulativeGraph from './ConfusionMatrixCumulativeGraph';
import {
  PerformanceType,
  useProjectModelDetailsState,
} from '../../pages/DataBrowser/ProjectModelDetails/states';
import {
  exponentScale,
  logScale,
  useConfusionMatrixPerThreshold,
} from '../../pages/DataBrowser/ProjectModelDetails/utils';
import cx from 'classnames';
import { round } from 'lodash';
import { makeStyles } from '@material-ui/core';
import { CONFIDENCE_THRESHOLD_OPTIONS } from '@clef/shared/constants';

const useStyles = makeStyles(({ palette, spacing }) => ({
  confusionMatrixGraphs: {
    borderRadius: 3,
    border: `1px solid ${palette.grey[200]}`,
    padding: spacing(4, 6),
    '&.compact': {
      border: 'none',
      padding: spacing(0),
    },
  },
}));

export type ConfusionMatrixGraphsProps = {
  performanceType: PerformanceType;
  split: string;
  confusionMatrix?: ConfusionMatrixPerThreshold;
  threshold: number;
  modelId: string | undefined;
  onUpdateThreshold?: (newThreshold: number) => void;
  labelType: LabelType | null | undefined;
};

/**
 * Renders histogram and cumulative graph for confusion matrix.
 * https://docs.google.com/document/d/1pUGxJfQi6WfqCgH44_co_oAtaC9lI_1WAyM11U7eCWU
 */
const ConfusionMatrixGraphs: React.FC<ConfusionMatrixGraphsProps> = ({
  performanceType,
  split,
  confusionMatrix,
  threshold,
  modelId,
  onUpdateThreshold,
  labelType,
}) => {
  const styles = useStyles();
  const { thresholdsMin, thresholdsMax, thresholdsStep } = CONFIDENCE_THRESHOLD_OPTIONS;

  const {
    state: { enableLogScale },
  } = useProjectModelDetailsState();
  let confusionMatrixPerThreshold = useConfusionMatrixPerThreshold(performanceType, split, modelId);
  // for new unified flow, the metrics come from the props
  if (confusionMatrix) {
    confusionMatrixPerThreshold = confusionMatrix;
  }

  const decreaseThreshold = useThrottle(() => {
    const newThreshold = round(threshold - thresholdsStep, 2);
    if (newThreshold >= thresholdsMin) {
      onUpdateThreshold?.(newThreshold);
    }
  }, 64);
  useKeyPress('left', decreaseThreshold);
  const increaseThreshold = useThrottle(() => {
    const newThreshold = round(threshold + thresholdsStep, 2);
    if (newThreshold <= thresholdsMax - thresholdsStep) {
      onUpdateThreshold?.(newThreshold);
    }
  }, 64);
  useKeyPress('right', increaseThreshold);

  // for log scale
  const {
    truePositives = [],
    trueNegatives = [],
    falsePositives = [],
    falseNegatives = [],
    misclassified = [],
  } = (confusionMatrixPerThreshold as ConfusionMatrixPerThreshold) ?? {};
  const maxValue = Math.max(
    ...[truePositives, trueNegatives, falsePositives, falseNegatives, misclassified].flat(),
  );
  // we want the max Y to be maxY = base^4, so base = maxY^(1/4)
  const scaleBase = Math.ceil(maxValue ** 0.25);
  const logScaleFunction = scaleBase > 1 ? logScale(scaleBase) : undefined;
  const exponentScaleFunction = scaleBase > 1 ? exponentScale(scaleBase) : undefined;

  const valueScaleFunction = enableLogScale ? logScaleFunction : undefined;
  const valueDescaleFunction = enableLogScale ? exponentScaleFunction : undefined;

  if (!confusionMatrixPerThreshold) {
    return null;
  }

  return (
    <div
      className={cx(styles.confusionMatrixGraphs, 'compact')}
      data-testid="confusion-matrix-graphs"
    >
      <ConfusionMatrixCumulativeGraph
        performanceType={performanceType}
        confusionMatrixPerThreshold={confusionMatrixPerThreshold as ConfusionMatrixPerThreshold}
        valueScaleFunction={valueScaleFunction}
        valueDescaleFunction={valueDescaleFunction}
        // no need to cap when log scale is enabled
        enableCap={!enableLogScale}
        threshold={threshold}
        onUpdateThreshold={onUpdateThreshold}
        labelType={labelType}
      />
    </div>
  );
};

export default ConfusionMatrixGraphs;
