import React from 'react';
import { useQuery, useQueryClient } from '@tanstack/react-query';
import ProjectModelAPI from '@/api/project_model_api';
import {
  ProjectId,
  CurrentModelInfo,
  RegisteredModel,
  ModelStatus,
  MediaDetailsWithPrediction,
  LabelType,
} from '@clef/shared/types';
import project_model_api, { ModelStatusResponse } from '@/api/project_model_api';
import { isModelInEndState, isModelTrainingSuccessful } from '@/store/projectModelInfoState/utils';
import { useGetSelectedProjectQuery } from '../projects';
import { useSnackbar } from 'notistack';
import { ApiErrorType } from '@/api/base_api';
import { datasetQueryKeys } from '../dataset';
import { modelAnalysisQueryKeys } from '../modelAnalysis';
import { layoutQueryKeys } from '../layout';
import { evaluationSetQueryKeys } from '../evaluationSets';
import { jobQueryKeys } from '../jobs';
import { Typography, makeStyles } from '@material-ui/core';
import CLEF_PATH from '@/constants/path';
import { useHistory } from 'react-router';
import SnackbarCloseButton from '@/components/SnackbarCloseButton';
import { useModels } from '@/hooks/useModels';
import { Button, useLocalStorage } from '@clef/client-library';
import { endpointQueryKeys } from '../endpoints';
import { useModelAnalysisEnabled } from '@/hooks/useFeatureGate';
import { getDateNumber } from '@clef/shared/utils';
import { useDatasetExportedWithVersionsQuery } from '@/serverStore/dataset';
import { SelectedModelIdForProjectsStorageKey } from '@/constants/data_browser';
import { useSetAtom } from 'jotai';
import { thresholdForPredictAtom } from '@/uiStates/projectModels/pageUIStates';

export const projectModelQueryKeys = {
  all: ['projectModels'] as const,
  list: (projectId: ProjectId) => [projectId, ...projectModelQueryKeys.all, 'list'] as const,
  modelInfo: (projectId: ProjectId) =>
    [projectId, ...projectModelQueryKeys.all, 'modelInfo'] as const,
  modelStatus: (projectId?: ProjectId, modelId?: string) =>
    [projectId, ...projectModelQueryKeys.all, 'status', modelId] as const,
};

export const useProjectModelInfoQuery = (projectId: ProjectId = 0) => {
  return useQuery({
    queryKey: projectModelQueryKeys.modelInfo(projectId),
    queryFn: async () => {
      const apiResponse = await ProjectModelAPI.getProjectModelInfo(projectId);
      return apiResponse.data;
    },
    enabled: !!projectId,
  });
};

export const useGetProjectModelListQuery = (isEnabled = true) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<RegisteredModel[], ApiErrorType>({
    queryKey: projectModelQueryKeys.list(projectId),
    queryFn: () => ProjectModelAPI.getModels(projectId),
    enabled: !!projectId && isEnabled,
  });
};

const useGetProjectModelListWithProjectIdQuery = (projectId: ProjectId, isEnabled = true) => {
  return useQuery<RegisteredModel[], ApiErrorType>({
    queryKey: projectModelQueryKeys.list(projectId),
    queryFn: () => ProjectModelAPI.getModels(projectId),
    enabled: !!projectId && isEnabled,
  });
};

/**
 * This query is used for VP projects default selected model in build page
 */
export const useVPLatestModel = (): {
  loading: boolean;
  latestModel: RegisteredModel | null;
  error: ApiErrorType | null;
} => {
  const {
    data: projectModelsData,
    isLoading: projectModelsLoading,
    error: projectModelsError,
  } = useGetProjectModelListQuery();

  let latestModel = null;
  if (projectModelsData && projectModelsData.length > 0) {
    latestModel = projectModelsData?.sort(
      (a, b) => getDateNumber(b.createdAt) - getDateNumber(a.createdAt),
    )[0];
  }
  return {
    loading: projectModelsLoading,
    error: projectModelsError,
    latestModel,
  };
};

export const useExampleProjectModelInfoQuery = (
  projectId: ProjectId,
): RegisteredModel | undefined => {
  const { data: projectModelsData } = useGetProjectModelListWithProjectIdQuery(projectId);

  const latestModel = projectModelsData?.sort(
    (a, b) => getDateNumber(b.createdAt) - getDateNumber(a.createdAt),
  )[0];
  return latestModel;
};

export const useCurrentProjectModelInfoQuery = (): CurrentModelInfo => {
  const { id: projectId, labelType } = useGetSelectedProjectQuery().data ?? {};

  const [selectedModelIdForProjects] = useLocalStorage<{
    [projectId: number]: string;
  }>(SelectedModelIdForProjectsStorageKey);
  const { data: projectModelsData } = useGetProjectModelListQuery();
  const currentModelInfo = projectId
    ? projectModelsData?.find(item => item.id === selectedModelIdForProjects?.[projectId])
    : undefined;

  const { data: datasetExported } = useDatasetExportedWithVersionsQuery({
    includeNotCompleted: true,
    includeFastEasy: true,
  });
  const projectModelInfo = currentModelInfo
    ? ({
        ...currentModelInfo,
        versionedDatasetContentId: datasetExported?.datasetVersions?.find(
          e => e.id === currentModelInfo?.datasetVersionId,
        )?.version,
      } as CurrentModelInfo)
    : {};

  const { latestModel: vpLatestModel } = useVPLatestModel();
  return labelType === LabelType.SegmentationInstantLearning
    ? {
        id: vpLatestModel?.id,
      } ?? {}
    : projectModelInfo;
};

const useStyles = makeStyles(theme => ({
  link: {
    verticalAlign: 'baseline',
    color: theme.palette.primary.main,
    '&:hover': {
      color: theme.palette.primary.dark,
      textDecoration: 'underline',
      cursor: 'pointer',
    },
  },
  snackbarButtonText: {
    color: theme.palette.common.white,
  },
}));

export const useModelStatusQuery = (projectId?: ProjectId, modelId?: string) => {
  const queryClient = useQueryClient();
  const { datasetId, orgId } = useGetSelectedProjectQuery().data ?? {};
  const enableModelAnalysis = useModelAnalysisEnabled();
  const [failedJobLastSeenDateMap, setFailedJobLastSeenDateMap] =
    useLocalStorage<Record<ProjectId, string>>('failedJobLastSeenDate');
  const { findModels } = useModels();
  const [selectedModelIdForProjects, setSelectedModelIdForProjects] = useLocalStorage<{
    [projectId: number]: string;
  }>(SelectedModelIdForProjectsStorageKey);
  const setThresholdForPredict = useSetAtom(thresholdForPredictAtom);

  const model = findModels(modelId);
  const modelName = model?.modelName;
  const { enqueueSnackbar, closeSnackbar } = useSnackbar();
  const history = useHistory();
  const styles = useStyles();
  const res = useQuery({
    queryKey: projectModelQueryKeys.modelStatus(projectId, modelId),
    queryFn: async ctx => {
      if (!projectId || !modelId) {
        return;
      }
      // pass last updatedAt for backend polling
      const lastResponse = queryClient.getQueryData<ModelStatusResponse>(ctx.queryKey);
      const { updatedAt: lastUpdatedAt } = lastResponse ?? {};
      const newResponse = await project_model_api.getModelStatus(projectId, modelId, lastUpdatedAt);

      // handling status change
      const { status, metricsReady, updatedAt } = newResponse ?? {};
      if (lastUpdatedAt && updatedAt !== lastUpdatedAt) {
        if (status === ModelStatus.Failed) {
          queryClient.invalidateQueries(projectModelQueryKeys.list(projectId));
          enqueueSnackbar(t('Model training failed.'), {
            variant: 'error',
            autoHideDuration: 12000,
            action: key => {
              return (
                <>
                  <Button
                    id="failed-view-log-button"
                    className={styles.snackbarButtonText}
                    onClick={() => {
                      projectId &&
                        setFailedJobLastSeenDateMap({
                          ...failedJobLastSeenDateMap,
                          [projectId]: new Date().toString(),
                        });
                      history.push(`${CLEF_PATH.modelsV2.failedJobs}?modelId=${modelId}`);
                      // retrainModel.mutate();
                      closeSnackbar(key);
                    }}
                  >
                    {t('View Log')}
                  </Button>
                  <SnackbarCloseButton snackbarKey={key} />
                </>
              );
            },
          });
        } else if (status === ModelStatus.Stopped) {
          enqueueSnackbar(t('Model training stopped.'), {
            variant: 'warning',
            autoHideDuration: 12000,
          });
        } else if (isModelTrainingSuccessful(status, metricsReady)) {
          enableModelAnalysis
            ? enqueueSnackbar(
                <Typography>
                  {t('{{modelName}} training is completed.', {
                    modelName: modelName || 'Model',
                  })}
                </Typography>,
                {
                  variant: 'success',
                  autoHideDuration: 12000,
                  action: key => {
                    return (
                      <>
                        <Button
                          id="succeed-view-report-button"
                          className={styles.snackbarButtonText}
                          onClick={() => {
                            history.push(`${CLEF_PATH.modelsV2.list}?modelId=${modelId}`);
                            closeSnackbar(key);
                          }}
                        >
                          {t('View Report')}
                        </Button>
                        <SnackbarCloseButton snackbarKey={key} />
                      </>
                    );
                  },
                },
              )
            : enqueueSnackbar(<Typography>{t('Model training is completed.')}</Typography>, {
                variant: 'success',
                autoHideDuration: 12000,
              });

          setSelectedModelIdForProjects({
            ...selectedModelIdForProjects,
            [projectId]: modelId,
          });
          datasetId &&
            queryClient.setQueriesData<MediaDetailsWithPrediction>(
              datasetQueryKeys.mediaDetails(datasetId),
              prev => (prev ? { ...prev, predictionLabel: null } : prev),
            );
          datasetId && queryClient.invalidateQueries(datasetQueryKeys.mediaDetails(datasetId));
          projectId && queryClient.invalidateQueries(datasetQueryKeys.allWithFilters(projectId));
          queryClient.invalidateQueries(datasetQueryKeys.exportedWithVersions(projectId));
          queryClient.invalidateQueries(projectModelQueryKeys.list(projectId));
          queryClient.invalidateQueries(projectModelQueryKeys.modelInfo(projectId));
          queryClient.invalidateQueries(modelAnalysisQueryKeys.metrics(projectId, modelId));
          queryClient.invalidateQueries(modelAnalysisQueryKeys.mediaList(projectId, modelId));
          queryClient.invalidateQueries(layoutQueryKeys.list(projectId));
          queryClient.invalidateQueries(modelAnalysisQueryKeys.modelList(projectId));
          queryClient.invalidateQueries(evaluationSetQueryKeys.list(projectId));
          orgId && queryClient.invalidateQueries(jobQueryKeys.jobDetail(orgId, projectId, modelId));
          queryClient.invalidateQueries(endpointQueryKeys.bundleList(projectId));
          setThresholdForPredict(undefined);
        }
      }
      // end of handling status change

      return newResponse;
    }, // end of queryFn
    // polling
    refetchInterval: (data, query) => {
      // stop polling on any error
      if (query.state.error) {
        return false;
      }
      // stop polling if the model training ended
      const { status, metricsReady } = data ?? {};
      if (isModelInEndState(status, metricsReady)) {
        return false;
      }
      // we expected to refetch after 0ms, but 0 means stop polling in react query.
      // so we set 1ms here.
      return 1;
    },
    enabled: Boolean(projectId && modelId),
  });

  return res;
};

export const useValidProjectModels = () => {
  const {
    data: projectModelsData,
    isLoading: projectModelsLoading,
    error: projectModelsError,
  } = useGetProjectModelListQuery();

  const savedModels = projectModelsData?.filter(model => model.modelName);

  return {
    loading: projectModelsLoading,
    error: projectModelsError,
    models: projectModelsData,
    savedModels,
  };
};

export const useSucceedProjectModels = () => {
  const { data: projectModelsData } = useGetProjectModelListQuery();

  const succeedModels = projectModelsData
    ?.filter(model => isModelTrainingSuccessful(model.status, model.metricsReady))
    .sort(model => getDateNumber(model.updatedAt));

  return {
    models: projectModelsData,
    succeedModels,
  };
};
