import model_analysis_api from '@/api/model_analysis_api';
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useSnackbar } from 'notistack';
import { useGetSelectedProjectQuery } from '../projects';
import { modelAnalysisQueryKeys } from './queries';
import { useTypedSelector } from '@/hooks/useTypedSelector';
import project_model_api from '@/api/project_model_api';
import { RegisteredModelWithBundles } from '@/api/model_api';
import { useAtom } from 'jotai';
import {
  batchModelMetricsParamAtom,
  modelMetricsMapAtom,
} from '@/pages/model_iteration/componentsV2/atoms';
import { LabelType } from '@clef/shared/types';

export const useRunEvaluationMutation = (modelName: string, evaluationSetName: string) => {
  const { labelType, id: projectId } = useGetSelectedProjectQuery().data ?? {};
  const queryClient = useQueryClient();
  const { enqueueSnackbar } = useSnackbar();
  const [batchModelMetricParam] = useAtom(batchModelMetricsParamAtom);
  const [, setModelMetricsMap] = useAtom(modelMetricsMapAtom);
  return useMutation({
    mutationFn: async (params: { modelId: string; evaluationSetId: number; threshold: number }) => {
      const { modelId, evaluationSetId, threshold } = params;
      if (!projectId) return;
      await model_analysis_api.runEvaluation({ projectId, modelId, evaluationSetId, threshold });
      return params;
    },
    onSuccess: (
      params: { modelId: string; evaluationSetId: number; threshold: number } | undefined,
    ) => {
      if (!params) return;
      const { modelId, evaluationSetId, threshold } = params;
      enqueueSnackbar(
        labelType === LabelType.Classification
          ? t(`Evaluating model {{modelName}} on evaluation set {{evaluationSetName}}`, {
              modelName,
              evaluationSetName,
            })
          : t(
              `Evaluating model {{modelName}} with threshold {{threshold}} on evaluation set {{evaluationSetName}}`,
              { modelName, threshold, evaluationSetName },
            ),
        { variant: 'success' },
      );
      setModelMetricsMap(prev => {
        const key = `${modelId}-${evaluationSetId}-${threshold}`;
        if (!prev[key]) {
          return prev;
        }
        const newMap = { ...prev };
        delete newMap[key];
        return newMap;
      });
      projectId &&
        queryClient.invalidateQueries(
          modelAnalysisQueryKeys.metrics(projectId, modelId, evaluationSetId, threshold),
        );
      projectId &&
        queryClient.invalidateQueries(
          modelAnalysisQueryKeys.batchMetrics(projectId, batchModelMetricParam),
        );
      projectId &&
        queryClient.invalidateQueries(
          modelAnalysisQueryKeys.modelEvaluationReports(projectId, modelId),
        );
    },
    onError: (e: Error) => enqueueSnackbar(e.message, { variant: 'error' }),
  });
};

export const useModelAnalysisCreateBundleMutation = (modelName: string | undefined) => {
  const queryClient = useQueryClient();
  const { enqueueSnackbar } = useSnackbar();
  const { labelType, id: selectedProjectId } = useGetSelectedProjectQuery().data ?? {};
  const [batchModelMetricParam] = useAtom(batchModelMetricsParamAtom);
  const [, setModelMetricsMap] = useAtom(modelMetricsMapAtom);
  return useMutation({
    mutationFn: async (params: { modelId: string; evaluationSetId: number; threshold: number }) => {
      await model_analysis_api.createBundle({
        ...params,
        projectId: selectedProjectId!,
      });
      return { ...params };
    },
    onSuccess: res => {
      const { modelId, evaluationSetId, threshold } = res;
      modelName &&
        threshold !== undefined &&
        enqueueSnackbar(
          labelType === LabelType.Classification
            ? t('Successfully created bundle for model {{modelName}}', {
                modelName,
              })
            : t(
                'Successfully created bundle for model {{modelName}} with threshold {{threshold}}',
                {
                  modelName,
                  threshold,
                },
              ), // this could be an api error or frontend error
          { variant: 'success' },
        );
      setModelMetricsMap(prev => {
        const key = `${modelId}-${evaluationSetId}-${threshold}`;
        if (!prev[key]) {
          return prev;
        }
        const newMap = { ...prev };
        delete newMap[key];
        return newMap;
      });
      selectedProjectId &&
        queryClient.invalidateQueries(
          modelAnalysisQueryKeys.batchMetrics(selectedProjectId, batchModelMetricParam),
        );
      selectedProjectId &&
        queryClient.removeQueries(
          modelAnalysisQueryKeys.metrics(selectedProjectId, modelId, evaluationSetId, threshold),
        );
      selectedProjectId &&
        queryClient.invalidateQueries(
          modelAnalysisQueryKeys.modelEvaluationReports(selectedProjectId, modelId),
        );
      selectedProjectId &&
        queryClient.invalidateQueries(modelAnalysisQueryKeys.modelList(selectedProjectId));
    },
    onError: (err: any) => {
      enqueueSnackbar(
        (err?.body || err)?.message, // this could be an api error or frontend error
        { variant: 'error' },
      );
    },
  });
};

export const useUpdateBundleIsFavMutation = (modelName: string, threshold: number) => {
  const { enqueueSnackbar } = useSnackbar();
  const queryClient = useQueryClient();
  const selectedProjectId = useTypedSelector(state => state.project.selectedProjectId) ?? 0;
  return useMutation({
    mutationFn: async (params: { modelId: string; threshold: number; isFav: boolean }) => {
      const { modelId, threshold, isFav } = params;
      await project_model_api.updateModelBundleIsFav(modelId, threshold, isFav);
      return params;
    },
    onMutate: params => {
      const { modelId, threshold, isFav } = params;
      queryClient.setQueryData(
        modelAnalysisQueryKeys.modelList(selectedProjectId),
        (prev?: RegisteredModelWithBundles[]) => {
          return prev?.map(model => {
            if (model.id !== modelId) {
              return { ...model };
            } else {
              const updatedBundles = model.bundles.map(bundle => {
                if (bundle.threshold !== threshold) {
                  return { ...bundle };
                } else {
                  return { ...bundle, isFav: isFav };
                }
              });
              return { ...model, bundles: updatedBundles };
            }
          });
        },
      );
    },
    onSuccess: params => {
      params.isFav
        ? enqueueSnackbar(t(`${modelName},${threshold} is successfully marked as favorite.`), {
            variant: 'success',
          })
        : enqueueSnackbar(t(`${modelName},${threshold} is successfully removed from favorite.`), {
            variant: 'success',
          });
    },
    onError: (e: Error, params) => {
      const { threshold } = params;
      queryClient.invalidateQueries(modelAnalysisQueryKeys.modelList(selectedProjectId));
      enqueueSnackbar(
        t(`Failed to mark ${modelName},${threshold} as favorite, {{errorMessage}}`, {
          errorMessage: e.message,
        }),
        {
          variant: 'error',
          autoHideDuration: 12000,
        },
      );
    },
  });
};

export const useUpdateBundleIsDeletedMutation = (
  modelName: string,
  thresholdForMessage: number | undefined,
) => {
  const { enqueueSnackbar } = useSnackbar();
  const queryClient = useQueryClient();
  const selectedProjectId = useTypedSelector(state => state.project.selectedProjectId) ?? 0;
  return useMutation({
    mutationFn: async (params: { modelId: string; threshold: number; isDeleted: boolean }) => {
      const { modelId, threshold, isDeleted } = params;
      await project_model_api.updateModelBundleIsDeleted(
        selectedProjectId,
        modelId,
        threshold,
        isDeleted,
      );
      return params;
    },
    onMutate: params => {
      const { modelId, threshold } = params;
      queryClient.setQueryData(
        modelAnalysisQueryKeys.modelList(selectedProjectId),
        (prev?: RegisteredModelWithBundles[]) => {
          return prev?.map(model => {
            if (model.id !== modelId) {
              return { ...model };
            } else {
              const updatedBundles = model.bundles.map(bundle => {
                if (bundle.threshold !== threshold) {
                  return { ...bundle };
                } else {
                  return { ...bundle, isDeleted: true };
                }
              });
              return { ...model, bundles: updatedBundles };
            }
          });
        },
      );
    },
    onSuccess: params => {
      const { threshold, isDeleted } = params;
      if (isDeleted) {
        enqueueSnackbar(t(`${modelName},${threshold} is successfully deleted.`), {
          variant: 'success',
        });
      }
    },
    onError: (e: Error) => {
      queryClient.invalidateQueries(modelAnalysisQueryKeys.modelList(selectedProjectId));
      enqueueSnackbar(
        t(`Failed to delete ${modelName},${thresholdForMessage}, {{errorMessage}}`, {
          errorMessage: e.message,
        }),
        {
          variant: 'error',
          autoHideDuration: 12000,
        },
      );
    },
  });
};
