import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { useRef, useCallback } from 'react';
import { useInstantLearningState } from './state';
import { useLocalStorage } from '@clef/client-library';
import { useQueryClient } from '@tanstack/react-query';
import { projectModelQueryKeys } from '@/serverStore/projectModels';
import { datasetQueryKeys } from '@/serverStore/dataset';
import { MediaDetailsWithPrediction } from '@clef/shared/types';

// If the media count is greater than 1, there will be 2 SSEs with the first one to refresh the current viewed
// media and another one to refresh all media
export const shouldRefreshSingleMedia = (
  mediaId: number | undefined,
  modelId: string | undefined,
  totalMediaCount: number | undefined,
) => mediaId && modelId && totalMediaCount && totalMediaCount > 1;

// TODO: consider moving `openIterating` to jotai so that we can handle this inside model status polling.
export const useHandleTrainingComplete = () => {
  const { id: projectId, datasetId } = useGetSelectedProjectQuery().data ?? {};
  const { dispatch } = useInstantLearningState();
  const [prevRegisterModel, setPrevRegisterModel] = useLocalStorage('prev_register_model');
  const prevRegisterModelRef = useRef(prevRegisterModel);
  prevRegisterModelRef.current = prevRegisterModel;
  const queryClient = useQueryClient();

  return useCallback(async () => {
    if (!projectId) return;
    if (!prevRegisterModelRef.current) {
      setPrevRegisterModel(true);
      dispatch(draft => {
        draft.openIterating = true;
      });
    }
    queryClient.invalidateQueries(projectModelQueryKeys.list(projectId));
    queryClient.invalidateQueries(projectModelQueryKeys.modelInfo(projectId));
    datasetId &&
      queryClient.setQueriesData<MediaDetailsWithPrediction>(
        datasetQueryKeys.mediaDetails(datasetId),
        prev =>
          prev
            ? ({
                ...(prev ?? {}),
                predictionLabel: null,
              } as typeof prev)
            : undefined,
      );
    datasetId && queryClient.invalidateQueries(datasetQueryKeys.mediaDetails(datasetId));
  }, [dispatch, projectId, queryClient, setPrevRegisterModel]);
};

export default useHandleTrainingComplete;
