import React, { createContext, useContext, useEffect } from 'react';
import { Updater, useImmer } from 'use-immer';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import {
  DatasetGroupOptions,
  MediaStatusType,
  Pipeline,
  SelectMediaOption,
  ClientTrainingState,
  TrainMode,
  ClientAdvancedTrainingState,
  LabelType,
  ClientTrainLimits,
} from '@clef/shared/types';
import { useGetDatasetStatsQuery, datasetQueryKeys } from '@/serverStore/dataset';
import {
  useAutoSplitMutation,
  useGetDatasetFilterOptionsQuery,
  useDatasetExportedWithVersionsQuery,
} from '@/serverStore/dataset';
import { useDefectSelector } from '../../../store/defectState/actions';
import {
  autoSplitCoreAlgorithm,
  DefectDistribution,
  DefectDistributionWithAssignment,
  assignmentToDistributionToAssignSplitMapping,
  splitValueToNoDefectAssignment,
} from '@clef/shared/utils/auto_split_core_algorithm';
import { useWorkflowAssistantState } from '@/components/WorkflowAssistant/state';
import TransformsApi from '../../../api/transforms_api';
import DatasetAPI from '../../../api/dataset_api';
import { useDataBrowserState } from '../dataBrowserState';
import { useSnackbar } from 'notistack';
import { useProjectModelInfoQuery } from '@/serverStore/projectModels';
import { useCreateFastTrainingJobMutation, useGetModelArchSchemas } from '@/serverStore/train';
import { useQueryClient } from '@tanstack/react-query';
import { convertLabelTypeToExperimentType } from '@/utils/project_utils';

export const defaultState = {
  [TrainMode.Default]: {
    // By default, train:dev:test = 80:20:0
    splitRadio: [80, 20, 0],
    limits: undefined,
  } as ClientTrainingState,
  [TrainMode.Advanced]: {
    // By default, train:dev:test = 80:20:0
    splitRadio: [80, 20, 0],
    transforms: [],
    augmentations: [],
    defaultParameters: {},
    // Wait for useMaglevFastModelClassConfigApi to fulfill
    trainingParams: undefined,
    limits: undefined,
    currentSchema: undefined,
  } as ClientAdvancedTrainingState,
  pipelineSections: undefined as undefined | Pipeline['sections'],
};

export type TrainModelState = typeof defaultState;

/**
 * Context
 */
export const TrainingStateContext = createContext<{
  state: TrainModelState;
  dispatch: Updater<TrainModelState>;
}>({
  state: defaultState,
  dispatch: () => {},
});

export const useTrainingState = () => useContext(TrainingStateContext);

/**
 * TrainingStateProvider component with initial setup
 */
export const TrainingStateProvider: React.FC<{
  children: (state: TrainModelState) => React.ReactNode;
}> = ({ children }) => {
  const [state, dispatch] = useImmer<TrainModelState>(defaultState);

  const { data: selected } = useGetSelectedProjectQuery();
  const { data: modelInfo } = useProjectModelInfoQuery(selected?.id);
  const { data: modelArchSchemas } = useGetModelArchSchemas();

  const { dispatch: workflowAssistantDispatch } = useWorkflowAssistantState();

  // default train initialization
  useEffect(() => {
    if (modelInfo && modelArchSchemas && selected?.labelType) {
      if (!modelArchSchemas) {
        return;
      }

      dispatch(draft => {
        draft.default.limits = modelArchSchemas[0].datasetLimits;
      });

      workflowAssistantDispatch(state => {
        state.labelingRequiredPoints = modelArchSchemas[0].datasetLimits.minLabeledMedia;
      });
    }
  }, [dispatch, selected, state, modelInfo, workflowAssistantDispatch, modelArchSchemas]);

  return (
    <TrainingStateContext.Provider value={{ state, dispatch }}>
      {children(state)}
    </TrainingStateContext.Provider>
  );
};

const getPipelineId = async (
  apiPipelineSections: Pipeline['sections'],
  projectId: number,
  jobName: string,
) => {
  if (!apiPipelineSections.train.length) return undefined;
  const response = await TransformsApi.createPipeline(
    projectId,
    jobName + '-pipeline.yaml',
    apiPipelineSections,
  );
  return response.data;
};

/**
 * Train function
 */
export const useTrainModel = () => {
  const { data: selectedProject } = useGetSelectedProjectQuery();
  const { data: allFilters } = useGetDatasetFilterOptionsQuery();
  const {
    state: { rightSidebar },
    dispatch: dispatchDataBrowserState,
  } = useDataBrowserState();
  const postCreateFastTrainingJob = useCreateFastTrainingJobMutation();

  const queryClient = useQueryClient();

  // 'Split' is in columnFilterMap, the project has split in datasetContent's splitSet column
  // 'split' is in fieldFilterMap, the project has split metadata
  // Always prefer to use 'Split' column in case user manually creates 'split' or 'Split' metadata
  const splitFilterOption = allFilters
    ? allFilters.find(value => value.filterName === 'Split' && value.filterType === 'column') ||
      allFilters.find(value => value.filterName === 'split')
    : undefined;
  const splitFilter = splitFilterOption?.value ? Object.keys(splitFilterOption.value) : [];
  const splitFilterByIds: number[] =
    splitFilterOption?.filterType === 'column'
      ? splitFilter.map(s => splitFilterOption.value![s] as number)
      : [];

  const selectMediaOption: SelectMediaOption | undefined = splitFilterOption
    ? {
        fieldFilterMap:
          splitFilterOption.filterType === 'field' && splitFilter?.length
            ? { [splitFilterOption.fieldId!]: { NOT_CONTAIN_ANY: splitFilter } }
            : {},
        columnFilterMap: {
          datasetContent: {
            mediaStatus: { CONTAINS_ANY: [MediaStatusType.Approved] },
            ...(splitFilterOption.filterType === 'column' &&
              splitFilterByIds.length && { splitSet: { NOT_CONTAIN_ANY: splitFilterByIds } }),
          },
        },
        selectedMedia: [],
        unselectedMedia: [],
        isUnselectMode: true,
      }
    : undefined;

  const { data: mediaGroupByDefectDistribution } = useGetDatasetStatsQuery(
    {
      selectOptions: selectMediaOption,
      groupOptions: [DatasetGroupOptions.DEFECT_DISTRIBUTION],
    },
    !!selectMediaOption,
  );

  const { data: mediaGroupByDefectType } = useGetDatasetStatsQuery(
    {
      selectOptions: selectMediaOption,
      groupOptions: [DatasetGroupOptions.DEFECT_TYPE],
    },
    !!selectMediaOption,
  );

  const allDefects = useDefectSelector();

  // filter out defects that does not have count
  const defectList =
    allDefects && mediaGroupByDefectType
      ? allDefects.filter(({ id }) =>
          mediaGroupByDefectType.find(defectType => defectType.defect_id === id),
        )
      : undefined;

  const { enqueueSnackbar } = useSnackbar();
  const autoSplit = useAutoSplitMutation();

  const { data: datasetExported } = useDatasetExportedWithVersionsQuery({
    withCount: true,
    includeNotCompleted: true,
    includeFastEasy: true,
  });

  return async (
    state: ClientTrainingState,
    mode: TrainMode,
    enableAutoSplit: boolean = true,
    version?: number,
    modelName?: string,
    modelDetails?: string,
  ) => {
    if (!selectedProject?.datasetId) return;
    const { splitRadio, trainingParams, transforms, augmentations } = state;

    const currentSnapshot = datasetExported?.datasetVersions?.find(v => v.version === version);
    const datasetVersionId = currentSnapshot?.id;

    try {
      /**
       * Step 1: Auto split with splitRadio in state
       */
      if (defectList && mediaGroupByDefectDistribution && enableAutoSplit && !version) {
        const defectIdList: number[] = defectList.map(defect => defect.id);
        const defectDistributions: DefectDistribution[] = (
          mediaGroupByDefectDistribution as DefectDistribution[]
        ).filter(stats => stats.count && stats.defect_distribution);
        const numberOfMediaWithNoDefects =
          mediaGroupByDefectDistribution.find(stats => !stats.defect_distribution)?.count || 0;

        const distributionWithDefectResult = autoSplitCoreAlgorithm(
          defectIdList,
          defectDistributions,
          defectIdList.reduce(
            (acc, id) => ({ ...acc, [id]: splitRadio }),
            {} as { [defectId: number]: [number, number, number] },
          ),
        );
        const defectDistributionWithAssignment: DefectDistributionWithAssignment[] = [
          ...distributionWithDefectResult,
          ...(numberOfMediaWithNoDefects
            ? [
                {
                  count: numberOfMediaWithNoDefects,
                  defect_distribution: null,
                  assignment: splitValueToNoDefectAssignment(
                    numberOfMediaWithNoDefects,
                    splitRadio,
                  ),
                },
              ]
            : []),
        ];

        await autoSplit.mutateAsync({
          selectOptions: selectMediaOption!,
          splitByDefectDistribution: assignmentToDistributionToAssignSplitMapping(
            defectDistributionWithAssignment,
          ),
        });
      }

      /**
       * Step 1.1: check image count in train set
       */
      const stats = await DatasetAPI.getStats(
        selectedProject.datasetId,
        selectedProject.id,
        {},
        [DatasetGroupOptions.SPLIT],
        version,
      );

      const trainSplit = stats.find(split => {
        return split.split === 'train';
      });

      if (!trainSplit || trainSplit?.count < 2) {
        enqueueSnackbar(
          t(
            `At least 2 images are required in the train split; there are ${
              trainSplit?.count || 0
            } images.`,
          ),
          {
            variant: 'error',
            autoHideDuration: 12000,
          },
        );
        return;
      }

      /**
       * Step 2: TODO: Create transform and augmentation pipeline
       */
      let updatedTrainingParams = {};
      const projectId = selectedProject?.id;
      if (mode === TrainMode.Advanced) {
        const apiTransformPipelineSections = {
          train: transforms ?? [],
          valid: [],
        };
        const apiAugmentationsPipelineSections = {
          train: augmentations ?? [],
          valid: [],
        };
        const [preprocessingPipelineId, augmentationPipelineId] = await Promise.all([
          getPipelineId(apiTransformPipelineSections, projectId, 'preprocessing'),
          getPipelineId(apiAugmentationsPipelineSections, projectId, 'augmentation'),
        ]);
        updatedTrainingParams = {
          // ...trainingParams,
          projectId,
          hyperParams: {
            ...trainingParams?.hyperParams,
            preprocessingPipelineId,
            augmentationPipelineId,
          },
          datasetVersionId,
          modelName,
          modelDetails,
          experimentType: convertLabelTypeToExperimentType(selectedProject.labelType!),
          archName: trainingParams?.hyperParams?.model.archName,
        };
      } else {
        updatedTrainingParams = {
          projectId,
          experimentType: convertLabelTypeToExperimentType(selectedProject.labelType!),
        };
      }

      /**
       * Step 3: Kickoff train model and update projectModelInfo
       */
      await postCreateFastTrainingJob.mutateAsync(updatedTrainingParams);
      if (rightSidebar === undefined) {
        dispatchDataBrowserState(draft => {
          draft.rightSidebar = 'model_performance';
        });
      }
      if (mode === TrainMode.Default) {
        enqueueSnackbar(t('Start Training.'), { variant: 'success' });
      }
      if (datasetVersionId) {
        queryClient.invalidateQueries(
          datasetQueryKeys.snapshotModels(projectId, { datasetVersionId }),
        );
      }
      queryClient.invalidateQueries(
        datasetQueryKeys.exportedWithVersions(projectId, {
          withCount: true,
          includeNotCompleted: true,
          includeFastEasy: true,
        }),
      );
    } catch (e) {
      if (mode === TrainMode.Default) {
        enqueueSnackbar(e.message, { variant: 'error' });
      } else {
        // let the parallel process handle the error
        throw e;
      }
    }
  };
};

export const useIsShowWarnAlert = (
  pipelineSections?: Pipeline['sections'],
  limitsProp?: ClientTrainLimits,
) => {
  const limits = limitsProp;
  const { labelType } = useGetSelectedProjectQuery().data ?? {};
  if (!pipelineSections) return false;
  if (!limits || !limits.largeImage) return false;
  const section = pipelineSections.train.find(item =>
    item.params.find(param => param.name === 'width'),
  );
  if (!section) return false;
  const width = section.params.find(param => param.name === 'width')!.value;
  const height = section.params.find(param => param.name === 'height')!.value;
  const area = width * height;
  return (
    area > limits.largeImage.thresholdArea &&
    area <= limits.largeImage.maxArea &&
    (labelType === LabelType.Segmentation || labelType === LabelType.BoundingBox)
  );
};

// get those data for custom training configuration initialization
export const useGotDataForCustomTraining = () => {
  const { id: projectId, labelType } = useGetSelectedProjectQuery().data ?? {};
  const { data: modelInfo } = useProjectModelInfoQuery(projectId);
  const { data: modelArchSchemas } = useGetModelArchSchemas();

  return !!(modelInfo && !!modelArchSchemas && labelType);
};
