import { ApiErrorType } from '@/api/base_api';
import { DatasetVersionId, ProjectId, RegisteredModel } from '@clef/shared/types';
import { useQuery } from '@tanstack/react-query';
import EvaluationSetAPI, { EvaluationSetItem } from '@/api/evaluation_set_api';
import { useGetSelectedProjectQuery } from '../projects';
import { MediaSplitName } from '@/constants/stats_card';
import { useGetModelMediaListInfiniteQuery } from '../modelAnalysis';

export const evaluationSetQueryKeys = {
  all: ['evaluationSet'] as const,
  list: (projectId: ProjectId) => [projectId, ...evaluationSetQueryKeys.all, 'list'] as const,
};

export const useGetProjectEvaluationSetsListQuery = (isEnabled = true) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<EvaluationSetItem[], ApiErrorType>({
    queryKey: evaluationSetQueryKeys.list(projectId),
    queryFn: async () => {
      const response = await EvaluationSetAPI.getProjectEvaluationSetModelInfo(projectId);
      const sortedEvaluationSetItems = response.data.sort((a, b) =>
        a.createdAt > b.createdAt ? 1 : -1,
      ); // old set displayed first
      return sortedEvaluationSetItems;
    },
    enabled: !!projectId && isEnabled,
  });
};

const useGetDefaultEvaluationSetsQuery = (datasetVersionId: DatasetVersionId) => {
  const { data: evaluationSets, isLoading: isEvaluationSetsLoading } =
    useGetProjectEvaluationSetsListQuery();
  return {
    data: {
      [MediaSplitName.Train]: evaluationSets?.find(
        set =>
          set.split?.splitSetName === MediaSplitName.Train &&
          datasetVersionId === set.datasetVersionId,
      ),
      [MediaSplitName.Dev]: evaluationSets?.find(
        set =>
          set.split?.splitSetName === MediaSplitName.Dev &&
          datasetVersionId === set.datasetVersionId,
      ),
      [MediaSplitName.Test]: evaluationSets?.find(
        set =>
          set.split?.splitSetName === MediaSplitName.Test &&
          datasetVersionId === set.datasetVersionId,
      ),
    },
    isLoading: isEvaluationSetsLoading,
  };
};

export const useGetDefaultEvaluationSetsCountQuery = (model: RegisteredModel) => {
  const { data: defaultEvaluationSets, isLoading: isDefaultEvaluationSetsLoading } =
    useGetDefaultEvaluationSetsQuery(model.datasetVersionId!);
  const trainSet = defaultEvaluationSets?.[MediaSplitName.Train];
  const devSet = defaultEvaluationSets?.[MediaSplitName.Dev];
  const testSet = defaultEvaluationSets?.[MediaSplitName.Test];
  const { data: trainSetMediaList, isLoading: isTrainSetMediaListLoading } =
    useGetModelMediaListInfiniteQuery(model.id, model.confidence!, trainSet);
  const { data: devSetMediaList, isLoading: isDevSetMediaListLoading } =
    useGetModelMediaListInfiniteQuery(model.id, model.confidence!, devSet);
  const { data: testSetMediaList, isLoading: isTestSetMediaListLoading } =
    useGetModelMediaListInfiniteQuery(model.id, model.confidence!, testSet);
  return {
    data: {
      [MediaSplitName.Train]: trainSetMediaList?.pages[0]?.total,
      [MediaSplitName.Dev]: devSetMediaList?.pages[0]?.total,
      [MediaSplitName.Test]: testSetMediaList?.pages[0]?.total,
    },
    isLoading:
      isDefaultEvaluationSetsLoading ||
      isTrainSetMediaListLoading ||
      isDevSetMediaListLoading ||
      isTestSetMediaListLoading,
  };
};
