import React, { useMemo } from 'react';
import {
  PerformanceMetrics,
  ConfusionMatrixPerThreshold,
  LabelType,
  ClassificationConfusionMatrix,
} from '@clef/shared/types/basic';
import { CONFIDENCE_THRESHOLD_OPTIONS } from '@clef/shared/constants';
import { Box, makeStyles, Tooltip, TooltipProps, Typography } from '@material-ui/core';
import { range, round, sum } from 'lodash';
import { IconButton } from '@clef/client-library';
import InfoIcon from '../icon/InfoIcon';
import { useDataBrowserState } from '../dataBrowserState';
import { classificationConfusionMatrixToParings } from '../utils';

export const PrecisionValueId = 'model-performance-precision-value';
export const RecallValueId = 'model-performance-recall-value';

const useStyles = makeStyles(theme => ({
  verticalLine: {
    backgroundColor: theme.palette.grey[200],
    flex: '0 0 1px',
  },
  rateRow: {
    display: 'flex',
    alignItems: 'stretch',
    '& + &': {
      borderTop: `1px solid ${theme.palette.grey[200]}`,
    },
  },
  rate: {
    flex: '1 0 50%',
    display: 'flex',
    flexDirection: 'column',
    alignItems: 'center',
    justifyContent: 'center',
    '& svg': {
      fontSize: '1rem',
    },
  },
  rateName: {
    color: theme.palette.grey[500],
    marginTop: theme.spacing(1),
    fontSize: 12,
  },
}));

const Rate: React.FC<{
  rate: number | 'N/A';
  name: 'Precision' | 'Recall';
  tooltipTitle: TooltipProps['title'];
  id?: string;
  'data-testid'?: string;
}> = props => {
  const { rate, name, id, 'data-testid': dataTestId } = props;
  const styles = useStyles();
  return (
    <Box paddingY={7} className={styles.rate} id={id} data-testid={dataTestId}>
      <Typography variant="h1" id={name === 'Precision' ? PrecisionValueId : RecallValueId}>
        {rate !== 'N/A' ? t('{{metric}}%', { metric: Math.round(rate * 1000) / 10 }) : 'N/A'}
      </Typography>
      <Box display="flex" alignItems="center" className={styles.rateName}>
        <Typography variant="body1" noWrap>
          {name === 'Precision' ? t('Precision') : t('Recall')}
        </Typography>
        <Box marginLeft={1} />
        <Tooltip placement="top" arrow title={props.tooltipTitle}>
          <IconButton id="performance-rates-info" size="small">
            <InfoIcon />
          </IconButton>
        </Tooltip>
      </Box>
    </Box>
  );
};

export type PerformanceRatesProps = {
  modelMetrics: PerformanceMetrics;
  threshold: number;
  labelType: LabelType | null | undefined;
  showVeriticalLine?: boolean;
};

const PerformanceRates: React.FC<PerformanceRatesProps> = ({
  modelMetrics,
  threshold: localThreshold,
  labelType,
  showVeriticalLine = true,
}) => {
  const styles = useStyles();
  const {
    state: { viewMode },
  } = useDataBrowserState();

  const { thresholdsMin, thresholdsMax, thresholdsStep } = CONFIDENCE_THRESHOLD_OPTIONS;
  const { confusionMatrix, instanceParingConfusionMatrix } = modelMetrics;

  const metrics = useMemo(() => {
    if (labelType === LabelType.Classification || labelType === LabelType.AnomalyDetection) {
      // for classification / anomaly, need to calculate multi-class metrics like this
      // https://towardsdatascience.com/multi-class-metrics-made-simple-part-i-precision-and-recall-9250280bddc2
      if (!confusionMatrix) {
        return null;
      }
      const classificationConfusionMatrix = confusionMatrix as ClassificationConfusionMatrix;
      const defectIds = Object.keys(classificationConfusionMatrix).map(Number);
      if (!defectIds.length) {
        return null;
      }
      const parings = classificationConfusionMatrixToParings(classificationConfusionMatrix);
      // 1. calculate per-class precisions and recalls.
      const precisionPerDefect = {} as Record<number, number>;
      const recallPerDefect = {} as Record<number, number>;
      const totalGroundTruthsPerDefect = {} as Record<number, number>;
      const sumFn = (sum: number, { count }: { count: number }) => sum + count;
      defectIds.forEach(defectId => {
        const totalPredictedAsDefect = parings
          .filter(({ predDefectId }) => predDefectId === defectId)
          .reduce(sumFn, 0);
        const totalGroundTruth = parings
          .filter(({ gtDefectId }) => gtDefectId === defectId)
          .reduce(sumFn, 0);
        totalGroundTruthsPerDefect[defectId] = totalGroundTruth;
        const totalCorrectlyPredicted = classificationConfusionMatrix[defectId]?.[defectId] ?? 0;
        // how many predictions of this class are correct?
        precisionPerDefect[defectId] = totalCorrectlyPredicted / totalPredictedAsDefect;
        // how many ground truths of this class are correctly predicted?
        recallPerDefect[defectId] = totalCorrectlyPredicted / totalGroundTruth;
      });

      // 2. calculate the macro average or weighted average (depending on MI team's response)

      // option 1. Macro average: arithmetic mean value of precisions and recalls, respectively
      // const allPrecisions = Object.values(precisionPerDefect);
      // const allRecalls = Object.values(recallPerDefect);
      // const macroAvgPrecision = sum(allPrecisions) / allPrecisions.length;
      // const macroAvgRecall = sum(allRecalls) / allRecalls.length;

      // option 2. Weighted average: mean value considering the sample sizes (# of ground truth for each defect)
      const weightedPrecisionSum = defectIds.map(
        defectId => precisionPerDefect[defectId] * totalGroundTruthsPerDefect[defectId] || 0,
      );
      const weightedRecallSum = defectIds.map(
        defectId => recallPerDefect[defectId] * totalGroundTruthsPerDefect[defectId] || 0,
      );
      const totalGroundTruth = sum(Object.values(totalGroundTruthsPerDefect));
      const weightedAvgPrecision = sum(weightedPrecisionSum) / totalGroundTruth;
      const weightedAvgRecall = sum(weightedRecallSum) / totalGroundTruth;
      return { precision: weightedAvgPrecision, recall: weightedAvgRecall };
    } else {
      // not classification / anomaly detection, the metrics depends on image view / instance view
      if (!localThreshold && localThreshold !== 0) {
        return null;
      }

      let truePositiveCount: number;
      let falsePositiveCount: number;
      let falseNegativeCount: number;
      let misclassifiedCount: number;
      const thresholds = range(thresholdsMin, thresholdsMax, thresholdsStep);
      const thresholdIndex = thresholds.findIndex(n => n >= localThreshold! - thresholdsStep * 0.1);

      if (viewMode === 'image') {
        /**
         * In image mode, we use the confusion matrix calculated from server, it is accurate
         */
        if (!confusionMatrix) {
          return null;
        }
        truePositiveCount =
          (confusionMatrix as ConfusionMatrixPerThreshold).truePositives[thresholdIndex] ?? 0;
        falsePositiveCount =
          (confusionMatrix as ConfusionMatrixPerThreshold).falsePositives[thresholdIndex] ?? 0;
        falseNegativeCount =
          (confusionMatrix as ConfusionMatrixPerThreshold).falseNegatives[thresholdIndex] ?? 0;
        misclassifiedCount =
          (confusionMatrix as ConfusionMatrixPerThreshold).misclassified?.[thresholdIndex] ?? 0;
      } else {
        /**
         * In instance mode, we need to accumulate the performance metrics from instance from the ui
         */
        if (!instanceParingConfusionMatrix) {
          return null;
        }
        truePositiveCount = instanceParingConfusionMatrix.truePositives[thresholdIndex] ?? 0;
        falsePositiveCount = instanceParingConfusionMatrix.falsePositives[thresholdIndex] ?? 0;
        falseNegativeCount = instanceParingConfusionMatrix.falseNegatives[thresholdIndex] ?? 0;
        misclassifiedCount = instanceParingConfusionMatrix.misclassified?.[thresholdIndex] ?? 0;
      }
      // if both the dividend and the divisor are zero, we should retun 0%
      const precision =
        truePositiveCount + falsePositiveCount === 0
          ? 1.0
          : truePositiveCount / (truePositiveCount + falsePositiveCount + misclassifiedCount);
      const recall =
        truePositiveCount + falseNegativeCount === 0
          ? 1.0
          : truePositiveCount / (truePositiveCount + falseNegativeCount + misclassifiedCount);
      return { precision, recall };
    }
  }, [
    confusionMatrix,
    instanceParingConfusionMatrix,
    labelType,
    localThreshold,
    thresholdsMin,
    thresholdsMax,
    thresholdsStep,
    viewMode,
  ]);

  if (!metrics) {
    return null;
  }

  const { precision, recall } = metrics;

  return (
    <Box marginTop={2} id="performance-rates" data-testid="performance-rates" width="100%">
      {/* Precision / recall */}
      <Box className={styles.rateRow}>
        {/* TODO: tooltip copy */}
        <Rate
          id="performance-rates-precision"
          data-testid="performance-rates-precision"
          rate={round(precision as number, 3)}
          name="Precision"
          tooltipTitle={t(
            'Precision is the percentage of accurate predictions out of the total number of predictions. E.g. if 100 instances were detected as cats and 90 of them were actually cats, the precision rate is 90%. High precision means few false positive detections, low precision means many false positive detections.',
          )}
        />
        {showVeriticalLine && <div className={styles.verticalLine} />}
        {/* TODO: tooltip copy */}
        <Rate
          id="performance-rates-recall"
          data-testid="performance-rates-recall"
          rate={round(recall as number, 3)}
          name="Recall"
          tooltipTitle={t(
            'Recall is the percentage of accurate predictions out of the total number of actual predictions. E.g. if there were 100 actual cats in an image and the model detected 80 of them, the recall rate is 80%. High recall means most actual cats were detected, low recall means some were missed.',
          )}
        />
      </Box>
    </Box>
  );
};

export default PerformanceRates;
