import experiment_report_api from '@/api/experiment_report_api';
import zero_auth_api from '@/api/zero_auth_api';
import {
  BitMapLabelingAnnotation,
  BoxLabelingAnnotation,
  ClassificationLabelingAnnotation,
  PureCanvasLabelingAnnotation,
  useLabelingState,
} from '@/components/Labeling/labelingState';
import { layerToBitMapAnnotationsAsync } from '@/components/Labeling/utils';
import { refreshZeroAuthProjectApi, useZeroAuthProjectApi } from '@/hooks/api/useZeroAuthApi';
import { useCountDown } from '@/hooks/useCountDown';
import { useSetTimeout } from '@/hooks/useTimeout';
import { getDefectColor } from '@/utils';
import {
  AnnotationChangeType,
  defectColors,
  MediaInteractiveCanvasProps,
  useLocalStorage,
} from '@clef/client-library';
import {
  Defect,
  MediaDetails,
  MediaDetailsWithPrediction,
  Position,
  SegmentationAnnotationData,
  ZeroAuthLabel,
  ZeroAuthTrainAnnotation,
} from '@clef/shared/types';
import { Link } from '@material-ui/core';
import { isEqual } from 'lodash';
import { SnackbarKey, useSnackbar } from 'notistack';
import React, { useRef } from 'react';
import { createContext, useCallback, useContext, useEffect, useMemo } from 'react';
import { Updater } from 'use-immer';

export type ZeroAuthMediaStates = {
  isChanged: boolean;
  annotations?:
    | BoxLabelingAnnotation[]
    | BitMapLabelingAnnotation[]
    | ClassificationLabelingAnnotation[]
    | PureCanvasLabelingAnnotation[];
  predictionAnnotations?:
    | BoxLabelingAnnotation[]
    | BitMapLabelingAnnotation[]
    | PureCanvasLabelingAnnotation[]
    | ClassificationLabelingAnnotation[];
  mediaDetails?: MediaDetailsWithPrediction;

  gtSegPath?: string;
  gtInitialized?: boolean;
  predSegPath?: string;
  properties?: MediaDetails['properties'];
};

export type ZeroAuthInstantLearningState = {
  userId?: string;
  currentProjectId?: number;
  mediaIndex: number;
  mediaStatesByIndex: { [mediaIndex: number]: ZeroAuthMediaStates };
  mousePos?: Position;
  mousePosInfo?: {
    predictClass?: Defect;
  };
  training?: boolean;
  countDown: number;
  labelingWrapper: {
    showGroundTruthLabels: boolean;
    showPredictionLabels: boolean;
  };
};

export const initialState: ZeroAuthInstantLearningState = {
  mediaIndex: -1,
  mediaStatesByIndex: {},
  labelingWrapper: {
    showGroundTruthLabels: false,
    showPredictionLabels: true,
  },
  countDown: 0,
};

export const ZeroAuthInstantLearningContext = createContext<{
  state: ZeroAuthInstantLearningState;
  dispatch: Updater<ZeroAuthInstantLearningState>;
}>({ state: initialState, dispatch: () => {} });

export const useZeroAuthInstantLearningState = () => useContext(ZeroAuthInstantLearningContext);

export const useCurrentProject = () => {
  const { state } = useZeroAuthInstantLearningState();
  const { currentProjectId } = state;
  const [project] = useZeroAuthProjectApi(currentProjectId);
  return project;
};

export const useAllClasses = () => {
  const { state } = useZeroAuthInstantLearningState();
  const { currentProjectId } = state;
  const [project] = useZeroAuthProjectApi(currentProjectId);
  return useMemo(() => {
    return (
      Object.entries(project?.defectMap ?? {})
        // ignore OK class
        .filter(([classIndex]) => classIndex !== '0')
        .map(
          ([classIndex, { name }]) =>
            ({
              id: Number(classIndex),
              name,
              // pick color from defect palette
              color: defectColors[Number(classIndex) - 1],
              indexId: Number(classIndex),
            } as unknown as Defect),
        )
    );
  }, [project?.defectMap]);
};

export const useColorToClassIdMap = () => {
  const allClasses = useAllClasses();
  return useMemo(() => {
    return allClasses.reduce(
      (acc, oneClass) => ({
        ...acc,
        [getDefectColor(oneClass)]: oneClass.id,
      }),
      {} as Record<string, number>,
    );
  }, [allClasses]);
};

export const useClassesById = () => {
  const allClasses = useAllClasses();
  return useMemo(() => {
    return allClasses.reduce(
      (acc, oneClass) => ({ ...acc, [oneClass.id]: oneClass }),
      {} as Record<string, Defect>,
    );
  }, [allClasses]);
};

export const useCurrentImageState = () => {
  const { state } = useZeroAuthInstantLearningState();
  const { mediaIndex, mediaStatesByIndex } = state;
  const imageState = mediaStatesByIndex[mediaIndex];
  return imageState ?? {};
};

export const useUpdateAnnotations = () => {
  const { state, dispatch } = useZeroAuthInstantLearningState();
  const { mediaIndex } = state;
  const allClasses = useAllClasses();
  const { properties } = useCurrentImageState();
  return useCallback(
    async (_canvasAnnotations, changeType, layer) => {
      if (changeType === AnnotationChangeType.Reset) {
        return;
      }
      const newAnnotations =
        changeType === AnnotationChangeType.DeleteAll
          ? []
          : await layerToBitMapAnnotationsAsync(layer, properties ?? undefined, allClasses);
      dispatch(draft => {
        const mediaStates = draft.mediaStatesByIndex[mediaIndex] ?? {};
        if (!isEqual(mediaStates.annotations, newAnnotations)) {
          mediaStates.isChanged = true;
          mediaStates.annotations = newAnnotations;
          draft.mediaStatesByIndex[mediaIndex] = mediaStates;
        }
      });
    },
    [allClasses, dispatch, mediaIndex, properties],
  ) as NonNullable<MediaInteractiveCanvasProps['onAnnotationChanged']>;
};

/**
 * current_media:
 *  get all labeled class IDs for current image.
 *  used for auto switching classes.
 *
 * new_labels_only:
 *  get all new labeled class IDs for new labels since last train through all images.
 *  used for step guide and train button status.
 *
 * all:
 *  get all labeled class IDs for all images.
 *  used for train button status.
 */
export type LabeledClassIdsSource = 'current_image' | 'new_labels_only' | 'all';
export const useLabeledClassIds = (source: LabeledClassIdsSource = 'all') => {
  const { state } = useZeroAuthInstantLearningState();
  const { mediaIndex, mediaStatesByIndex } = state;
  return useMemo(() => {
    const newLabeledClassIds = new Set<number>();
    let mediaStates: ZeroAuthMediaStates[] = [];
    if (source === 'current_image') {
      const mediaState = mediaStatesByIndex[mediaIndex];
      if (mediaState) {
        mediaStates = [mediaState];
      }
    } else if (source === 'new_labels_only') {
      mediaStates = Object.values(mediaStatesByIndex).filter(mediaState => mediaState.isChanged);
    } else {
      mediaStates = Object.values(mediaStatesByIndex);
    }
    mediaStates.forEach(mediaState => {
      if (source === 'new_labels_only' && !mediaState.isChanged) {
        return;
      }
      if (mediaState.annotations) {
        mediaState.annotations.forEach(annotation => {
          newLabeledClassIds.add(annotation.defectId);
        });
      }
    });
    return newLabeledClassIds;
  }, [mediaIndex, mediaStatesByIndex, source]);
};

export const useHasTrained = () => {
  const { state } = useZeroAuthInstantLearningState();
  const { mediaStatesByIndex } = state;
  return useMemo(() => {
    return Object.values(mediaStatesByIndex).some(mediaState => !!mediaState.gtSegPath);
  }, [mediaStatesByIndex]);
};

export const useAutoSwitchToNextClass = () => {
  const hasTrained = useHasTrained();
  const labeledClassIds = useLabeledClassIds('all');
  const allClasses = useAllClasses();
  const { dispatch } = useLabelingState();

  useEffect(() => {
    // Only enable auto change before the first train
    // Only enable auto change for the first two labeled classes.
    // Switch to one unlabeled class
    if (!hasTrained && labeledClassIds.size < 2 && labeledClassIds.size < allClasses.length) {
      const unlabeledClass = allClasses.find(oneClass => !labeledClassIds.has(oneClass.id));
      if (unlabeledClass) {
        dispatch(draft => {
          draft.selectedDefect = unlabeledClass;
        });
      }
    }
  }, [allClasses, allClasses.length, dispatch, hasTrained, labeledClassIds]);
};

export const useStartTrain = () => {
  const { state, dispatch } = useZeroAuthInstantLearningState();
  const { currentProjectId, mediaStatesByIndex } = state;
  const project = useCurrentProject();
  const [skipHealthCheck] = useLocalStorage('skip_health_check');

  const { enqueueSnackbar, closeSnackbar } = useSnackbar();

  const { timerIdRef, setTimeout } = useSetTimeout();
  const waitingTooLongSnackBarKey = useRef<SnackbarKey>();

  const { startCountDown } = useCountDown(count => {
    dispatch(draft => {
      draft.countDown = count;
    });
  });

  return useCallback(async () => {
    if (!project || !currentProjectId) {
      return;
    }

    try {
      dispatch(draft => {
        draft.training = true;
      });

      if (!skipHealthCheck) {
        // train worker availability check
        const { instantLearning } = await experiment_report_api.getTrainHealth();
        if (!instantLearning.active || instantLearning.occupation > 0.95) {
          dispatch(draft => {
            draft.training = false;
          });
          startCountDown(5);
          enqueueSnackbar(
            t('We are experiencing exceptionally high demand. Please try Visual Prompting later.'),
            { variant: 'warning', autoHideDuration: 6000 },
          );
          return;
        }
      }

      // generate labels
      const imageLabels = project.dataset.map((_, index) => {
        const mediaStates: ZeroAuthMediaStates = mediaStatesByIndex[index];
        if (!mediaStates || !mediaStates.annotations || mediaStates.annotations.length === 0) {
          return { annotations: [] } as ZeroAuthLabel;
        }
        const annotations = mediaStates.annotations?.map(
          annotation =>
            ({
              defectId: annotation.defectId,
              rangeBox: (annotation.data as SegmentationAnnotationData).rangeBox,
              segmentationBitmapEncoded: (annotation.data as SegmentationAnnotationData).bitMap,
            } as ZeroAuthTrainAnnotation),
        );
        return { annotations, isChanged: mediaStates.isChanged } as ZeroAuthLabel;
      });

      setTimeout(() => {
        waitingTooLongSnackBarKey.current = enqueueSnackbar(
          t('Still loading? You can {{reload}} to try again.', {
            reload: (
              <Link
                color="inherit"
                id="reload-page"
                onClick={() => window.location.reload()}
                underline="always"
                style={{ cursor: 'pointer' }}
              >
                {t('reload the page')}
              </Link>
            ),
          }),
          { variant: 'info', persist: true },
        );
      }, 30_000);

      // start train
      await zero_auth_api.train({
        zeroAuthProjectId: currentProjectId!,
        zeroAuthLabels: imageLabels,
      });
      //refresh the page
      refreshZeroAuthProjectApi({ keys: 'refresh-all', swr: true });
    } catch (e) {
      enqueueSnackbar(e.message, { variant: 'error', autoHideDuration: 12000 });
    } finally {
      dispatch(draft => {
        draft.training = false;
      });
      timerIdRef.current && clearTimeout(timerIdRef.current);
      waitingTooLongSnackBarKey.current && closeSnackbar(waitingTooLongSnackBarKey.current);
    }
  }, [
    closeSnackbar,
    currentProjectId,
    dispatch,
    enqueueSnackbar,
    mediaStatesByIndex,
    project,
    setTimeout,
    skipHealthCheck,
    startCountDown,
    timerIdRef,
  ]);
};
