import { getClosestMultiple } from '@clef/shared/utils/math';
import {
  Transform,
  TransformParams,
  TransformParamValue,
  TransformId,
  Pipeline,
  Media,
  TrainingParamsModel,
  LabelType,
  TransformType,
  ProjectId,
  ClientTrainLimits,
  HyperParamsSchema,
  ModelArchSchemaItem,
} from '@clef/shared/types';
import { TransformUISchemaOption } from '../types/client';
import { TRANSFORMS_UI_SCHEMA } from '../constants/model_train';
import TransformsApi from '../api/transforms_api';
import { JSONSchema7Type } from 'json-schema';
import { isEmpty, isNil, some } from 'lodash';
import {
  ModelArchModelSizeDisplayName,
  ModelSizeToArchNameForTooltip,
  customTrainingDefaultModelArchName,
} from '../constants/model';
import { ModelArch } from '@clef/shared/types/model_arch';

export const getMultiplesInput = (
  fieldValue: number,
  transformMultiples: number = 1,
  siblingValue: number = 0,
): number => {
  if (!siblingValue && transformMultiples === 1) return fieldValue;
  if (!siblingValue) return getClosestMultiple(fieldValue, transformMultiples);

  const max = Math.max(fieldValue, siblingValue);
  const min = Math.min(fieldValue, siblingValue);

  const closestMultiple = getClosestMultiple(Number(max) - Number(min), transformMultiples);
  return fieldValue > siblingValue
    ? closestMultiple + siblingValue
    : siblingValue - closestMultiple;
};

export const getCropSibling = (fieldName: string): string => {
  return fieldName.replace(
    fieldName.includes('min') ? 'min' : 'max',
    fieldName.includes('min') ? 'max' : 'min',
  );
};

export const getTransformDimensions = (
  transform: { id: TransformId; params: TransformParamValue[] },
  transformParams: Transform[],
): {
  width: number;
  height: number;
} => {
  const transformName =
    transformParams.find(transParam => transParam.id === transform.id)?.name || '';

  let width;
  let height;

  switch (transformName) {
    case 'RescaleWithPadding':
    case 'Resize':
      width = transform.params.find(param => param.name === 'width')?.value;
      height = transform.params.find(param => param.name === 'height')?.value;
      break;
    case 'Crop':
      width =
        transform.params.find(param => param.name === 'x_max')?.value -
        transform.params.find(param => param.name === 'x_min')?.value;
      height =
        transform.params.find(param => param.name === 'y_max')?.value -
        transform.params.find(param => param.name === 'y_min')?.value;
      break;
  }

  return {
    width,
    height,
  };
};

export const getMaxTransformPixels = (
  labelType?: string | null,
  largeImageLimits?: ClientTrainLimits['largeImage'],
): number => {
  if (largeImageLimits) {
    return largeImageLimits.maxArea;
  }
  const defaultMaxTransformPixels = 1500 * 1500;
  const segmentationMaxTransformPixels = 10000 * 10000;
  const objectDetectionMaxTransformPixels = 6000 * 6000;
  return labelType === LabelType.Segmentation
    ? segmentationMaxTransformPixels
    : labelType === LabelType.BoundingBox
    ? objectDetectionMaxTransformPixels
    : defaultMaxTransformPixels;
};

export const isSingleSliderKey = (key: string): boolean => {
  return (
    !key.match(/lower_limit$/) &&
    !key.match(/upper_limit$/) &&
    key !== 'p' &&
    key !== 'always_apply'
  );
};

export const isDropdownKey = (key: string): boolean => {
  return !!key.match(/border_mode$/) || !!key.match(/interpolation$/);
};

export const groupTransforms = (
  apiTransforms: Transform[],
  sections: Pipeline['sections'],
): { sectionTransforms: TransformParams[]; sectionAugmentations: TransformParams[] } => {
  const allTransforms = apiTransforms.filter(t => TRANSFORMS_UI_SCHEMA[1].hasOwnProperty(t.name));
  const sectionTransforms = sections.train.filter(t =>
    allTransforms.find(transform => transform.id === t.id),
  );

  const allAugmentations = apiTransforms.filter(t =>
    TRANSFORMS_UI_SCHEMA[2].hasOwnProperty(t.name),
  );
  const sectionAugmentations = sections.train.filter(t =>
    allAugmentations.find(augmentation => augmentation.id === t.id),
  );

  return { sectionTransforms, sectionAugmentations };
};

export const updateSections = (
  newSections: Pipeline['sections'],
  oldSections: Pipeline['sections'],
  sectionType: TransformType,
  transformParams: Transform[],
): {
  upgradedSections: Pipeline['sections'];
  upgradedTransforms: TransformParams[];
  upgradedAugmentations: TransformParams[];
} => {
  const newSectionsByType = groupTransforms(transformParams, newSections);
  const oldSectionsByType = groupTransforms(transformParams, oldSections);

  let upgradedSections: Pipeline['sections'] = oldSections;
  let upgradedTransforms: TransformParams[] = oldSectionsByType.sectionTransforms;
  let upgradedAugmentations: TransformParams[] = oldSectionsByType.sectionAugmentations;

  if (sectionType == TransformType.TRANSFORM) {
    upgradedSections = {
      train: [...oldSectionsByType.sectionAugmentations, ...newSectionsByType.sectionTransforms],
      valid: [],
    };
    upgradedTransforms = newSectionsByType.sectionTransforms;
  } else if (sectionType == TransformType.AUGMENTATION) {
    upgradedSections = {
      train: [...oldSectionsByType.sectionTransforms, ...newSectionsByType.sectionAugmentations],
      valid: [],
    };
    upgradedAugmentations = newSectionsByType.sectionAugmentations;
  }

  return { upgradedSections, upgradedTransforms, upgradedAugmentations };
};

export const getRandomMedia = (media: Media[]) => {
  const shuffledMedia = media.sort(() => 0.5 - Math.random());
  return shuffledMedia.slice(0, 5);
};

export const extractNewTransform = (
  transform: Transform,
  rule?: { [key: string]: { default: number } },
  checkValueValid?: boolean,
) => {
  const initialParamValues = Object.values(transform.paramsDescription).map(
    ({ name, jsonSchema }) => ({
      name,
      value: rule?.[name]?.default ?? jsonSchema.default,
    }),
  );
  if (checkValueValid && some(initialParamValues, param => param.value === null)) {
    return null;
  }
  return { name: transform.name, id: transform.id, params: initialParamValues };
};

export const getAppliedResizeOptions = (
  resizeOptions: TransformUISchemaOption[],
  transformParams: Transform[],
  sections: Pipeline['sections'],
) => {
  const resizeOptionNames = resizeOptions.map(option => option.name);
  const resizeTransformIds = transformParams
    .filter(t => resizeOptionNames.includes(t.name as any))
    .map(t => t.id);

  return sections.train.filter(section => !resizeTransformIds.includes(section.id));
};

export const handleSectionUpdate = (
  sections: Pipeline['sections'],
  transformId: string,
  params: TransformParamValue,
) => {
  return {
    ...sections,
    train: sections.train.map(section =>
      section.id === transformId
        ? {
            ...section,
            params: section.params.map(param =>
              param.name === params.name ? { ...param, value: params.value } : param,
            ),
          }
        : section,
    ),
  };
};

export const getMultipleAdjustedPipelineSections = (
  pipelineSections: Pipeline['sections'],
  transformParams: Transform[],
  transformMultiples = 0,
) => ({
  train: pipelineSections.train.map(transform => {
    const transformName =
      transformParams.find(transParam => transParam.id === transform.id)?.name || '';

    if (!transformName || transformName === 'AutoResize') {
      return transform;
    }

    const isResizeOrRescale = transformName === 'Resize' || transformName === 'RescaleWithPadding';
    const { width, height } = getTransformDimensions(transform, transformParams);

    const handleResizeRescaleMultipleSet = (targetName: string) => {
      const targetParamIndex = transform.params.findIndex(param => param.name === targetName);
      if (transform.params[targetParamIndex]) {
        transform.params[targetParamIndex].value = getMultiplesInput(width, transformMultiples);
      }
    };

    const handleCropMultipleSet = (targetName: string) => {
      const targetParamIndex = transform.params.findIndex(param => param.name === targetName);
      const targetSiblingParamIndex = transform.params.findIndex(
        param => param.name === getCropSibling(targetName),
      );

      if (transform.params[targetParamIndex]) {
        transform.params[targetParamIndex].value = getMultiplesInput(
          transform.params[targetParamIndex]?.value,
          transformMultiples,
          transform.params[targetSiblingParamIndex]?.value,
        );
      }
    };

    if (transformMultiples && width % transformMultiples !== 0) {
      isResizeOrRescale ? handleResizeRescaleMultipleSet('width') : handleCropMultipleSet('x_max');
    }

    if (transformMultiples && height % transformMultiples !== 0) {
      isResizeOrRescale ? handleResizeRescaleMultipleSet('height') : handleCropMultipleSet('y_max');
    }

    return transform;
  }),
  valid: pipelineSections.valid,
});

export const getSanitizedTrainSection = (sections: Pipeline['sections']) => [
  ...sections.train
    .filter(section => section.id !== 'autoResize')
    .map(section => {
      return {
        ...section,
        params: section.params.map(param => {
          return { name: param.name, value: param.value };
        }),
      };
    }),
];

export const getPreviewDimensions = (
  transformParams: Transform[],
  sections: Pipeline['sections'],
) => {
  const allTransforms = transformParams.filter(t => TRANSFORMS_UI_SCHEMA[1].hasOwnProperty(t.name));

  const appliedTransforms = sections.train.filter(section =>
    allTransforms.find(transform => transform.id === section.id),
  );

  return appliedTransforms.length
    ? getTransformDimensions(appliedTransforms[appliedTransforms.length - 1], transformParams)
    : null;
};

export const getApiUpdatedMedia = async (
  mediaURLs: string[],
  currentMedias: Media[],
  trainSection: TransformParams[],
  projectId?: ProjectId,
) => {
  const train = await Promise.all(
    mediaURLs.map(async url => {
      const transformed = await TransformsApi.executePipeline(
        [url],
        {
          train: trainSection,
          valid: [],
        },
        projectId,
      );
      return transformed['train'][0];
    }),
  );

  return currentMedias.map((t, index) => {
    return { ...t, url: train[index] };
  });
};

export type HyperParamDetails = {
  name: string;
  label: string;
  description: string;
  default: string;
  type: string;
  options?: string[];
};

export const getHyperParamDetails = (
  model: TrainingParamsModel,
  attributeMapping?: Record<string, any>,
) =>
  [
    {
      name: 'learningParams.epochs',
      label: t('Epoch'),
      description: t('Number of epochs to train the model for'),
      default: model['learningParams.epochs'],
      type: typeof model['learningParams.epochs'],
    },
    !!attributeMapping?.['availableModelSizes'] && {
      name: 'archName',
      label: t('Model size'),
      description: t('The size of model'),
      default: model.archName,
      type: typeof model.archName,
      options: (attributeMapping?.['availableModelSizes'] as string[]) ?? [],
    },
    {
      name: 'nmsParams.iou_threshold',
      label: t('Non-maximum suppression'),
      description: t('The IoU threshold for NMS'),
      default: model['nmsParams.iou_threshold'],
      type: typeof model['nmsParams.iou_threshold'],
    },
  ].filter(Boolean) as HyperParamDetails[];

export const getSchemaParamValue = (
  value: number | undefined,
  fallback: JSONSchema7Type | undefined,
): number => (!isNil(value) ? value : Number(fallback));

export const getModelArchModelSizeDisplayName = (
  modelArch: ModelArch | undefined,
  modelSize: string | undefined,
): string | undefined => {
  if (!modelArch || !modelSize) return undefined;
  return ModelArchModelSizeDisplayName[modelArch]?.[modelSize];
};

export const getModelArchsByLabelType = (labelType: LabelType): Record<string, string> => {
  if (
    [LabelType.BoundingBox, LabelType.Segmentation, LabelType.Classification].includes(labelType)
  ) {
    return ModelSizeToArchNameForTooltip[labelType] as Record<string, string>;
  }
  return {};
};

export const getModelDefaultArchNameByLabelType = (labelType: LabelType): string | undefined => {
  if (
    [LabelType.BoundingBox, LabelType.Segmentation, LabelType.Classification].includes(labelType)
  ) {
    return customTrainingDefaultModelArchName[labelType] as string;
  }
  return undefined;
};

export const getErrorsByParamName = (
  transformValues?: TransformParams,
  width?: number,
  height?: number,
  transformParam?: Transform,
  limits?: ClientTrainLimits,
  currentSchema?: ModelArchSchemaItem,
) => {
  const errors = {} as Record<string, string>;
  const rule = currentSchema?.preprocessing?.params?.[transformParam?.name!]?.properties;
  if (!transformValues || width === undefined || height === undefined) {
    return errors;
  }
  const valuesByName = transformValues.params.reduce(
    (res, { name, value }) => ({ ...res, [name]: value }),
    {},
  ) as Record<string, number>;
  if (transformParam?.name === 'Crop') {
    const { x_min, x_max, y_min, y_max } = valuesByName;
    if (x_min < 0) {
      errors.x_min = t('X min should be >= 0');
    }
    if (x_max > width) {
      errors.x_max = t('X max should be <= image width ({{width}})', { width });
    }
    if (x_min >= x_max) {
      errors.x_min = errors.x_max = t('X min should be less than X max');
    }
    if (y_min < 0) {
      errors.y_min = t('Y min should be >= 0');
    }
    if (y_max > height) {
      errors.y_max = t('Y max should be <= image height ({{height}})', { height });
    }
    if (y_min >= y_max) {
      errors.y_min = errors.y_max = t('Y min should be less than Y max');
    }
  } else if (rule) {
    let height = 0;
    let width = 0;
    const areaRule = rule.area;

    transformValues.params.forEach(({ name, value }) => {
      if (name === 'height') {
        height = value;
      }
      if (name === 'width') {
        width = value;
      }

      const isRequired = transformParam?.paramsDescription[name]?.isRequired;
      const { min, max } = rule[name] ?? {};
      if (min !== undefined && value < min) {
        errors[name] = t('{{name}} should be >= {{min}}', { name, min });
      }
      if (max !== undefined && value > max && !(areaRule && areaRule.max)) {
        errors[name] = t('{{name}} should be <= {{max}}', { name, max });
      }
      if (isRequired && isNil(value)) {
        errors[name] = t('{{name}} is required', { name });
      }
    });

    const limit = limits?.largeImage?.maxArea;

    // If there is an area rule, overwrite the error message
    if (limit && height * width > limit) {
      errors.height = t('Area must be <= {{max}}', { max: limit.toLocaleString() });
      errors.width = t('Area must be <= {{max}}', { max: limit.toLocaleString() });
    }
  }

  return errors;
};

export const getErrorsByTransformAndParamName = (
  transformValuesList?: TransformParams[],
  width?: number,
  height?: number,
  transformParams?: Transform[],
  limits?: ClientTrainLimits,
  currentSchema?: ModelArchSchemaItem,
) => {
  const errorsByTransform = {} as { [transformName: string]: { [paramName: string]: string } };
  transformValuesList?.forEach(transformValue => {
    const transformParam = transformParams?.find(t => t.id === transformValue.id);
    const errorsByParamName = getErrorsByParamName(
      transformValue,
      width,
      height,
      transformParam,
      limits,
      currentSchema,
    );
    if (!isEmpty(errorsByParamName)) {
      errorsByTransform[transformParam?.name!] = errorsByParamName;
    }
  });

  return errorsByTransform;
};

export const getErrorsByHyperParams = (schema?: HyperParamsSchema, epochs?: number) => {
  const errorsByHyperParam = {} as { [hyperParamName: string]: { [value: string]: string } };
  if (schema !== undefined && epochs !== undefined) {
    const rule = schema.definitions.LearningParams.properties.epochs;
    const min = rule.minimum ?? 0;
    const max = rule.maximum ?? Infinity;

    if (epochs < min) {
      errorsByHyperParam.model = {
        'learningParams.epochs': t('Number of epochs should be >= {{min}}', { min }),
      };
    }
    if (epochs > max) {
      errorsByHyperParam.model = {
        'learningParams.epochs': t('Number of epochs should be <= {{max}}', { max }),
      };
    }
  }

  return errorsByHyperParam;
};
