import React, { useMemo, useState } from 'react';
import cx from 'classnames';
import { useAtom } from 'jotai';
import { modelListFilterOptionsAtom } from '../atoms';
import { Box, makeStyles, TextField, Tooltip } from '@material-ui/core';
import { Typography } from '@clef/client-library';
import Search from '@material-ui/icons/Search';
import { useGetConfusionMatrixQuery } from '@/serverStore/modelAnalysis';
import useGetDefectNameById from '@/hooks/defect/useGetDefectNameById';
import {
  PredictionMatrix,
  PredictionMatrixData,
} from '@/pages/DataBrowser/ModelPerformance/ConfusionMatrix';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { AggregatedConfusionMatrix, LabelType, RegisteredModel } from '@clef/shared/types';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import LoadingProgress from '../LoadingProgress';
import InfoOutlined from '@material-ui/icons/InfoOutlined';

const useStyles = makeStyles(theme => ({
  HeaderWithClearButton: {
    display: 'flex',
    justifyContent: 'space-between',
    alignItems: 'center',
  },
  clearFilterButton: {
    color: theme.palette.blue[600],
    fontSize: 12,
    fontWeight: 700,
    lineHeight: '16px',
    cursor: 'pointer',
  },
  searchIcon: {
    color: theme.palette.grey[500],
    marginRight: theme.spacing(1),
  },
  confusionMatrixSearchField: {
    borderRadius: 6,
    height: 36,
    background: theme.palette.greyModern[25],
  },
  correctColor: {
    color: theme.palette.green[500],
  },
  incorrectColor: {
    color: theme.palette.error.main,
  },
  matrixRow: {
    height: 36,
  },
  matrixLeftCell: {
    paddingLeft: theme.spacing(2),
  },
  matrixRightCell: {
    paddingRight: theme.spacing(2),
  },
  matrixCell: {
    verticalAlign: 'middle',
  },
  infoTitle: {
    display: 'flex',
    alignItems: 'center',
    gap: theme.spacing(1),
  },
}));

export type ModelConfusionMatrixProps = {
  model?: RegisteredModel;
  evaluationSet?: EvaluationSetItem;
  threshold?: number;
};

const ModelConfusionMatrix: React.FC<ModelConfusionMatrixProps> = props => {
  const styles = useStyles();
  const { model, evaluationSet, threshold } = props;
  const { data: confusionMatrixData, isLoading: isConfusionMatrixDataLoading } =
    useGetConfusionMatrixQuery(model?.id, evaluationSet?.id, threshold);
  const [searchText, setSearchText] = useState('');
  const [filterOptions, setFilterOptions] = useAtom(modelListFilterOptionsAtom);

  const getDefectNameById = useGetDefectNameById();
  const {
    correctConfusionMatrix,
    falsePositiveConfusionMatrix,
    falseNegativeConfusionMatrix,
    misClassificationConfusionMatrix,
  } = useMemo(() => {
    const addNameAndFilterSearchKeyToMatrix = (
      confusionMatrix: AggregatedConfusionMatrix[],
      lowerSearchText: string,
    ) => {
      const confusionMatrixWithNames = confusionMatrix
        .filter(m => m.count > 0)
        .map(
          item =>
            ({
              gtDefectId: item.gtClassId,
              predDefectId: item.predClassId,
              count: item.count,
              gtCaption: item.gtClassId ? getDefectNameById(item.gtClassId) : 'No label',
              predictionCaption: item.predClassId
                ? getDefectNameById(item.predClassId)
                : 'No label',
            } as PredictionMatrixData),
        );
      return lowerSearchText
        ? confusionMatrixWithNames.filter(
            c =>
              c.gtCaption.toLowerCase().includes(lowerSearchText) ||
              c.predictionCaption.toLowerCase().includes(lowerSearchText),
          )
        : confusionMatrixWithNames;
    };

    const { correct, misClassification, falseNegative, falsePositive } =
      confusionMatrixData?.splitConfusionMatrices ?? {};
    const lowerSearchText = searchText.toLowerCase();
    const correctConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      correct?.data ?? [],
      lowerSearchText,
    );
    const misClassificationConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      misClassification?.data ?? [],
      lowerSearchText,
    );
    const falseNegativeConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      falseNegative?.data ?? [],
      lowerSearchText,
    );
    const falsePositiveConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      falsePositive?.data ?? [],
      lowerSearchText,
    );
    return {
      correctConfusionMatrix,
      misClassificationConfusionMatrix,
      falseNegativeConfusionMatrix,
      falsePositiveConfusionMatrix,
    };
  }, [confusionMatrixData, searchText, getDefectNameById]);

  const { data: project } = useGetSelectedProjectQuery();
  const { labelType } = project ?? {};
  const countTitle = labelType === LabelType.Segmentation ? t('Pixels') : t('Count');

  return (
    <Box id="confusion-matrix-section" width={250} flexShrink={0} flexGrow={0}>
      {isConfusionMatrixDataLoading ? (
        <LoadingProgress size={24} />
      ) : (
        <>
          <Box className={styles.HeaderWithClearButton} marginBottom={4}>
            <Typography variant="body_bold">{t('Analyze')}</Typography>
          </Box>
          <Box marginBottom={5}>
            <TextField
              variant="outlined"
              placeholder={t('Search by class')}
              InputProps={{
                className: styles.confusionMatrixSearchField,
                startAdornment: <Search className={styles.searchIcon} />,
              }}
              value={searchText}
              onChange={e => setSearchText(e.target.value ?? '')}
            />
          </Box>
          <Box>
            <Box display="table" width="100%">
              <Box display="table-row" className={styles.matrixRow}>
                <Box
                  display="table-cell"
                  className={cx(styles.matrixCell, styles.matrixLeftCell)}
                  width="44%"
                >
                  <Typography variant="body2">
                    <strong>{t('Ground Truth')}</strong>
                  </Typography>
                </Box>
                <Box display="table-cell" className={styles.matrixCell} width="44%">
                  <Typography variant="body2">
                    <strong>{t('Prediction')}</strong>
                  </Typography>
                </Box>
                <Box display="table-cell" className={cx(styles.matrixCell, styles.matrixRightCell)}>
                  <Typography variant="body2">
                    <strong>{countTitle}</strong>
                  </Typography>
                </Box>
              </Box>
            </Box>
            {falsePositiveConfusionMatrix.length > 0 && (
              <PredictionMatrix
                title={
                  <Box className={styles.infoTitle}>
                    <Typography
                      variant="subtitle1"
                      className={cx(styles.incorrectColor, styles.matrixLeftCell)}
                    >
                      {t('False Positive ({{count}})', {
                        count: falsePositiveConfusionMatrix
                          .reduce((accum, matrix) => accum + matrix.count, 0)
                          .toLocaleString(),
                      })}
                    </Typography>
                    <Tooltip
                      placement="top"
                      arrow={true}
                      title={t(
                        'The model predicted that an object of interest was present, but the model was incorrect.',
                      )}
                    >
                      <InfoOutlined fontSize="small" />
                    </Tooltip>
                  </Box>
                }
                filterOptions={filterOptions}
                data={falsePositiveConfusionMatrix}
                data-testid="false-positive-predictions"
                onCountClick={data => {
                  setFilterOptions({
                    gtClassId: data.gtDefectId || 0,
                    predClassId: data.predDefectId || 0,
                  });
                }}
              />
            )}
            {misClassificationConfusionMatrix.length > 0 && (
              <PredictionMatrix
                title={
                  <Box className={styles.infoTitle}>
                    <Typography
                      variant="subtitle1"
                      className={cx(styles.incorrectColor, styles.matrixLeftCell)}
                    >
                      {t('Mis-Classification ({{count}})', {
                        count: misClassificationConfusionMatrix
                          .reduce((accum, matrix) => accum + matrix.count, 0)
                          .toLocaleString(),
                      })}
                    </Typography>
                    <Tooltip
                      placement="top"
                      arrow={true}
                      title={t(
                        'The model correctly predicted that an object of interest was present, but it predicted the wrong class.',
                      )}
                    >
                      <InfoOutlined fontSize="small" />
                    </Tooltip>
                  </Box>
                }
                filterOptions={filterOptions}
                data={misClassificationConfusionMatrix}
                data-testid="mis-classification-predictions"
                onCountClick={data => {
                  setFilterOptions({
                    gtClassId: data.gtDefectId || 0,
                    predClassId: data.predDefectId || 0,
                  });
                }}
              />
            )}
            {falseNegativeConfusionMatrix.length > 0 && (
              <PredictionMatrix
                title={
                  <Box className={styles.infoTitle}>
                    <Typography
                      variant="subtitle1"
                      className={cx(styles.incorrectColor, styles.matrixLeftCell)}
                    >
                      {t('False Negative ({{count}})', {
                        count: falseNegativeConfusionMatrix
                          .reduce((accum, matrix) => accum + matrix.count, 0)
                          .toLocaleString(),
                      })}
                    </Typography>
                    <Tooltip
                      placement="top"
                      arrow={true}
                      title={t(
                        'The model predicted that an object of interest was not present, but the model was incorrect.',
                      )}
                    >
                      <InfoOutlined fontSize="small" />
                    </Tooltip>
                  </Box>
                }
                filterOptions={filterOptions}
                data={falseNegativeConfusionMatrix}
                data-testid="false-negative-predictions"
                onCountClick={data => {
                  setFilterOptions({
                    gtClassId: data.gtDefectId || 0,
                    predClassId: data.predDefectId || 0,
                  });
                }}
              />
            )}
            {correctConfusionMatrix.length > 0 && (
              <PredictionMatrix
                title={
                  <Box className={styles.infoTitle}>
                    <Typography
                      variant="subtitle1"
                      className={cx(styles.correctColor, styles.matrixLeftCell)}
                    >
                      {t('Correct Predictions ({{count}})', {
                        count: correctConfusionMatrix
                          .reduce((accum, matrix) => accum + matrix.count, 0)
                          .toLocaleString(),
                      })}
                    </Typography>
                    <Tooltip
                      placement="top"
                      arrow={true}
                      title={t('The model’s prediction was correct.')}
                    >
                      <InfoOutlined fontSize="small" />
                    </Tooltip>
                  </Box>
                }
                data={correctConfusionMatrix}
                data-testid="correct-predictions"
                onlyShowCorrectColumn
                filterOptions={filterOptions}
                onCountClick={data => {
                  setFilterOptions({
                    gtClassId: data.gtDefectId || 0,
                    predClassId: data.predDefectId || 0,
                  });
                }}
              />
            )}
          </Box>
        </>
      )}
    </Box>
  );
};

export default ModelConfusionMatrix;
