import { Button, Dropdown, Typography } from '@clef/client-library';
import { Box, Divider, MenuItem, MenuList, makeStyles } from '@material-ui/core';
import React, { useCallback, useState } from 'react';
import { Provider } from 'jotai';
import AdjustThresholdDropdown from './AdjustThresholdDropdown';
import ModelPerformance from '../ModelPerformance';
import ModelConfusionMatrix from './ModelConfusionMatrix';
import {
  useGetBatchModelMetricsQuery,
  useGetModelEvaluationReportsQuery,
} from '@/serverStore/modelAnalysis';
import { LabelType, ModelEvaluationReportStatus, RegisteredModel } from '@clef/shared/types';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { getTrainDevTestEvaluationSetName, getEvaluationSetName } from '../utils';
import ModelImageList from '../ModelImageList/ModelImageList';
import { isEmpty } from 'lodash';
import ExpandMore from '@material-ui/icons/ExpandMore';
import { useModelAnalysisCreateBundleMutation } from '@/serverStore/modelAnalysis/mutations';
import LoadingProgress from '../LoadingProgress';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { BundleInfo } from '@/api/model_api';
import { DownloadIcon } from '@/images/media_details/ToolIcons';
import model_analysis_api from '@/api/model_analysis_api';
import { useSnackbar } from 'notistack';
import { isModelTrainingSuccessful } from '@/store/projectModelInfoState/utils';

type PerformanceReportPanelProps = {
  model?: RegisteredModel & { bundles?: BundleInfo[] };
  evaluationSet?: EvaluationSetItem;
  threshold: number;
  trainDevTestColumnEvaluationSets?: EvaluationSetItem[];
  otherEvaluationSets?: EvaluationSetItem[];
  onChangeEvaluationSet?: (e: EvaluationSetItem) => void;
};

const useStyles = makeStyles(() => ({
  thresholdText: {
    fontSize: 14,
  },
}));

const AfterAdjustThresholdTip = () => {
  return (
    <Box
      width="100%"
      height="100%"
      display="flex"
      flexDirection="column"
      justifyContent="center"
      alignItems="center"
    >
      <Typography variant="body_regular">
        {t('You need to save the threshold to evaluate the model')}
      </Typography>
    </Box>
  );
};

const PerformanceReportPanel = (props: PerformanceReportPanelProps) => {
  const {
    model,
    evaluationSet,
    threshold: initialThreshold,
    trainDevTestColumnEvaluationSets,
    otherEvaluationSets,
    onChangeEvaluationSet,
  } = props;
  const styles = useStyles();
  const { enqueueSnackbar } = useSnackbar();

  const { data: evaluationReports, isLoading: isEvaluationReportsLoading } =
    useGetModelEvaluationReportsQuery(model?.id);
  const completedReports = (evaluationReports ?? []).filter(
    r => r.status === ModelEvaluationReportStatus.COMPLETED && r.threshold === initialThreshold,
  );
  const isCurrentModel = (evaluationSet: EvaluationSetItem) =>
    evaluationSet.datasetVersionId === model?.datasetVersionId;
  const evaluationOptions = [
    // train dev test for this model, should displayed with simplified names
    ...(trainDevTestColumnEvaluationSets ?? [])
      .filter(isCurrentModel)
      .sort((s1, s2) => {
        const order = { train: 0, dev: 1, test: 2 } as Record<string, number>;
        return order[s1.split?.splitSetName ?? ''] - order[s2.split?.splitSetName ?? ''];
      })
      .map(evaluationSet => ({
        evaluationSet,
        name: getTrainDevTestEvaluationSetName(evaluationSet),
      })),
    // train dev test for other models, display full names
    ...(trainDevTestColumnEvaluationSets ?? [])
      .filter(e => !isCurrentModel(e) && !e.hidden)
      .map(evaluationSet => ({
        evaluationSet,
        name: getEvaluationSetName(evaluationSet),
      })),
    // other evaluation sets, display full names
    ...(otherEvaluationSets ?? [])
      .filter(e => !e.hidden)
      .map(evaluationSet => ({
        evaluationSet,
        name: getEvaluationSetName(evaluationSet),
      })),
  ].filter(o => {
    // show only those eval sets that have completed evaluation
    return completedReports.some(r => r.evaluationSetId === o.evaluationSet.id);
  });

  const selectedOption =
    evaluationOptions.find(option => option.evaluationSet.id === evaluationSet?.id) ??
    (evaluationOptions.length ? evaluationOptions[0] : null);
  const selectedEvaluationSet = selectedOption?.evaluationSet ?? null;

  const [threshold, setThreshold] = useState(initialThreshold);
  const { data: batchModelMetrics, isLoading: isModelMetricsLoading } =
    useGetBatchModelMetricsQuery();
  const modelMetrics = batchModelMetrics?.find(
    pred =>
      pred.evaluationSetId === selectedEvaluationSet?.id &&
      pred.threshold === threshold &&
      pred.modelId === model?.id,
  );

  const createBundle = useModelAnalysisCreateBundleMutation(model?.modelName || undefined);
  const overallPerformance = modelMetrics?.metrics?.all?.performance;
  const [adjustThresholdDropdownAnchorEl, setAdjustThresholdDropdownAnchorEl] =
    React.useState<null | HTMLElement>(null);
  const { data: project } = useGetSelectedProjectQuery();
  const { labelType } = project ?? {};
  const hideAdjustThresholdButton = labelType === LabelType.Classification;

  const onThresholdSave = async (newThreshold: number) => {
    if (model && selectedEvaluationSet && newThreshold !== undefined) {
      if (!isModelTrainingSuccessful(model.status, model.metricsReady)) {
        enqueueSnackbar(t('The model has not yet finish training, Please try again later.'), {
          variant: 'warning',
          autoHideDuration: 12000,
        });
        return;
      }
      await createBundle.mutateAsync({
        modelId: model.id,
        evaluationSetId: selectedEvaluationSet.id,
        threshold: newThreshold,
      });
    }
    setThreshold(newThreshold);
    setAdjustThresholdDropdownAnchorEl(null);
  };

  const onThresholdCancel = () => {
    setAdjustThresholdDropdownAnchorEl(null);
  };

  const modelEvaluationReport = completedReports.find(report => {
    return (
      report.evaluationSetId === selectedEvaluationSet?.id &&
      report.threshold === threshold &&
      report.modelId === model?.id
    );
  });

  const [downloadCsvLoading, setDownloadCsvLoading] = useState(false);
  const { id: projectId } = useGetSelectedProjectQuery().data ?? {};

  const handleDownloadCsvClicked = useCallback(async () => {
    if (!projectId) return;
    if (!modelEvaluationReport) return;
    try {
      setDownloadCsvLoading(true);
      const res = await model_analysis_api.getModelEvaluationReportCsv(
        projectId,
        modelEvaluationReport.id,
      );
      window.open(res);
    } catch (e) {
      enqueueSnackbar(
        t(`Failed to download CSV. {{errorMessage}}`, {
          errorMessage: e.message,
        }),
        {
          variant: 'error',
          autoHideDuration: 3000,
        },
      );
    } finally {
      setDownloadCsvLoading(false);
    }
  }, [enqueueSnackbar, projectId, modelEvaluationReport?.id]);

  return isEvaluationReportsLoading ? (
    <LoadingProgress />
  ) : evaluationOptions.length === 0 ? (
    <Typography>
      {t(
        'There is no evaluation set for this model, please run evaluation first on models table page.',
      )}
    </Typography>
  ) : (
    <>
      <Box display="flex" alignItems="center" id="evaluation-set-and-performance-row">
        <Box id="evaluation-set-and-performance-row-left">
          {/* evaluation set */}
          <Box display="flex" alignItems="center" id="evaluation-set-section" marginBottom={4}>
            <Box marginRight={3}>
              <Typography>{t('Evaluation set:')}</Typography>
            </Box>
            <Dropdown
              dropdown={toggleDropdown => (
                <MenuList>
                  {evaluationOptions.map(option => (
                    <MenuItem
                      key={option.evaluationSet.id}
                      selected={option.evaluationSet.id === evaluationSet?.id}
                      onClick={() => {
                        onChangeEvaluationSet?.(option.evaluationSet);
                        toggleDropdown(false);
                      }}
                    >
                      {t(option.name)}
                    </MenuItem>
                  ))}
                </MenuList>
              )}
            >
              <Box display="flex" alignItems="center">
                <Typography>{selectedOption?.name}</Typography>
                <ExpandMore fontSize="small" />
              </Box>
            </Dropdown>
          </Box>
          {/* threshold */}
          <Box
            display="flex"
            alignItems="center"
            id="confidence-threshold-section"
            marginBottom={5}
            className={styles.thresholdText}
          >
            {t('Confidence threshold:{{thresholdValue}}', {
              thresholdValue: (
                <Box marginLeft={3} component={'span'}>
                  <Typography display="inline">{threshold}</Typography>
                </Box>
              ),
            })}
          </Box>
          {!hideAdjustThresholdButton && (
            <Box marginBottom={7}>
              <Button
                id="adjust-threshold"
                variant="outlined"
                onClick={event => {
                  setAdjustThresholdDropdownAnchorEl(event.currentTarget);
                }}
              >
                {t('Adjust Threshold')}
              </Button>
              <AdjustThresholdDropdown
                model={model}
                initialThreshold={threshold}
                anchorEl={adjustThresholdDropdownAnchorEl}
                handleClose={() => {
                  setAdjustThresholdDropdownAnchorEl(null);
                }}
                onSave={onThresholdSave}
                onCancel={onThresholdCancel}
              />
            </Box>
          )}
        </Box>
        <Box id="evaluation-set-and-performance-row-right" paddingLeft={30}>
          {/* performance */}
          <ModelPerformance
            performance={overallPerformance}
            modelId={model?.id}
            threshold={threshold}
            evaluationSet={selectedEvaluationSet}
          />
        </Box>
        {modelEvaluationReport && labelType !== LabelType.Segmentation && (
          <Box display="flex" flex={1} flexDirection={'row'} justifyContent={'flex-end'}>
            <Button
              variant="outlined"
              id={'download-model-evaluation-report-csv'}
              startIcon={<DownloadIcon />}
              onClick={handleDownloadCsvClicked}
              disabled={!projectId || !modelEvaluationReport || downloadCsvLoading}
            >
              {t('Download CSV')}
            </Button>
          </Box>
        )}
      </Box>
      <Box marginBottom={7}>
        <Divider />
      </Box>
      {isModelMetricsLoading ? (
        <LoadingProgress />
      ) : isEmpty(modelMetrics) ? (
        <AfterAdjustThresholdTip />
      ) : (
        selectedEvaluationSet && (
          <Provider>
            <Box display="flex">
              <ModelConfusionMatrix
                model={model}
                evaluationSet={selectedEvaluationSet}
                threshold={threshold}
              />
              <ModelImageList
                model={model}
                evaluationSet={selectedEvaluationSet}
                threshold={threshold}
              />
            </Box>
          </Provider>
        )
      )}
    </>
  );
};

export default PerformanceReportPanel;
