import ConfusionMatrixGraphs from '@/components/ProjectModel/ConfusionMatrixGraphs';
import { IconButton } from '@clef/client-library';
import { ConfusionMatrixPerThreshold, LabelType, PerformanceMetrics } from '@clef/shared/types';
import { Box, makeStyles, Slider, TextField, Tooltip, Typography } from '@material-ui/core';
import KeyboardArrowDown from '@material-ui/icons/KeyboardArrowDown';
import KeyboardArrowUp from '@material-ui/icons/KeyboardArrowUp';
import { range } from 'lodash';
import React, { useState } from 'react';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { useDataBrowserState } from '../dataBrowserState';
import { PerformanceType } from '../ProjectModelDetails/states';
import { CONFIDENCE_THRESHOLD_OPTIONS } from '@clef/shared/constants';

import InfoIcon from '../icon/InfoIcon';

const useStyles = makeStyles(theme => ({
  titleWrapper: {
    display: 'flex',
    alignItems: 'center',
    lineHeight: '20px',
    whiteSpace: 'nowrap',
    fontSize: 14,
    marginBottom: theme.spacing(1),
  },
  title: {
    color: theme.palette.greyModern[500],
    fontWeight: 700,
  },
  thresholdWrapper: {
    display: 'flex',
    columnGap: '12px',
  },
  thresholdInputRoot: {
    width: 44,
    height: 36,
  },
  thresholdInput: {
    textAlign: 'center',
    padding: 0,
    fontFamily: theme.typography.body1.fontFamily,
    width: '100%',
    height: '100%',
    boxSizing: 'border-box',
    fontStyle: 'normal',
    fontWeight: 400,
    fontSize: '14px',
    lineHeight: '20px',
    '&::-webkit-outer-spin-button, &::-webkit-inner-spin-button': {
      '-webkit-appearance': 'none',
      margin: 0,
    },
    '&[type=number]': {
      '-moz-appearance': 'textfield',
    },
  },
  sliderRoot: {
    margin: 'auto 0',
    flex: 1,
  },
  sliderRail: {
    color: theme.palette.greyModern[400],
    borderRadius: 4,
    height: 4,
  },
  sliderTrack: {
    color: '#4B5565',
    borderRadius: 4,
    height: 4,
  },
  sliderMark: {
    borderRadius: 2,
    height: 4,
    width: 4,
    color: `${theme.palette.common.white} !important`,
    backgroundColor: `${theme.palette.common.white} !important`,
    opacity: `1 !important`,
  },
  sliderMarkLabel: {
    top: 20,
  },
  sliderThumb: {
    width: 16,
    height: 16,
    background: '#FFFFFF',
    border: '1.5px solid #4B5565',
    borderRadius: 12,
    marginTop: -7,
    '&:hover': {
      boxShadow: '0px 0px 0px 8px rgb(75, 85, 101, .16)',
    },
    '&.MuiSlider-active': {
      boxShadow: '0px 0px 0px 14px rgb(75, 85, 101, .16)',
    },
  },
  sliderActive: {
    boxShadow: '0px 0px 0px 14px rgb(75, 85, 101, .16)',
  },
  markLabelTooltipLayer: {
    width: 8,
    height: 8,
    borderRadius: 4,
    backgroundColor: 'transparent',
    transform: 'translate(25%, -100%)',
  },
}));

export type ControlledPerformanceChartsProps = {
  modelId: string | undefined;
  threshold: number | undefined;
  modelMetrics: PerformanceMetrics;
  onThresholdChange: (newThreshold: number) => void | Promise<void>;
  alwaysShowChart?: boolean;
  marks?: { value: number; label: string }[];
};

export const ControlledPerformanceCharts: React.FC<ControlledPerformanceChartsProps> = ({
  modelId,
  threshold,
  modelMetrics,
  onThresholdChange,
  alwaysShowChart = false,
  marks,
}) => {
  const styles = useStyles();
  const { labelType } = useGetSelectedProjectQuery().data ?? {};
  const [showCharts, setShowCharts] = useState(labelType === LabelType.Segmentation);
  const {
    state: { viewMode },
  } = useDataBrowserState();
  const { thresholdsMin, thresholdsMax, thresholdsStep } = CONFIDENCE_THRESHOLD_OPTIONS;
  const thresholds = range(thresholdsMin, thresholdsMax, thresholdsStep).map(
    x => Math.round(x * 1000) / 1000,
  );

  if (modelId === undefined || threshold === undefined) {
    return null;
  }

  return (
    <>
      <Box display="flex" flexDirection="column" marginBottom={0.5}>
        <div className={styles.titleWrapper}>
          <Typography display="inline" className={styles.title}>
            {t('Confidence Threshold')}
          </Typography>
          <Tooltip
            arrow
            placement="top"
            title={t(
              'Your model makes several predictions. You can adjust the Confidence Threshold to filter out the less confident predictions so you can focus on the most confident ones.',
            )}
          >
            <IconButton id="confidence-threshold-info" size="small">
              <InfoIcon />
            </IconButton>
          </Tooltip>
        </div>
        <Box className={styles.thresholdWrapper}>
          <Slider
            min={thresholdsMin}
            max={thresholdsMax - thresholdsStep}
            step={thresholdsStep}
            value={threshold}
            marks={marks?.map(mark => ({
              ...mark,
              label: (
                <Tooltip arrow={true} title={mark.label} placement="bottom">
                  <Box className={styles.markLabelTooltipLayer} />
                </Tooltip>
              ),
            }))}
            onChange={(_, newThreshold) => {
              onThresholdChange(newThreshold as number);
            }}
            classes={{
              root: styles.sliderRoot,
              rail: styles.sliderRail,
              track: styles.sliderTrack,
              thumb: styles.sliderThumb,
              active: styles.sliderActive,
              mark: styles.sliderMark,
              markActive: styles.sliderMark,
              markLabel: styles.sliderMarkLabel,
            }}
          />
          <TextField
            variant="outlined"
            type="number"
            value={threshold}
            InputProps={{
              classes: {
                root: styles.thresholdInputRoot,
                input: styles.thresholdInput,
              },
            }}
            inputProps={{
              className: styles.thresholdInput,
              step: thresholdsStep,
              min: thresholdsMin,
              max: thresholdsMax - thresholdsStep,
              'data-testid': 'threshold-input',
            }}
            onChange={e => {
              let value = Number(e.target.value || 0);
              const max = thresholdsMax - thresholdsStep;
              if (value > max) {
                value = max;
              } else if (value < thresholdsMin) {
                value = thresholdsMin;
              } else if (!thresholds.includes(value)) {
                value =
                  thresholds.find(item => Math.abs(item - value) <= thresholdsStep * 0.55) || 0;
              }

              onThresholdChange(value || 0);
            }}
          />
          {!alwaysShowChart && (
            <Tooltip placement="top" arrow title={t('Toggle charts')}>
              <IconButton size="small" onClick={() => setShowCharts(prev => !prev)}>
                {showCharts ? <KeyboardArrowUp /> : <KeyboardArrowDown />}
              </IconButton>
            </Tooltip>
          )}
        </Box>
        <div style={{ flex: 1 }} />
      </Box>
      {(alwaysShowChart || showCharts) && (
        <Box marginLeft={-3} marginRight={-3}>
          <ConfusionMatrixGraphs
            split=""
            performanceType={PerformanceType.AnnotationLevel}
            confusionMatrix={
              viewMode === 'instance'
                ? modelMetrics.instanceParingConfusionMatrix
                : (modelMetrics.confusionMatrix as ConfusionMatrixPerThreshold)
            }
            threshold={threshold}
            labelType={labelType as LabelType}
            modelId={modelId}
            onUpdateThreshold={newThreshold => {
              onThresholdChange(newThreshold);
            }}
          />
        </Box>
      )}
    </>
  );
};

export default ControlledPerformanceCharts;
