import React, { useCallback, useState } from 'react';
import { useAtom } from 'jotai';
import { Box } from '@material-ui/core';
import { AnnotationInstance, Media } from '@clef/shared/types';
import { modelListFilterOptionsAtom } from '../atoms';
import {
  useGetConfusionMatrixCountsQuery,
  useGetModelMediaListInfiniteQuery,
} from '@/serverStore/modelAnalysis';
import LoadingProgress from '../LoadingProgress';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import ModelImageDetailDialog from '../ModelImageDetail/ModelImageDetailDialog';
import { ComparisonObjectDetectionImageDiffView } from '../ModelImageDetail/ObjectDetectionImageDiffView';
import ModelImageVirtualListWrapper from './ModelImageVirtualListWrapper';
import { SortOrder } from '@/api/model_analysis_api';

export type ModelComparisonImageListObjectDetectionProps = {
  modelId?: string;
  threshold?: number;
  candidateModelId: string;
  candidateThreshold?: number;
  evaluationSet?: EvaluationSetItem;
  containerWidth: number;
  sortOrder: SortOrder;
};

const ModelComparisonImageListObjectDetection: React.FC<
  ModelComparisonImageListObjectDetectionProps
> = props => {
  const {
    modelId,
    threshold,
    candidateModelId,
    candidateThreshold,
    evaluationSet,
    containerWidth,
    sortOrder,
  } = props;

  const [filterOptions] = useAtom(modelListFilterOptionsAtom);
  const [selectedImageId, setSelectedImageId] = useState<number>();

  const { data: mediaListPages } = useGetModelMediaListInfiniteQuery(
    modelId,
    threshold,
    evaluationSet,
    filterOptions,
    candidateModelId,
    candidateThreshold,
    undefined,
  );
  const { data: baselineCounts, isLoading: baselineCountsLoading } =
    useGetConfusionMatrixCountsQuery(modelId, evaluationSet?.id, threshold);
  const { data: candidateCounts, isLoading: candidateCountsLoading } =
    useGetConfusionMatrixCountsQuery(candidateModelId, evaluationSet?.id, candidateThreshold);

  const getDifferences = useCallback(
    (mediaId: number) => {
      if (filterOptions) {
        return 0;
      }
      const baselineCount =
        baselineCounts?.find(obj => obj.mediaId === mediaId)?.incorrectCount ?? 0;
      const candidateCount =
        candidateCounts?.find(obj => obj.mediaId === mediaId)?.incorrectCount ?? 0;
      return candidateCount - baselineCount;
    },
    [baselineCounts, candidateCounts, filterOptions],
  );

  const sortFunctions = useCallback(
    (a: Media, b: Media) =>
      (getDifferences(a.id) - getDifferences(b.id)) * (sortOrder === SortOrder.DESC ? 1 : -1),
    [getDifferences, sortOrder],
  );

  const allMedias =
    (
      (mediaListPages?.pages.flatMap(page => page?.mediaList).filter(media => !!media) ??
        []) as Media[]
    ).sort(sortFunctions) || undefined;

  if (filterOptions && (baselineCountsLoading || candidateCountsLoading)) {
    return (
      <Box>
        <LoadingProgress size={24} />
      </Box>
    );
  }

  if (!baselineCounts || !candidateCounts) {
    return null;
  }

  return (
    <>
      <ModelImageVirtualListWrapper
        titles={[
          ...(filterOptions ? [] : [t('Ground truth')]),
          t('Baseline model'),
          t('Candidate model'),
        ]}
        evaluationSet={evaluationSet}
        modelId={modelId}
        sortOrder={sortOrder}
        threshold={threshold}
        candidate={candidateModelId}
        candidateThreshold={candidateThreshold}
        containerWidth={containerWidth}
        sortFunctions={sortFunctions}
        columns={filterOptions ? 2 : 3}
        rowRender={(
          media: Media,
          rowWidth: number,
          allInstances?: AnnotationInstance[],
          candidateAllInstances?: AnnotationInstance[],
        ) => {
          const differences = getDifferences(media.id);
          return (
            <ComparisonObjectDetectionImageDiffView
              key={media.id}
              media={media}
              baselineAllInstances={allInstances}
              candidateAllInstances={candidateAllInstances}
              baselineModelId={modelId}
              baselineThreshold={threshold}
              candidateModelId={candidateModelId}
              candidateThreshold={candidateThreshold}
              rowWidth={rowWidth}
              version={evaluationSet?.datasetVersion.version}
              differences={differences}
              onImageClick={() => setSelectedImageId(media.id)}
            />
          );
        }}
      />
      <ModelImageDetailDialog
        modelId={modelId}
        threshold={threshold}
        candidateModelId={candidateModelId}
        candidateThreshold={candidateThreshold}
        evaluationSet={evaluationSet}
        mediaId={selectedImageId}
        mediaList={allMedias}
        onClose={() => setSelectedImageId(undefined)}
      />
    </>
  );
};

export default ModelComparisonImageListObjectDetection;
