import { AnnotationChangeType, MediaInteractiveCanvasProps } from '@clef/client-library';
import {
  LabelingType,
  Media,
  MediaDetailsWithPrediction,
  MediaLevelLabel,
} from '@clef/shared/types';
import { AnnotationInstance, Label, LabelSource } from '@clef/shared/types/basic';
import { Draft } from 'immer';
import { cloneDeep, isEqual, sortBy } from 'lodash';
import { useSnackbar } from 'notistack';
import React, { createContext, useCallback, useContext, useMemo } from 'react';
import { useImmer } from 'use-immer';
import LabelAPI from '../../api/label_api';
import { useTypedSelector } from '../../hooks/useTypedSelector';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { useDefectSelector } from '../../store/defectState/actions';
import { useCurrentProjectModelInfoQuery } from '@/serverStore/projectModels';
import {
  BitMapLabelingAnnotation,
  BoxLabelingAnnotation,
  ClassificationLabelingAnnotation,
  PureCanvasLabelingAnnotation,
  ToolMode,
  useColorToDefectIdMap,
  useLabelingState,
} from './labelingState';
import {
  canvasAnnotationToBoxAnnotation,
  convertToServerAnnotation,
  getAnnotationTypeByLabelType,
  layerToBitMapAnnotationsAsync,
} from './utils';
import { useQueryClient } from '@tanstack/react-query';
import { datasetQueryKeys } from '@/serverStore/dataset';

export type MediaStates = {
  isChanged: boolean;
  annotations:
    | BoxLabelingAnnotation[]
    | BitMapLabelingAnnotation[]
    | ClassificationLabelingAnnotation[]
    | PureCanvasLabelingAnnotation[];
  predictionAnnotations?:
    | BoxLabelingAnnotation[]
    | BitMapLabelingAnnotation[]
    | PureCanvasLabelingAnnotation[]
    | ClassificationLabelingAnnotation[];
  mediaLevelLabel?: MediaLevelLabel | null;
  predictionMediaLevelLabel?: MediaLevelLabel | null;
  mediaDetails?: MediaDetailsWithPrediction;
};

export type CustomSaveAnnotationFunction = () => Promise<void>;

export type ImageLabelingState = {
  mediaIndex: number;
  mediaList: Media[];
  deletedMediaIds: number[];
  isLabelMode: boolean;
  hideGroundTruthLabels: boolean;
  hidePredictionLabels: boolean;
  mediaStatesById: { [mediaId: number]: MediaStates };
  unsavedMediaStatesById: { [mediaId: number]: MediaStates };
  showHeatmap: boolean;
  annotationInstance?: AnnotationInstance;
  /** only used for UI display */
  saving?: boolean;
};

export const defaultMediaStates: MediaStates = {
  isChanged: false,
  annotations: [],
};

export const defaultState: ImageLabelingState = {
  mediaIndex: -1,
  mediaList: [],
  deletedMediaIds: [],
  isLabelMode: false,
  hideGroundTruthLabels: false,
  hidePredictionLabels: false,
  mediaStatesById: {},
  unsavedMediaStatesById: {},
  showHeatmap: false,
};

const ImageLabelingContext = createContext<{
  state: ImageLabelingState;
  dispatch: (f: (state: Draft<ImageLabelingState>) => void | ImageLabelingState) => void;
}>({
  state: defaultState,
  dispatch: () => {},
});

export const useImageLabelingContext = () => useContext(ImageLabelingContext);

export const ImageLabelingContextProvider: React.FC = ({ children }) => {
  const [state, dispatch] = useImmer(defaultState);
  return (
    <ImageLabelingContext.Provider value={{ state, dispatch }}>
      {children}
    </ImageLabelingContext.Provider>
  );
};

export const useHasUnsavedChanges = (): boolean => {
  const {
    state: { mediaStatesById, unsavedMediaStatesById },
  } = useImageLabelingContext();
  return (
    Object.values(mediaStatesById).some(mediaStates => mediaStates.isChanged) ||
    Object.keys(unsavedMediaStatesById).length > 0
  );
};

export const useSetIsLabelMode = () => {
  const {
    state: { labelingType },
    dispatch,
  } = useLabelingState();
  const { dispatch: dispatchMediaDetailsState } = useImageLabelingContext();
  return useCallback(
    (isLabelMode: boolean) => {
      dispatchMediaDetailsState(draft => {
        draft.isLabelMode = isLabelMode;
        if (isLabelMode) {
          draft.hideGroundTruthLabels = false;
        }
      });
      dispatch(draft => {
        if (!isLabelMode) {
          draft.toolMode = undefined;
        } else if (labelingType === LabelingType.DefectBoundingBox) {
          // bounding box has only one tool mode, select it by default for label mode
          draft.toolMode = ToolMode.Box;
        }
      });
    },
    [dispatch, dispatchMediaDetailsState, labelingType],
  );
};

export const useSaveAnnotations = () => {
  const { id: projectId, datasetId, labelType } = useGetSelectedProjectQuery().data ?? {};
  const { id: currentUserId } = useTypedSelector(state => state.login.user)!;
  const {
    state: { mediaStatesById, unsavedMediaStatesById },
    dispatch,
  } = useImageLabelingContext();
  const annotationType = getAnnotationTypeByLabelType(labelType);
  const { enqueueSnackbar } = useSnackbar();
  const { id: currentModelId } = useCurrentProjectModelInfoQuery();
  const queryClient = useQueryClient();

  const saveAnnotationsForMedia = useCallback(
    async (
      mediaStates: MediaStates,
      options?: { skipRefreshMediaDetails: boolean },
    ): Promise<boolean> => {
      const { annotations, mediaDetails, mediaLevelLabel } = mediaStates;
      if (!projectId || !datasetId || !mediaDetails) return false;

      dispatch(draft => {
        draft.mediaStatesById[mediaDetails.id].isChanged = false;
        draft.unsavedMediaStatesById[mediaDetails.id] = {
          ...mediaStates,
        };
      });

      const serverAnnotations = annotations.map(ann =>
        convertToServerAnnotation(ann, annotationType!),
      );
      const newLabel = {
        ...(mediaDetails?.label ?? { id: Date.now() }),
        labelerName: currentUserId,
        annotations: serverAnnotations,
      } as Label;

      try {
        queryClient.setQueryData<MediaDetailsWithPrediction>(
          datasetQueryKeys.mediaDetails(datasetId, {
            mediaId: mediaDetails.id,
            modelId: currentModelId,
          }),
          prev =>
            prev
              ? ({
                  ...prev,
                  label: newLabel,
                } as MediaDetailsWithPrediction)
              : prev,
        );
        await LabelAPI.upsertLabels({
          projectId,
          mediaId: mediaDetails.id,
          annotations: serverAnnotations,
          mediaLevelLabel: mediaLevelLabel || undefined,
          source: LabelSource.DirectLabeling,
          modelId: currentModelId,
        });
      } catch (e) {
        enqueueSnackbar(
          t('Failed to save annotations for media {{name}}: {{errorMessage}}', {
            name: mediaDetails.name,
            errorMessage: e.message,
          }),
          { variant: 'error' },
        );
        dispatch(draft => {
          draft.unsavedMediaStatesById[mediaDetails.id] = {
            ...mediaStates,
          };
        });
        return false;
      }
      if (!options?.skipRefreshMediaDetails) {
        queryClient.invalidateQueries(
          datasetQueryKeys.mediaDetails(datasetId, {
            mediaId: mediaDetails.id,
            modelId: currentModelId,
          }),
        );
      }
      queryClient.invalidateQueries(datasetQueryKeys.modelMetrics(projectId));

      dispatch(draft => {
        delete draft.unsavedMediaStatesById[mediaDetails.id];
      });
      return true;
    },
    [
      dispatch,
      currentUserId,
      annotationType,
      projectId,
      enqueueSnackbar,
      datasetId,
      currentModelId,
      queryClient,
    ],
  );

  return useCallback(
    async (options?: { skipRefreshMediaDetails: boolean }) => {
      if (!datasetId) {
        return;
      }

      dispatch(draft => {
        draft.saving = true;
      });
      const saveQueue = { ...unsavedMediaStatesById };

      let newMediaStatesById: { [mediaId: number]: MediaStates } = cloneDeep(mediaStatesById);

      newMediaStatesById = Object.entries(newMediaStatesById).reduce(
        (lookup, [mediaId, states]) => {
          // Filter out annotations without defect id in case user does not complete the whole quick create defect flow
          const newAnnotations = states.annotations.filter(annotation => !!annotation.defectId);

          lookup[parseInt(mediaId)] = {
            ...states,
            annotations: newAnnotations,
          };

          return lookup;
        },
        {} as { [mediaId: number]: MediaStates },
      );

      Object.values(newMediaStatesById)
        .filter(mediaStates => mediaStates.isChanged)
        .forEach(mediaStates => {
          saveQueue[mediaStates.mediaDetails!.id] = mediaStates;
        });
      const saveResults = await Promise.all(
        Object.values(saveQueue).map(mediaStates => saveAnnotationsForMedia(mediaStates, options)),
      );
      const successCount = saveResults.filter(Boolean).length;
      if (successCount > 0 && projectId && datasetId) {
        queryClient.invalidateQueries(datasetQueryKeys.allWithFilters(projectId));
      }

      dispatch(draft => {
        draft.saving = false;
      });
      return null;
    },
    [
      datasetId,
      dispatch,
      mediaStatesById,
      projectId,
      saveAnnotationsForMedia,
      unsavedMediaStatesById,
    ],
  );
};

export const useMediaList = () => {
  const {
    state: { mediaList: mediaListState, deletedMediaIds },
  } = useImageLabelingContext();
  const mediaList = useMemo(
    () => mediaListState.filter(media => !deletedMediaIds.includes(media.id)),
    [deletedMediaIds, mediaListState],
  );
  return mediaList;
};

export const useGoToPreviousMedia = () => {
  const {
    state: { mediaIndex },
    dispatch,
  } = useImageLabelingContext();
  const saveAnnotations = useSaveAnnotations();
  return () => {
    if (mediaIndex > 0) {
      dispatch(draft => {
        draft.mediaIndex = mediaIndex - 1;
      });
      saveAnnotations();
    }
  };
};

export const useGoToNextMedia = () => {
  const {
    state: { mediaIndex },
    dispatch,
  } = useImageLabelingContext();
  const saveAnnotations = useSaveAnnotations();
  const mediaList = useMediaList();
  return () => {
    if (mediaIndex + 1 < mediaList.length) {
      dispatch(draft => {
        draft.mediaIndex = mediaIndex + 1;
      });
      saveAnnotations();
    }
  };
};

export const useCurrentMediaStates = (): MediaStates => {
  const {
    state: { mediaStatesById, mediaIndex, annotationInstance },
  } = useImageLabelingContext();
  const mediaList = useMediaList();
  return useMemo(() => {
    const mediaId = annotationInstance?.mediaId ?? mediaList[mediaIndex]?.id;
    return mediaStatesById[mediaId] ?? (defaultMediaStates as MediaStates);
  }, [annotationInstance, mediaIndex, mediaList, mediaStatesById]);
};

export const isAnnotationsChanged = (
  prevAnnotations:
    | BoxLabelingAnnotation[]
    | BitMapLabelingAnnotation[]
    | ClassificationLabelingAnnotation[]
    | PureCanvasLabelingAnnotation[],
  newAnnotations:
    | BoxLabelingAnnotation[]
    | BitMapLabelingAnnotation[]
    | ClassificationLabelingAnnotation[]
    | PureCanvasLabelingAnnotation[],
) =>
  !isEqual(
    sortBy(prevAnnotations, ['defectId']).map(item => [item.defectId, item.data]),
    sortBy(newAnnotations, ['defectId']).map(item => [item.defectId, item.data]),
  );

export const useUpdateAnnotations = () => {
  const { state: imageLabelingState, dispatch } = useImageLabelingContext();
  const { mediaDetails } = useCurrentMediaStates();
  const {
    state: { labelingType },
  } = useLabelingState();
  const colorToDefectIdMap = useColorToDefectIdMap();
  const allDefects = useDefectSelector();
  return (async (canvasAnnotations, changeType, layer) => {
    // in media details dialog, we use Reset to update canvas annotations.
    // states is already in-sync. no need to handle Reset here.
    if (changeType === AnnotationChangeType.Reset) {
      return;
    }
    if (labelingType === LabelingType.DefectBoundingBox) {
      const newAnnotations =
        changeType === AnnotationChangeType.DeleteAll
          ? []
          : canvasAnnotations.map(ann => canvasAnnotationToBoxAnnotation(ann, colorToDefectIdMap));
      dispatch(draft => {
        const mediaId = mediaDetails?.id ?? -1;
        const mediaStates = draft.mediaStatesById[mediaId] ?? {};
        mediaStates.annotations = newAnnotations;
        mediaStates.isChanged = true;
        mediaStates.mediaLevelLabel = newAnnotations.length > 0 ? MediaLevelLabel.NG : undefined;
        draft.mediaStatesById[mediaId] = mediaStates;
      });
    } else if (labelingType === LabelingType.DefectSegmentation) {
      const newAnnotations =
        changeType === AnnotationChangeType.DeleteAll
          ? []
          : await layerToBitMapAnnotationsAsync(layer, mediaDetails?.properties!, allDefects);

      const mediaId = mediaDetails?.id ?? -1;
      const mediaStates = imageLabelingState.mediaStatesById[mediaId] ?? {};
      if (!isAnnotationsChanged(mediaStates.annotations, newAnnotations)) {
        return;
      }

      dispatch(draft => {
        const mediaStates = draft.mediaStatesById[mediaId] ?? {};
        mediaStates.annotations = newAnnotations;
        mediaStates.isChanged = true;
        mediaStates.mediaLevelLabel = newAnnotations.length > 0 ? MediaLevelLabel.NG : undefined;
        draft.mediaStatesById[mediaId] = mediaStates;
      });
    }
  }) as NonNullable<MediaInteractiveCanvasProps['onAnnotationChanged']>;
};

export const isMediaSelected = (state: Draft<ImageLabelingState>, mediaId: number): boolean => {
  const mediaList = state.mediaList.filter(media => !state.deletedMediaIds.includes(media.id));
  const mediaIndex = mediaList.findIndex(m => m.id === mediaId);

  if (mediaIndex !== -1) {
    return state.mediaIndex === mediaIndex;
  }

  return false;
};
