import { useJobDetailForCurrentProject } from '@/serverStore/jobs';
import {
  ClassificationConfusionMatrix,
  ConfusionMatrixPerThreshold,
  LabelType,
  RegisteredModelId,
} from '@clef/shared/types';
import { useMemo } from 'react';
import { PerformanceType } from './states';

export const Palette = {
  TN: '#CDD0D3',
  FP: '#F6D05F',
  FN: '#FF5C9A',
  TP: '#1E6091',
  MC: '#8670FF',
};

export const logScale =
  (base: number = Math.E) =>
  (x: number) =>
    (x >= 0 ? Math.log(x + 1) : -Math.log(-x + 1)) / Math.log(base);

export const exponentScale =
  (base: number = Math.E) =>
  (x: number) =>
    x >= 0 ? base ** x - 1 : -(base ** -x) + 1;

export const useConfusionMatrixPerThreshold = (
  performanceType: PerformanceType,
  split: string,
  modelId: RegisteredModelId | undefined,
): ConfusionMatrixPerThreshold | ClassificationConfusionMatrix | undefined => {
  const { data: jobDetails } = useJobDetailForCurrentProject(modelId);
  const datasetLevelMetrics = jobDetails?.datasetLevelMetrics;
  return useMemo(() => {
    if (!datasetLevelMetrics) {
      return undefined;
    }
    const {
      confusionMatrix,
      confusionMatrixPerSplit,
      binarizedConfusionMatrix,
      binarizedConfusionMatrixPerSplit,
      truePositives,
      trueNegatives,
      falsePositives,
      falseNegatives,
      misclassified,
      version,
    } = datasetLevelMetrics;

    if (version === '1.0') {
      if (performanceType === PerformanceType.MediaLevel) {
        return split ? binarizedConfusionMatrixPerSplit?.[split] : binarizedConfusionMatrix;
      }
      if (performanceType === PerformanceType.AnnotationLevel) {
        return split ? confusionMatrixPerSplit?.[split] : confusionMatrix;
      }
    } else {
      if (performanceType === PerformanceType.AnnotationLevel && !split) {
        return {
          truePositives: truePositives || [],
          falsePositives: falsePositives || [],
          trueNegatives: trueNegatives || [],
          falseNegatives: falseNegatives || [],
          misclassified: misclassified || [],
        };
      }
    }
    return undefined;
  }, [datasetLevelMetrics, performanceType, split]);
};

export const getConfusionMatrixCountTitle = (
  performanceType: PerformanceType,
  labelType: LabelType,
) => {
  if (performanceType === PerformanceType.MediaLevel) {
    return t('# of media');
  }
  if (labelType === LabelType.BoundingBox) {
    return t('# of bounding boxes');
  } else if (labelType === LabelType.Segmentation) {
    return t('# of pixels');
  } else {
    // TODO: not finalized yet
    return t('# of classes');
  }
};
