import React, { useMemo, useState } from 'react';
import cx from 'classnames';
import { useAtom } from 'jotai';
import {
  Box,
  makeStyles,
  MenuItem,
  Select,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
  Tooltip,
} from '@material-ui/core';
import { Button, Typography } from '@clef/client-library';
import { useGetConfusionMatrixQuery } from '@/serverStore/modelAnalysis';
import { DefectColorChip } from '@/pages/DataBrowser/ModelPerformance/ConfusionMatrix';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { modelListFilterOptionsAtom } from '../atoms';
import ModelImageList from '../ModelImageList/ModelImageList';
import { RegisteredModelWithThreshold } from '@/api/model_api';
import useGetDefectNameById from '@/hooks/defect/useGetDefectNameById';
import LoadingProgress from '../LoadingProgress';
import InfoOutlined from '@material-ui/icons/InfoOutlined';
import { AggregatedConfusionMatrix } from '@clef/shared/types';

const ALL_CLASSES_MENU_VALUE = -1;

const usePercentageStyles = makeStyles(theme => ({
  green: {
    color: theme.palette.green[500],
  },
  red: {
    color: theme.palette.red[500],
  },
  grey: {
    color: theme.palette.grey[600],
  },
  textAlignRight: {
    textAlign: 'right',
  },
}));

const Percentage: React.FC<{
  baseline: number;
  candidate: number;
  isCorrectMapping?: boolean;
}> = ({ baseline, candidate, isCorrectMapping }) => {
  const styles = usePercentageStyles();
  if (baseline === 0 && candidate === 0) {
    return <Typography className={cx(styles.textAlignRight, styles.grey)}>--</Typography>;
  }
  if (baseline === 0) {
    return isCorrectMapping ? (
      <Typography className={cx(styles.textAlignRight, styles.grey)}>--</Typography>
    ) : (
      <Typography className={cx(styles.textAlignRight, styles.red)}>{t('New Error')}</Typography>
    );
  } else if (candidate === 0) {
    return isCorrectMapping ? (
      <Typography className={cx(styles.textAlignRight, styles.grey)}>--</Typography>
    ) : (
      <Typography className={cx(styles.textAlignRight, styles.green)}>{t('Fixed')}</Typography>
    );
  }
  const percentage = ((candidate - baseline) / baseline) * 100;
  if (percentage === 0) {
    return <Typography className={cx(styles.textAlignRight, styles.grey)}>0%</Typography>;
  } else {
    const isGreen = (percentage < 0 && !isCorrectMapping) || (percentage > 0 && isCorrectMapping);
    const isRed = (percentage > 0 && !isCorrectMapping) || (percentage < 0 && isCorrectMapping);

    const fixedNumber = `${Math.abs(percentage).toFixed(1)}%`;

    const getTooltipTitle = () => {
      if (isCorrectMapping) {
        return isGreen
          ? t('Candidate has {{fixedNumber}} more correct pixels', { fixedNumber })
          : t('Candidate has {{fixedNumber}} less correct pixels', { fixedNumber });
      } else {
        return isRed
          ? t('Candidate has {{fixedNumber}} more errors', { fixedNumber })
          : t('Candidate has {{fixedNumber}} less errors', { fixedNumber });
      }
    };
    return (
      <Tooltip title={getTooltipTitle()} placement="top">
        <Box>
          <Typography
            className={cx(styles.textAlignRight, {
              [styles.green]: isGreen,
              [styles.red]: isRed,
            })}
          >
            {percentage > 0 ? '+' : '-'}
            {fixedNumber}
          </Typography>
        </Box>
      </Tooltip>
    );
  }
};

interface ComparisonMatrix {
  gtDefectId: number | null;
  predDefectId: number | null;
  baseline: number;
  candidate: number;
}

const useTableStyles = makeStyles(theme => ({
  container: {
    marginBottom: theme.spacing(5),
    borderRadius: '10px',
    border: `1px solid ${theme.palette.grey[300]}`,
  },
  headerRow: {
    backgroundColor: theme.palette.grey[50],
  },
  pointerCursor: {
    cursor: 'pointer',
  },
  selected: {
    backgroundColor: theme.palette.blue[50],
  },
  textAlignRight: {
    textAlign: 'right',
  },
  sumText: {
    fontWeight: 500,
    color: theme.palette.grey[600],
  },
  infoTitle: {
    display: 'flex',
    alignItems: 'center',
    gap: theme.spacing(1),
  },
}));

const ConfusionMatrixTable: React.FC<{
  title: string;
  titleTooltip: string;
  baseSum: number;
  candidateSum: number;
  isCorrectMapping?: boolean;
  comparsionMatrices: Array<ComparisonMatrix>;
}> = ({ title, titleTooltip, baseSum, candidateSum, isCorrectMapping, comparsionMatrices }) => {
  const styles = useTableStyles();
  const [filterOptions, setFilterOptions] = useAtom(modelListFilterOptionsAtom);
  return (
    <TableContainer className={styles.container}>
      <Table>
        <TableHead>
          <TableRow className={styles.headerRow}>
            <TableCell width={filterOptions ? 265 : 345}>
              <Box className={styles.infoTitle}>
                <Typography>{title}</Typography>
                <Tooltip placement="top" arrow={true} title={titleTooltip}>
                  <InfoOutlined fontSize="small" />
                </Tooltip>
              </Box>
            </TableCell>
            <TableCell width={filterOptions ? 100 : 145}>
              <Tooltip
                placement="top-start"
                title={t('Total number of errors the baseline model made for {{errorType}} type.', {
                  errorType: title.toLowerCase(),
                })}
              >
                <span>
                  <Typography className={cx(styles.sumText, styles.textAlignRight)}>
                    {baseSum}
                  </Typography>
                </span>
              </Tooltip>
            </TableCell>
            <TableCell width={filterOptions ? 100 : 145}>
              <Tooltip
                placement="top-start"
                title={t(
                  'Total number of errors the candidate model made for {{errorType}} type.',
                  {
                    errorType: title.toLowerCase(),
                  },
                )}
              >
                <span>
                  <Typography className={cx(styles.sumText, styles.textAlignRight)}>
                    {candidateSum}
                  </Typography>
                </span>
              </Tooltip>
            </TableCell>
            <TableCell width={filterOptions ? 100 : 120}>
              <Percentage
                baseline={baseSum}
                candidate={candidateSum}
                isCorrectMapping={isCorrectMapping}
              />
            </TableCell>
            {!filterOptions && <TableCell width={90} />}
          </TableRow>
        </TableHead>
        <TableBody>
          {comparsionMatrices.map(({ gtDefectId, predDefectId, baseline, candidate }) => {
            const gtClassId = gtDefectId ?? 0;
            const predClassId = predDefectId ?? 0;
            return (
              <TableRow
                key={`${gtDefectId}-${predDefectId}`}
                className={cx(styles.pointerCursor, {
                  [styles.selected]:
                    filterOptions?.gtClassId === gtClassId &&
                    filterOptions?.predClassId === predClassId,
                })}
                onClick={() =>
                  setFilterOptions({
                    gtClassId,
                    predClassId,
                  })
                }
              >
                <TableCell>
                  <Box display="flex" alignItems="center">
                    <DefectColorChip
                      width={
                        filterOptions
                          ? isCorrectMapping
                            ? 250
                            : 130
                          : isCorrectMapping
                          ? 350
                          : 175
                      }
                      defectId={gtDefectId}
                      size={12}
                      maxWidth={isCorrectMapping ? 150 : 100}
                    />
                    {!isCorrectMapping && (
                      <DefectColorChip
                        width={filterOptions ? 130 : 175}
                        isPrediction
                        defectId={predDefectId}
                        size={12}
                        maxWidth={100}
                      />
                    )}
                  </Box>
                </TableCell>
                <TableCell>
                  <Typography className={styles.textAlignRight}>{baseline}</Typography>
                </TableCell>
                <TableCell>
                  <Typography className={styles.textAlignRight}>{candidate}</Typography>
                </TableCell>
                <TableCell>
                  <Percentage
                    baseline={baseline}
                    candidate={candidate}
                    isCorrectMapping={isCorrectMapping}
                  />
                </TableCell>
                {!filterOptions && (
                  <TableCell>
                    <Button
                      id="view-specific-comparison-images-button"
                      color="primary"
                      variant="text"
                      size="small"
                      onClick={() =>
                        setFilterOptions({
                          gtClassId,
                          predClassId,
                        })
                      }
                    >
                      {t('View')}
                    </Button>
                  </TableCell>
                )}
              </TableRow>
            );
          })}
        </TableBody>
      </Table>
    </TableContainer>
  );
};

const useMatrixStyles = makeStyles(theme => ({
  title: {
    fontWeight: 700,
    color: theme.palette.grey[900],
  },
  textAlignRight: {
    textAlign: 'right',
  },
  paddingLeft16px: {
    paddingLeft: theme.spacing(4),
  },
  paddingRight16px: {
    paddingRight: theme.spacing(4),
  },
}));

export type ModelComparisonConfusionMatrixProps = {
  baseline: RegisteredModelWithThreshold;
  candidate: RegisteredModelWithThreshold;
  evaluationSet: EvaluationSetItem;
  baselineThreshold: number;
  candidateThreshold: number;
};

const ModelComparisonConfusionMatrix: React.FC<ModelComparisonConfusionMatrixProps> = props => {
  const { baseline, candidate, evaluationSet, baselineThreshold, candidateThreshold } = props;
  const { data: baselineConfusionMatrixData, isLoading: isLoadingBaseline } =
    useGetConfusionMatrixQuery(baseline.id, evaluationSet.id, baselineThreshold);
  const { data: candidateConfusionMatrixData, isLoading: isLoadingCandidate } =
    useGetConfusionMatrixQuery(candidate.id, evaluationSet.id, candidateThreshold);
  const [filterOptions] = useAtom(modelListFilterOptionsAtom);
  const [filteredDefect, setFilteredDefect] = useState<number>(ALL_CLASSES_MENU_VALUE);
  const getDefectNameById = useGetDefectNameById();
  const styles = useMatrixStyles();

  const { splitConfusionMatrices: baselineConfusionMatrices } = baselineConfusionMatrixData ?? {};
  const { splitConfusionMatrices: candidateConfusionMatrices } = candidateConfusionMatrixData ?? {};

  // No Label should not be included on correct mapping
  const getComparisonMatrix = (
    baselineConfusionMatrix: AggregatedConfusionMatrix[],
    candidateConfusionMatrix: AggregatedConfusionMatrix[],
    defectSets: Set<number>,
    filteredDefect: number,
    shouldNotIncludeNoLabel: boolean = false,
  ): ComparisonMatrix[] => {
    const baselineConfusionMatrixMap = baselineConfusionMatrix
      .filter(m => m.count > 0)
      .reduce((acc, cur) => {
        return { ...acc, [`${cur.gtClassId}-${cur.predClassId}`]: cur };
      }, {} as Record<string, AggregatedConfusionMatrix>);
    const candidateConfusionMatrixMap = candidateConfusionMatrix
      .filter(m => m.count > 0)
      .reduce((acc, cur) => {
        return { ...acc, [`${cur.gtClassId}-${cur.predClassId}`]: cur };
      }, {} as Record<string, AggregatedConfusionMatrix>);
    const compareConfusionMatrix: ComparisonMatrix[] = [];

    Object.keys({
      ...baselineConfusionMatrixMap,
      ...candidateConfusionMatrixMap,
    }).forEach(key => {
      const baselineCount = baselineConfusionMatrixMap.hasOwnProperty(key)
        ? baselineConfusionMatrixMap[key].count
        : 0;
      const candidateCount = candidateConfusionMatrixMap.hasOwnProperty(key)
        ? candidateConfusionMatrixMap[key].count
        : 0;
      const item = baselineConfusionMatrixMap.hasOwnProperty(key)
        ? baselineConfusionMatrixMap[key]
        : candidateConfusionMatrixMap[key];
      const res = {
        gtDefectId: item.gtClassId,
        predDefectId: item.predClassId,
        baseline: baselineCount,
        candidate: candidateCount,
      };
      defectSets.add(res.gtDefectId ?? 0);
      defectSets.add(res.predDefectId ?? 0);
      if (
        filteredDefect !== ALL_CLASSES_MENU_VALUE &&
        filteredDefect !== res.gtDefectId &&
        filteredDefect !== res.predDefectId
      ) {
        return;
      }
      if (shouldNotIncludeNoLabel) {
        !!res.gtDefectId && compareConfusionMatrix.push(res);
      } else {
        compareConfusionMatrix.push(res);
      }
    });
    return compareConfusionMatrix;
  };

  const {
    defectSets,
    correctConfusionMatrix,
    misClassificationConfusionMatrix,
    falseNegativeConfusionMatrix,
    falsePositiveConfusionMatrix,
  } = useMemo(() => {
    const defectSets: Set<number> = new Set();
    const {
      correct: baselineCorrect,
      falsePositive: baselineFP,
      falseNegative: baselineFN,
      misClassification: baselineMC,
    } = baselineConfusionMatrices ?? {};
    const {
      correct: candidateCorrect,
      falsePositive: candidateFP,
      falseNegative: candidateFN,
      misClassification: candidateMC,
    } = candidateConfusionMatrices ?? {};

    const correctConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineCorrect?.data ?? [],
      candidateCorrect?.data ?? [],
      defectSets,
      filteredDefect,
      true, // No Label should not be included on correct mapping
    );
    const falsePositiveConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineFP?.data ?? [],
      candidateFP?.data ?? [],
      defectSets,
      filteredDefect,
    );
    const falseNegativeConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineFN?.data ?? [],
      candidateFN?.data ?? [],
      defectSets,
      filteredDefect,
    );
    const misClassificationConfusionMatrix: ComparisonMatrix[] = getComparisonMatrix(
      baselineMC?.data ?? [],
      candidateMC?.data ?? [],
      defectSets,
      filteredDefect,
    );
    return {
      correctConfusionMatrix,
      falsePositiveConfusionMatrix,
      falseNegativeConfusionMatrix,
      misClassificationConfusionMatrix,
      defectSets,
    };
  }, [filteredDefect, baselineConfusionMatrixData, candidateConfusionMatrixData]);

  if (isLoadingBaseline || isLoadingCandidate) {
    return <LoadingProgress size={24} />;
  }

  return (
    <Box display="flex">
      <Box
        position="absolute"
        top={4}
        left={filterOptions ? 400 : 650}
        width={200}
        textAlign="right"
      >
        <Select
          value={filteredDefect}
          onChange={(e: React.ChangeEvent<{ value: unknown }>) =>
            setFilteredDefect(e.target.value as number)
          }
        >
          <MenuItem value={ALL_CLASSES_MENU_VALUE}>{t('All Classes')}</MenuItem>
          {Array.from(defectSets).map(
            id =>
              id && (
                <MenuItem key={id} value={id}>
                  {getDefectNameById(id)}
                </MenuItem>
              ),
          )}
        </Select>
      </Box>
      <Box
        width={filterOptions ? 600 : 850}
        flexShrink={0}
        flexGrow={0}
        style={{
          transition: 'width 0.3s ease-in-out',
        }}
      >
        <Box display="flex" marginBottom={3} alignItems="center">
          <Typography className={cx(styles.title, styles.paddingLeft16px)} width={175}>
            {t('Ground Truth')}
          </Typography>
          <Typography className={cx(styles.title, styles.paddingLeft16px)} width={175}>
            {t('Prediction')}
          </Typography>
          <Typography className={cx(styles.title, styles.textAlignRight)} width={145}>
            {filterOptions ? t('Baseline') : t('Baseline model')}
          </Typography>
          <Typography
            className={cx(styles.title, styles.textAlignRight, {
              [styles.paddingRight16px]: !filterOptions,
            })}
            width={145}
          >
            {filterOptions ? t('Candidate') : t('Candidate model')}
          </Typography>
          <Box display="flex" alignItems="center">
            <Typography className={cx(styles.title, styles.textAlignRight)} width={90}>
              {t('Differences')}
            </Typography>
            <Box marginLeft={1} />
            <Tooltip
              placement="top"
              arrow={true}
              title={
                <>
                  <Typography>
                    {t(
                      `Shows if the candidate model performed better (Fixed) or worse (New Error) than the baseline candidate for each error type. For differences shown as a percentage, the percentage is calculated as:`,
                    )}
                  </Typography>
                  <Typography>{t('((candidate - baseline) / baseline) * 100')}</Typography>
                </>
              }
            >
              <InfoOutlined fontSize="small" />
            </Tooltip>
          </Box>
        </Box>
        {[
          falsePositiveConfusionMatrix,
          falseNegativeConfusionMatrix,
          misClassificationConfusionMatrix,
          correctConfusionMatrix,
        ].map(
          (matrix, index) =>
            matrix.length > 0 && (
              <ConfusionMatrixTable
                key={['false-positive', 'false-negative', 'mis-classified', 'correct'][index]}
                title={
                  [t('False Positive'), t('False Negative'), t('Misclassified'), t('Correct')][
                    index
                  ]
                }
                titleTooltip={
                  [
                    t(
                      'The model predicted that an object of interest was present, but the model was incorrect.',
                    ),
                    t(
                      'The model predicted that an object of interest was not present, but the model was incorrect.',
                    ),
                    t(
                      'The model correctly predicted that an object of interest was present, but it predicted the wrong class.',
                    ),
                    t('The model’s prediction was correct.'),
                  ][index]
                }
                isCorrectMapping={index === 3}
                baseSum={matrix.reduce((accum, val) => accum + val.baseline, 0)}
                candidateSum={matrix.reduce((accum, val) => accum + val.candidate, 0)}
                comparsionMatrices={matrix}
              />
            ),
        )}
      </Box>
      {!!filterOptions && (
        <ModelImageList
          model={baseline}
          candidate={candidate}
          threshold={baseline.threshold}
          candidateThreshold={candidate.threshold}
          evaluationSet={evaluationSet}
        />
      )}
    </Box>
  );
};

export default ModelComparisonConfusionMatrix;
