import { useCallback, useState } from 'react';
import { InferenceExecutorResult, ProjectId } from '@clef/shared/types';
import { uuid4 } from '@sentry/utils';
import { uploadMediaAndGetPrediction } from '../../../store/uploadState/utils';
import { useSnackbar } from 'notistack';
import { getPredictionFromCloudInference } from '@/utils';
import { useQueryClient } from '@tanstack/react-query';
import { usageQueryKey } from '@/serverStore/usage';
import { useTypedSelector } from '@/hooks/useTypedSelector';

export interface ProcessedFile {
  file: File;
  uuid?: string;
  s3url?: string;
}

const maxParallelUploads = 5;

export const isProcessedFile = (obj: object): obj is ProcessedFile =>
  obj.hasOwnProperty('file') && obj.hasOwnProperty('uuid') && obj.hasOwnProperty('s3url');

export const processFile = (file: File): ProcessedFile => ({
  uuid: uuid4(),
  file,
  s3url: undefined,
});

export type GetPredictions = (acceptedFiles: File[]) => void;

export interface LivePrediction {
  imgSrc?: string;
  result?: InferenceExecutorResult;
  id: string;
  loading: boolean;
  s3url?: string;
  error?: string;
  name: string;
}

export const useLivePredictions = (
  projectId: ProjectId | undefined,
  datasetId: number | undefined,
  selectedModelId?: string,
) => {
  const [livePredictions, setLivePredictions] = useState<LivePrediction[]>([]);
  const { enqueueSnackbar } = useSnackbar();
  const orgId = useTypedSelector(state => state.login.user)?.orgId!;

  const updatePrediction = useCallback((prediction: LivePrediction) => {
    setLivePredictions((oldPredictions: LivePrediction[]) => {
      const index = oldPredictions.findIndex(oldPrediction => oldPrediction.id === prediction.id);

      if (index !== -1) {
        return [
          ...oldPredictions.slice(0, index),
          { ...oldPredictions[index], ...prediction },
          ...oldPredictions.slice(index + 1),
        ];
      } else {
        return [prediction].concat(oldPredictions);
      }
    });
  }, []);

  const getLivePrediction = useCallback(
    async (imageFile: File | ProcessedFile) => {
      if (projectId && datasetId && selectedModelId) {
        if (!isProcessedFile(imageFile)) {
          imageFile = processFile(imageFile);
        }

        const { uuid, file, s3url } = imageFile;
        const prediction: LivePrediction = {
          id: uuid || uuid4(),
          imgSrc: URL.createObjectURL(file),
          s3url,
          loading: true,
          result: undefined,
          error: undefined,
          name: file.name,
        };

        updatePrediction(prediction);

        try {
          const predictionResult = await uploadMediaAndGetPrediction(
            projectId,
            datasetId,
            file,
            selectedModelId,
            s3url,
          );

          updatePrediction({
            ...prediction,
            loading: false,
            result: predictionResult.result,
            s3url: predictionResult.s3url,
          });
        } catch (error) {
          const errorMessage = error?.message || 'Something went wrong with the prediction';
          enqueueSnackbar(errorMessage, {
            variant: 'error',
          });

          updatePrediction({
            ...prediction,
            loading: false,
            error: errorMessage,
          });
        }
      }
    },
    [datasetId, enqueueSnackbar, projectId, selectedModelId, updatePrediction],
  );

  const queryClient = useQueryClient();

  const getImageFilePredictions = useCallback(
    async (imageFiles: (File | ProcessedFile)[] | (File | ProcessedFile)) => {
      const files = Array.isArray(imageFiles) ? imageFiles : [imageFiles];
      await Promise.all(
        Array.from({ length: maxParallelUploads }, (_, initialIndex) => {
          const partitionedFiles = files.filter(
            (_, index) => index % maxParallelUploads === initialIndex,
          );
          return partitionedFiles.reduce(async (accPromise, file) => {
            await accPromise;
            return getLivePrediction(file);
          }, Promise.resolve());
        }),
      );
      queryClient.invalidateQueries(usageQueryKey.summary(orgId));
    },
    [orgId, getLivePrediction, queryClient],
  );

  const refreshImageFilePredictions = useCallback(() => {
    if (projectId && datasetId && selectedModelId) {
      livePredictions.forEach(async prediction => {
        try {
          if (!prediction.s3url) {
            throw Error(`No associated s3 url for image ${prediction.name}`);
          }

          const predictionResult = await getPredictionFromCloudInference(
            prediction.name,
            prediction.s3url,
            projectId,
            selectedModelId,
          );

          updatePrediction({
            ...prediction,
            loading: false,
            result: predictionResult,
          });
        } catch (error) {
          const errorMessage = error?.message || 'Something went wrong with the prediction';
          enqueueSnackbar(errorMessage, {
            variant: 'error',
          });

          updatePrediction({
            ...prediction,
            loading: false,
            error: errorMessage,
          });
        }
      });
    }
  }, [datasetId, enqueueSnackbar, livePredictions, projectId, selectedModelId, updatePrediction]);

  const clearLivePredictions = useCallback(() => {
    setLivePredictions([]);
  }, []);

  return {
    livePredictions,
    setLivePredictions,
    clearLivePredictions,
    getImageFilePredictions,
    refreshImageFilePredictions,
  };
};
