import React, { useCallback, useState } from 'react';
import { useAtom } from 'jotai';
import { Box, makeStyles, Theme } from '@material-ui/core';
import { Media } from '@clef/shared/types';
import MediaContainer from '@/pages/DataBrowser/MediaContainer';
import { calcOptimalRatio } from '@/pages/DataBrowser/MediaGrid/MediaGrid';
import { Typography, VirtualList } from '@clef/client-library';
import { modelListFilterOptionsAtom } from '../atoms';
import { useComparisonSegmentationInfoWithFilters } from '@/pages/DataBrowser/utils';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { useGetConfusionMatrixCountsQuery } from '@/serverStore/modelAnalysis';
import { SortOrder } from '@/api/model_analysis_api';
import StripeSvg from '@/images/model-iteration/stripe.svg';
import ModelImageDetailDialog from '../ModelImageDetail/ModelImageDetailDialog';
import LoadingProgress from '../LoadingProgress';

const useStyles = makeStyles<Theme, { differences: number }>(theme => ({
  mediaViewerRoot: {
    border: props =>
      `3px solid ${
        props.differences < 0 ? theme.palette.red[500] : theme.palette.green[500]
      } !important`,
    borderRadius: '6px',
  },
  differenceText: {
    color: 'white',
    padding: '2px 4px',
    borderRadius: '3px 3px 0 0',
    fontWeight: 700,
    lineHeight: '16px',
    display: 'inline-block',
    marginRight: theme.spacing(4),
    backgroundColor: props =>
      props.differences < 0 ? theme.palette.red[500] : theme.palette.green[500],
  },
  mediaNameAndMaskSwitchTabs: {
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'space-between',
    padding: theme.spacing(4, 0),
  },
  maskSwitchTabContainer: {
    display: 'flex',
    alignItems: 'center',
    gap: theme.spacing(2),
  },
  errorRegionTab: {
    display: 'flex',
    alignItems: 'center',
    gap: theme.spacing(1),
  },
  maskSwitchTab: {
    cursor: 'pointer',
    padding: theme.spacing(1),
  },
  tabText: {
    lineHeight: '20px',
  },
  activeTab: {
    backgroundColor: theme.palette.blue[50],
    borderRadius: '5px',
  },
  errorPixelsCount: {
    background: 'rgba(0, 0, 0, 0.50)',
    padding: '0 2px',
    position: 'absolute',
    right: 24,
    bottom: 24,
    color: 'white',
  },
  diffViewContainer: {
    position: 'relative',
    transition: 'outline 0.3s',
    borderRadius: 8,
    border: `2px solid transparent`,
    '&:hover': {
      border: `2px solid ${theme.palette.greyModern[400]}`,
      padding: theme.spacing(1),
    },
    '&::after': {
      display: 'block',
      content: '""',
      position: 'absolute',
      left: 0,
      top: 0,
      width: '100%',
      height: '100%',
      zIndex: 106,
      cursor: 'pointer',
    },
  },
  filteredMaskAnimation: {
    animation: `$flash 1.5s ${theme.transitions.easing.easeInOut} 2`,
  },
  '@keyframes flash': {
    '0%': {
      filter: 'brightness(1)',
    },
    '25%': {
      filter: 'brightness(1.5)',
    },
    '50%': {
      filter: 'brightness(1)',
    },
    '75%': {
      filter: 'brightness(0.5)',
    },
    '100%': {
      filter: 'brightness(1)',
    },
  },
}));

interface DiffViewProps {
  media: Media & { count?: number; candidateCount?: number };
  baselineModelId?: string;
  baselineThreshold?: number;
  containerWidth: number;
  candidateModelId?: string;
  candidateThreshold?: number;
  version?: number;
  imageRatio: number;
  evaluationSet?: EvaluationSetItem;
  differences: number;
  onImageClick: () => void;
}

const MediaDiffView: React.FC<DiffViewProps> = props => {
  const {
    media,
    baselineModelId,
    baselineThreshold,
    containerWidth,
    imageRatio,
    candidateModelId,
    candidateThreshold,
    evaluationSet,
    differences,
    onImageClick,
  } = props;
  const version = evaluationSet?.datasetVersion.version;
  const [filterOptions] = useAtom(modelListFilterOptionsAtom);
  const { baseline, candidate } =
    useComparisonSegmentationInfoWithFilters(
      media.id,
      baselineModelId,
      baselineThreshold,
      candidateModelId,
      candidateThreshold,
      version,
      filterOptions,
    ) ?? {};
  const [baselineUrl, candidateUrl] = [baseline?.dataUrl, candidate?.dataUrl];
  const columns = filterOptions ? 2 : 3;
  const columnWidth = 100 / columns;

  const styles = useStyles({ differences });
  const [onMediaHovered, setOnMediaHovered] = useState<boolean>(false);

  return (
    <Box
      textAlign="right"
      onMouseOver={() => filterOptions && setOnMediaHovered(true)}
      onMouseOut={() => filterOptions && setOnMediaHovered(false)}
    >
      {differences !== 0 && (
        <Typography variant="body2" className={styles.differenceText}>
          {differences > 0
            ? t('Performance IOU {{diff}}% increase', {
                diff: differences.toFixed(1),
              })
            : t('Performance IOU {{diff}}% decrease', {
                diff: -differences.toFixed(1),
              })}
        </Typography>
      )}
      <Box
        display="flex"
        key={media.id}
        className={styles.diffViewContainer}
        height={(containerWidth * imageRatio) / columns}
        onClick={() => onImageClick()}
      >
        {!filterOptions && (
          <Box width={columnWidth + '%'}>
            <MediaContainer
              media={media}
              showGroundTruth
              modelId={baselineModelId}
              versionId={version}
              threshold={baselineThreshold}
            />
          </Box>
        )}
        <Box width={columnWidth + '%'} position="relative">
          <MediaContainer
            media={media}
            showPredictions
            modelId={baselineModelId}
            versionId={version}
            threshold={baselineThreshold}
            maskHitFilterUrl={filterOptions ? baselineUrl : undefined}
            filterOptions={filterOptions}
            classes={
              onMediaHovered ? { filteredMaskImage: styles.filteredMaskAnimation } : undefined
            }
          />
          {filterOptions && typeof media.count === 'number' && (
            <Typography className={styles.errorPixelsCount}>
              {filterOptions.gtClassId === filterOptions.predClassId
                ? t('{{count}} correct pixels', { count: media.count })
                : t('{{count}} error pixels', { count: media.count })}
            </Typography>
          )}
        </Box>
        <Box width={columnWidth + '%'} position="relative">
          <MediaContainer
            media={media}
            showPredictions
            modelId={candidateModelId}
            versionId={version}
            threshold={candidateThreshold}
            maskHitFilterUrl={filterOptions ? candidateUrl : undefined}
            filterOptions={filterOptions}
            classes={
              differences !== 0
                ? {
                    mediaViewerRoot: styles.mediaViewerRoot,
                  }
                : {
                    filteredMaskImage: onMediaHovered ? styles.filteredMaskAnimation : undefined,
                  }
            }
          />
          {filterOptions && typeof media.candidateCount === 'number' && (
            <Typography className={styles.errorPixelsCount}>
              {filterOptions.gtClassId === filterOptions.predClassId
                ? t('{{count}} correct pixels', { count: media.candidateCount })
                : t('{{count}} error pixels', { count: media.candidateCount })}
            </Typography>
          )}
        </Box>
      </Box>
    </Box>
  );
};

const useListStyles = makeStyles(theme => ({
  mediaListRoot: {
    width: '100%',
  },
  filterOptionsDisplay: {
    display: 'flex',
    alignItems: 'center',
    margin: theme.spacing(5, 0),
    padding: theme.spacing(4),
    justifyContent: 'center',
    gap: theme.spacing(1),
    borderRadius: theme.spacing(5),
    background: theme.palette.blue[50],
  },
}));

export type ModelImageListSegmentationProps = {
  mediaList?: Media[];
  modelId?: string;
  threshold?: number;
  candidateModelId?: string;
  candidateThreshold?: number;
  evaluationSet?: EvaluationSetItem;
  sortOrder: SortOrder;
  containerWidth: number;
};

const ModelComparisonImageListSegmentation: React.FC<ModelImageListSegmentationProps> = props => {
  const styles = useListStyles();
  const {
    mediaList,
    modelId,
    threshold,
    containerWidth,
    candidateModelId,
    candidateThreshold,
    evaluationSet,
    sortOrder,
  } = props;
  const [filterOptions] = useAtom(modelListFilterOptionsAtom);
  const [selectedImageId, setSelectedImageId] = useState<number>();

  const { data: baselineConfusionMatrixCounts, isLoading: baselineLoading } =
    useGetConfusionMatrixCountsQuery(modelId, evaluationSet?.id, threshold);
  const { data: candidateConfusionMatrixCounts, isLoading: candidateLoading } =
    useGetConfusionMatrixCountsQuery(candidateModelId, evaluationSet?.id, candidateThreshold);
  const getDifferences = useCallback(
    (media: Media) => {
      if (filterOptions) return 0;
      const baselineCount = baselineConfusionMatrixCounts?.find(obj => obj.mediaId === media.id);
      const candidateCount = candidateConfusionMatrixCounts?.find(obj => obj.mediaId === media.id);
      if (!baselineCount || !candidateCount) return 0;
      const baselineIOU =
        baselineCount.correctCount / (baselineCount.correctCount + baselineCount.incorrectCount);
      const candidateIOU =
        candidateCount.correctCount / (candidateCount.correctCount + candidateCount.incorrectCount);
      return (candidateIOU - baselineIOU) * 100;
    },
    [baselineConfusionMatrixCounts, candidateConfusionMatrixCounts, filterOptions],
  );
  if (!mediaList) {
    return null;
  } else if (baselineLoading || candidateLoading) {
    return <LoadingProgress size={24} />;
  }
  const imageRatio = calcOptimalRatio(mediaList);
  const columns = filterOptions ? 2 : 3;
  const height = (containerWidth * imageRatio) / columns + (filterOptions ? 60 : 16);
  const sortedMediaList = mediaList.sort(
    (a, b) => (getDifferences(a) - getDifferences(b)) * (sortOrder === SortOrder.DESC ? -1 : 1),
  );

  return (
    <>
      {filterOptions && (
        <Box className={styles.filterOptionsDisplay}>
          <Typography>
            {filterOptions.gtClassId === filterOptions.predClassId
              ? t('Correct regions is highlighted with')
              : t('Error regions is highlighted with')}
          </Typography>
          <img src={StripeSvg} />
        </Box>
      )}
      <Box display="flex" alignItems="center" marginBottom={3} id="model-list-titles">
        {!filterOptions && (
          <Box flex={1}>
            <Typography variant="body_bold">{t('Ground truth')}</Typography>
          </Box>
        )}
        <Box flex={1}>
          <Typography variant="body_bold">{t('Baseline model')}</Typography>
        </Box>
        <Box flex={1}>
          <Typography variant="body_bold">{t('Candidate model')}</Typography>
        </Box>
      </Box>
      <Box maxHeight={1000}>
        {baselineConfusionMatrixCounts && candidateConfusionMatrixCounts && (
          <VirtualList dataList={sortedMediaList} itemHeight={height} containerMaxHeight={1500}>
            {media => {
              const differences = getDifferences(media);
              return (
                <MediaDiffView
                  differences={differences}
                  key={media.id}
                  media={media}
                  baselineModelId={modelId}
                  baselineThreshold={threshold}
                  candidateModelId={candidateModelId}
                  candidateThreshold={candidateThreshold}
                  imageRatio={imageRatio}
                  containerWidth={containerWidth}
                  evaluationSet={evaluationSet}
                  onImageClick={() => setSelectedImageId(media.id)}
                />
              );
            }}
          </VirtualList>
        )}
      </Box>
      <ModelImageDetailDialog
        modelId={modelId}
        threshold={threshold}
        candidateModelId={candidateModelId}
        candidateThreshold={candidateThreshold}
        evaluationSet={evaluationSet}
        mediaId={selectedImageId}
        mediaList={sortedMediaList}
        onClose={() => setSelectedImageId(undefined)}
      />
    </>
  );
};

export default ModelComparisonImageListSegmentation;
