import React, { useCallback, useRef } from 'react';
import { Box, makeStyles } from '@material-ui/core';
import { Button, useKeyPress, useLocalStorage } from '@clef/client-library';
import { KeyHandler } from 'hotkeys-js';
import { useSnackbar } from 'notistack';
import {
  useImageLabelingContext,
  useCurrentMediaStates,
} from '../../../components/Labeling/imageLabelingContext';
import { useVPLatestModel } from '@/serverStore/projectModels';
import { useInstantLearningState } from '../state';
import { useLabeledClassIds, useStartInstantLearningTrain } from '../utils';
import { PulseWrapper } from '@clef/client-library';
import classNames from 'classnames';
import { useDefectSelector } from '@/store/defectState/actions';
import experiment_report_api from '@/api/experiment_report_api';
import { useCountDown } from '@/hooks/useCountDown';

const useStyles = makeStyles(theme => ({
  trainButtonContainer: {
    position: 'absolute',
    bottom: 0,
    right: 0,
    transform: 'translate(50%, 50%)',
    zIndex: 1,
  },
  trainButtonContainerCentered: {
    right: '50%',
  },
  trainButton: {
    padding: theme.spacing(10.5, 0, 0, 0),
    width: 180,
    height: 180,
    borderRadius: '50%',
    boxShadow: '0px 4px 8px rgba(0, 0, 0, 0.15)',
    alignItems: 'flex-start!important',
    fontSize: 24,
    fontWeight: 500,
    '&.MuiButton-contained.Mui-disabled': {
      backgroundColor: theme.palette.grey[500],
      color: 'white',
    },
  },
  pulseWrapper: {
    borderRadius: '50%',
  },
}));

export type TrainButtonProps = {};

const TrainButton: React.FC<TrainButtonProps> = () => {
  const styles = useStyles();
  const { latestModel } = useVPLatestModel();
  const { id: latestModelId } = latestModel ?? {};
  const { state, dispatch } = useInstantLearningState();
  const isTraining = state.trainingState === 'training-in-progress';

  const { state: imageLabelingState, dispatch: dispatchImageLabelingState } =
    useImageLabelingContext();
  const imageLabelingStateRef = useRef(imageLabelingState);
  imageLabelingStateRef.current = imageLabelingState;

  const startInstantLearningTrain = useStartInstantLearningTrain();
  const { mediaDetails, predictionAnnotations } = useCurrentMediaStates();

  const { enqueueSnackbar } = useSnackbar();

  const { countDown, startCountDown } = useCountDown();
  const [skipHealthCheck] = useLocalStorage('skip_health_check');

  const handleStartTrain = useCallback(async () => {
    if (isTraining) {
      return;
    }

    const prevPredictions = predictionAnnotations?.slice();
    try {
      dispatch(draft => {
        draft.trainingState = 'training-in-progress';
      });

      if (!skipHealthCheck) {
        // check for worker availability
        const { instantLearning } = await experiment_report_api.getTrainHealth();
        if (!instantLearning.active || instantLearning.occupation > 0.95) {
          enqueueSnackbar(
            t('We are experiencing exceptionally high demand. Please try Visual Prompting later.'),
            { variant: 'warning', autoHideDuration: 6000 },
          );

          dispatch(draft => {
            draft.trainingState = null;
          });

          startCountDown(5);
          return;
        }
      }

      // clear predictions
      const mediaId = mediaDetails?.id ?? -1;
      dispatchImageLabelingState(draft => {
        if (mediaId in draft.mediaStatesById) {
          draft.mediaStatesById[mediaId].predictionAnnotations = undefined;
        }
      });

      // start training
      await startInstantLearningTrain(mediaId);
    } catch (e) {
      enqueueSnackbar(e.message ?? t('Run failed for unknown reason'), {
        variant: 'error',
        autoHideDuration: 12000,
      });
      dispatchImageLabelingState(draft => {
        const mediaId = mediaDetails?.id ?? -1;
        if (mediaId in draft.mediaStatesById) {
          draft.mediaStatesById[mediaId].predictionAnnotations = prevPredictions;
        }
      });
      dispatch(draft => {
        draft.trainingState = null;
      });
    }
  }, [
    dispatchImageLabelingState,
    enqueueSnackbar,
    isTraining,
    mediaDetails?.id,
    dispatch,
    predictionAnnotations,
    skipHealthCheck,
    startCountDown,
    startInstantLearningTrain,
  ]);

  useKeyPress('enter', handleStartTrain as unknown as KeyHandler);
  const validClasses = useDefectSelector(false).length;
  const labeledClassIdsSet = useLabeledClassIds();

  const disabled =
    // Disabled if we are counting down, asking user to wait for some time
    countDown > 0 ||
    // Disabled if it's training or model status is not initialized
    isTraining === undefined ||
    isTraining ||
    // Otherwise, disabled if not changing the labels, not enough saved classes, or not enough labeled classes for the
    // first time labeling and training
    (!latestModelId && (validClasses < 2 || labeledClassIdsSet.size < 2));

  return (
    <Box
      className={classNames(
        styles.trainButtonContainer,
        state.labelAndTrainMode === 'single' && styles.trainButtonContainerCentered,
        disabled && 'disabled',
      )}
    >
      <PulseWrapper enabled={!disabled} className={styles.pulseWrapper}>
        <Button
          id={'instant-learning-train'}
          disabled={disabled}
          variant="contained"
          color="primary"
          className={styles.trainButton}
          tooltip={disabled ? '' : t('Run model {{key}}', { key: <code>Enter</code> })}
          onClick={handleStartTrain}
        >
          <Box display="flex" justifyContent="center">
            {t('Run')}
            {countDown > 0 && t('({{countDown}})', { countDown })}
          </Box>
        </Button>
      </PulseWrapper>
    </Box>
  );
};

export default TrainButton;
