import { makeStyles } from '@material-ui/core/styles';
import React, { useCallback, useMemo } from 'react';
import { Paper, Select, Typography, Box } from '@material-ui/core';
import { Alert } from '@material-ui/lab';
import CloseOutlined from '@material-ui/icons/CloseOutlined';
import { Button } from '@clef/client-library';
import MenuItem from '@material-ui/core/MenuItem';
import Modal from '@material-ui/core/Modal';
import CircularProgress from '@material-ui/core/CircularProgress';

import {
  Media,
  Transform,
  Pipeline,
  TransformId,
  TransformParamValue,
  ClientTrainLimits,
  LabelType,
  ModelArchSchemaItem,
} from '@clef/shared/types';
import {
  getMultiplesInput,
  isSingleSliderKey,
  isDropdownKey,
  getSchemaParamValue,
  getErrorsByParamName,
  getMaxTransformPixels,
  getTransformDimensions,
} from '../../../utils/job_train_utils';
import MediaCarousel from './MediaCarousel';
import ParamInput from '../../../components/DataAugmentation/ParamInput';
import { greyScale } from '@clef/client-library';
import { SingleSlider } from './Sliders/SingleSlider';
import { RangeSlider } from './Sliders/RangeSlider';
import { Dropdown } from './Dropdown';
import { TransformUISchema } from '../../../types/client';
import auto_resize from '../../../images/model-iteration/auto_resize.png';
import { isEmpty } from 'lodash';
import { TRANSFORM_TEXTS } from '@/constants/model_train';
import { useTrainingState } from '@/pages/DataBrowser/TrainModelButtonGroup/state';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';

const useStyles = makeStyles(({ spacing }) => ({
  modal: {
    display: 'flex',
    justifyContent: 'center',
    alignItems: 'center',
  },
  modalSelect: {
    borderRadius: spacing(3.5),
    '& .MuiSelect-outlined': {
      padding: `${spacing(2)}px ${spacing(4)}px`,
      paddingRight: spacing(8),
    },
  },
  paperContainer: {
    width: '80vw',
    height: '90vh',
    overflow: 'auto',
    padding: spacing(10),
    fontSize: 14,
    outline: 'none',
  },
  pipelineParamsWrapper: {
    display: 'flex',
    flexWrap: 'wrap',
    justifyContent: 'space-between',
    alignItems: 'baseline',
    gap: `${spacing(2)}px`,
    borderBottom: `1px solid ${greyScale[300]}`,
    marginBottom: 20,
    paddingBottom: 20,
  },
  pipelineParams: {
    display: 'flex',
    flexWrap: 'wrap',
    alignItems: 'flex-start',
    gap: `${spacing(2)}px`,
  },
  buttonProgress: {
    marginLeft: spacing(3),
    color: greyScale[100],
  },
  alerts: {
    display: 'flex',
    justifyContent: 'center',
    alignSelf: 'center',
    borderRadius: 8,
    margin: spacing(4, 0),
    alignItems: 'center',
    fontWeight: 400,
    fontSize: '14px',
  },
}));

type Props = {
  medias: Media[];
  transformParams: Transform[];
  uiSchema: TransformUISchema;
  resizeId: string;
  rescaleId: string;
  cropId?: string;
  transformMultiples?: number;
  isTransformStep?: boolean;
  disableResizeRescaleSwitch?: boolean;
  handleCreateConfirmed: () => void;
  isExecuting: boolean;
  transformModalOpen: boolean;
  handleModalClose: () => void;
  updateTransformParam: (
    transformId: TransformId,
    params: TransformParamValue,
    isTransformListUpdate?: boolean,
  ) => void;
  handleMenuClick: (
    transformName: string,
    isNewEntry: boolean,
    isResizeOption?: boolean,
    event?: any,
    isResizeSwitch?: boolean,
  ) => void;
  pipelineSections: Pipeline['sections'];
  schemaType: string;
  selectedResize: string;
  localParam: number[];
  setLocalParam: (param: number[]) => void;
  limits?: ClientTrainLimits;
  displayAllTransforms?: boolean;
  schema?: ModelArchSchemaItem;
};

const TransformModal = ({
  medias,
  transformParams,
  uiSchema,
  resizeId,
  rescaleId,
  cropId,
  transformMultiples,
  isTransformStep = false,
  disableResizeRescaleSwitch = false,
  handleCreateConfirmed,
  isExecuting,
  transformModalOpen,
  handleModalClose,
  updateTransformParam,
  pipelineSections,
  handleMenuClick,
  schemaType,
  selectedResize,
  localParam,
  setLocalParam,
  limits: limitProps,
  displayAllTransforms = true,
  schema,
}: Props) => {
  const styles = useStyles();
  const { labelType } = useGetSelectedProjectQuery().data ?? {};

  const resizingAfterCrop =
    pipelineSections.train.find(i => i.id === cropId) &&
    (schemaType === 'Resize' || schemaType === 'RescaleWithPadding');
  const pipelineSectionsWithResizingAfterCrop = {
    train: pipelineSections.train.filter(i => i.id !== cropId),
    valid: pipelineSections.valid,
  };

  const resizeOptions = Object.values(uiSchema).filter(option => option.isResizeOption);

  const transformParam = transformParams.find(t => t.name === schemaType);
  const rule = schema?.preprocessing.params?.[transformParam?.name!]?.properties;

  const schemaParams = uiSchema[schemaType];

  const { state } = useTrainingState();
  const limits = limitProps || state.advanced.limits;

  const maxTransformPixels = getMaxTransformPixels(labelType, limits?.largeImage);

  const showSizeAlert = useMemo(() => {
    if (transformParam && (schemaType === 'Resize' || schemaType === 'RescaleWithPadding')) {
      let transformMegapixels;
      const currentTransformParam = transformParams.find(t => t.name === schemaType);
      const currentPipeline = pipelineSections.train.find(p => p.id === currentTransformParam?.id);
      if (!currentTransformParam || !currentPipeline) return false;
      const { width, height } = getTransformDimensions(currentPipeline, transformParams);
      if (width && height) {
        transformMegapixels = width * height;
      }

      if (transformMegapixels && transformMegapixels > maxTransformPixels) {
        return true;
      }
    }
    return false;
  }, [maxTransformPixels, pipelineSections.train, schemaType, transformParam, transformParams]);

  let lowerRangeSliderKeys: string[] = [];
  if (schemaParams.hasRangeSlider) {
    lowerRangeSliderKeys = Object.keys(transformParam?.paramsDescription!).filter(param =>
      param.match(/lower_limit$/),
    );
  }
  let singleSliderKeys: string[] = [];
  if (schemaParams.hasSingleSlider) {
    singleSliderKeys = Object.keys(transformParam?.paramsDescription!).filter(param =>
      isSingleSliderKey(param),
    );
  }
  let dropdownKeys: string[] = [];
  if (schemaParams.hasDropdown) {
    dropdownKeys = Object.keys(transformParam?.paramsDescription!).filter(param =>
      isDropdownKey(param),
    );
  }
  const autoResizeValue =
    pipelineSections?.train
      ?.find(section => section.id === transformParam?.id)
      ?.params.find(({ name }) => name === 'Defect shortest side pixel size')?.value ||
    transformParam?.paramsDescription?.estimate?.jsonSchema?.default;

  const hideAutoResizeInput = ![LabelType.AnomalyDetection, LabelType.Classification].includes(
    labelType!,
  );

  const transformValues = useMemo(() => {
    return pipelineSections.train.find(section => section.id === transformParam?.id)!;
  }, [pipelineSections.train, transformParam?.id]);

  const onParamInputBlur = useCallback(
    (value: number | undefined, field: string) => {
      updateTransformParam(transformParam?.id!, { name: field, value });
    },
    [transformParam, updateTransformParam],
  );

  const { width, height } = useMemo(() => {
    const resizeOrRescale = pipelineSections.train.find(
      t => t.id === resizeId || t.id === rescaleId,
    );
    const { width, height } =
      (resizeOrRescale?.params.reduce(
        (res, { name, value }) => ({ ...res, [name]: value }),
        {},
      ) as Record<string, number>) ?? {};
    return { width, height };
  }, [pipelineSections.train, rescaleId, resizeId]);

  const errorsByParamName = useMemo(() => {
    return getErrorsByParamName(transformValues, width, height, transformParam, limits, schema);
  }, [height, limits, schema, transformParam, transformValues, width]);

  const hasError = !isEmpty(errorsByParamName);

  const hueSaturationValueSldierLabels = useMemo(() => {
    if (transformParam?.name !== 'HueSaturationValue') return null;
    return transformParam?.label?.split(' ') ?? null;
  }, [transformParam?.label, transformParam?.name]);

  return (
    <Modal
      open={transformModalOpen}
      onClose={handleModalClose}
      className={styles.modal}
      data-testid={`transform-modal`}
      disableEnforceFocus
    >
      <Paper elevation={0} className={styles.paperContainer}>
        <Box display="flex" justifyContent="space-between" alignItems="center">
          <h2 data-testid={`transform-modal-title`}>{t(`${schemaParams?.label}`)}</h2>
          <CloseOutlined style={{ cursor: 'pointer' }} onClick={handleModalClose} />
        </Box>
        {!disableResizeRescaleSwitch && isTransformStep && schemaParams?.isResizeOption && (
          <Box mb={4}>
            <Select
              variant="outlined"
              value={selectedResize}
              className={styles.modalSelect}
              data-testid={`resize-select`}
            >
              {resizeOptions.map(option => (
                <MenuItem
                  key={option.name}
                  value={option.name}
                  onClick={() =>
                    handleMenuClick(option.name, true, option.isResizeOption, null, true)
                  }
                >
                  {option.label}
                </MenuItem>
              ))}
            </Select>
          </Box>
        )}
        {!schemaParams.estimate && (
          <p data-testid={`transform-modal-description`}>
            {TRANSFORM_TEXTS[transformParam?.name ?? '']?.description ??
              transformParam?.description}
          </p>
        )}
        {(schemaParams.hasRangeSlider ||
          schemaParams.hasSingleSlider ||
          schemaParams.probability ||
          schemaParams.estimate) && (
          <div className={styles.pipelineParamsWrapper}>
            {schemaParams.hasRangeSlider && (
              <div>
                {lowerRangeSliderKeys.map(lowerKey => {
                  const upperKey = lowerKey.replace('lower', 'upper');
                  return (
                    <RangeSlider
                      key={lowerKey}
                      minKey={lowerKey}
                      maxKey={upperKey}
                      transform={transformParam!}
                      title={`${
                        lowerKey.substr(0, lowerKey.indexOf('_lower_limit')) ||
                        (transformParam?.label ?? transformParam?.name)
                      } Range`}
                      updateTransformParam={values => {
                        updateTransformParam(transformParam?.id!, {
                          name: lowerKey,
                          value: values[0],
                        });
                        updateTransformParam(transformParam?.id!, {
                          name: upperKey,
                          value: values[1],
                        });
                      }}
                      upperValue={getSchemaParamValue(
                        pipelineSections?.train
                          ?.find(section => section.id === transformParam?.id)
                          ?.params.find(({ name }) => name === upperKey)?.value,
                        transformParam?.paramsDescription[upperKey].jsonSchema.default,
                      )}
                      lowerValue={getSchemaParamValue(
                        pipelineSections?.train
                          ?.find(section => section.id === transformParam?.id)
                          ?.params.find(({ name }) => name === lowerKey)?.value,
                        transformParam?.paramsDescription[lowerKey].jsonSchema.default,
                      )}
                    />
                  );
                })}
              </div>
            )}
            {schemaParams.hasSingleSlider && (
              <div>
                {singleSliderKeys.map(sliderKey => {
                  return (
                    <SingleSlider
                      key={sliderKey}
                      title={sliderKey || transformParam!.label || transformParam!.name}
                      minKey={sliderKey}
                      maxKey={sliderKey}
                      transform={transformParam!}
                      value={getSchemaParamValue(
                        pipelineSections?.train
                          ?.find(section => section.id === transformParam?.id)
                          ?.params.find(({ name }) => name === sliderKey)?.value,
                        transformParam?.paramsDescription[sliderKey].jsonSchema.default,
                      )}
                      updateTransformParam={value =>
                        updateTransformParam(transformParam?.id!, {
                          name: sliderKey,
                          value: value,
                        })
                      }
                    />
                  );
                })}
              </div>
            )}
            {schemaParams.hasDropdown && (
              <div>
                {dropdownKeys.map(key => {
                  return (
                    <Dropdown
                      key={key}
                      title={key || transformParam!.label || transformParam!.name}
                      value={
                        pipelineSections?.train
                          ?.find(section => section.id === transformParam?.id)
                          ?.params.find(({ name }) => name === key)?.value ||
                        transformParam?.paramsDescription[key].jsonSchema.default
                      }
                      transform={transformParam!}
                      updateTransformParam={value =>
                        updateTransformParam(transformParam?.id!, {
                          name: key,
                          value: value,
                        })
                      }
                    />
                  );
                })}
              </div>
            )}
            {schemaParams.probability && (
              <SingleSlider
                title={'Probability'}
                minKey={'p'}
                maxKey={'p'}
                transform={transformParam!}
                value={getSchemaParamValue(
                  pipelineSections?.train
                    ?.find(section => section.id === transformParam?.id)
                    ?.params.find(({ name }) => name === 'p')?.value,
                  transformParam?.paramsDescription['p'].jsonSchema.default,
                )}
                updateTransformParam={value =>
                  updateTransformParam(transformParam?.id!, {
                    name: 'p',
                    value: value,
                  })
                }
              />
            )}
            {schemaParams.estimate && (
              <>
                <Box display="flex">
                  <Box py={4}>
                    <Box pr={8}>
                      <Typography variant="body1">
                        {t(`${transformParam?.paramsDescription.estimate.description}`)}
                      </Typography>
                    </Box>

                    {!hideAutoResizeInput && (
                      <Box py={4}>
                        <ParamInput
                          jsonSchema={transformParam?.paramsDescription.estimate.jsonSchema || {}}
                          isRequired={
                            transformParam?.paramsDescription.estimate.isRequired || false
                          }
                          disabled={false}
                          onChange={value => {
                            updateTransformParam(transformParam?.id!, {
                              name: 'Defect shortest side pixel size',
                              value,
                            });
                          }}
                          onBlur={() =>
                            updateTransformParam(transformParam?.id!, {
                              name: 'Defect shortest side pixel size',
                              value: getMultiplesInput(autoResizeValue, transformMultiples),
                            })
                          }
                          label={transformParam?.paramsDescription.estimate.name || ''}
                          value={autoResizeValue}
                          id={'estimate'}
                        />
                      </Box>
                    )}
                  </Box>
                  <Box>
                    <img src={auto_resize} alt="auto" />
                  </Box>
                </Box>
                <Box display="flex" justifyContent="flex-end" width="100%" pt={2}>
                  <Button
                    id="add-pipeline-augmentation"
                    disabled={isExecuting || hasError}
                    variant="contained"
                    color="primary"
                    onClick={() => {
                      handleCreateConfirmed();
                    }}
                  >
                    {t('Add')}
                    {isExecuting && (
                      <CircularProgress size={24} className={styles.buttonProgress} />
                    )}
                  </Button>
                </Box>
              </>
            )}
            {!schemaParams.estimate && (
              <Button
                id="add-pipeline-augmentation"
                disabled={isExecuting || hasError}
                variant="contained"
                color="primary"
                onClick={() => {
                  handleCreateConfirmed();
                }}
              >
                {t('Add')}
                {isExecuting && <CircularProgress size={24} className={styles.buttonProgress} />}
              </Button>
            )}
          </div>
        )}

        {medias.length > 0 && (
          <MediaCarousel
            showCropTool={schemaParams.showCropTool}
            medias={medias}
            lowerRangeSliderKeys={lowerRangeSliderKeys}
            transformValues={
              pipelineSections.train.find(section => section.id === transformParam?.id)!
            }
            resizeId={resizeId}
            rescaleId={rescaleId}
            pipelineSections={
              resizingAfterCrop ? pipelineSectionsWithResizingAfterCrop : pipelineSections
            }
            localTransformParam={localParam}
            updateTransformParam={updateTransformParam}
            alwaysTransform={!schemaParams.hasRangeSlider && schemaParams.probability}
            transformMultiples={transformMultiples}
            displayAllTransforms={displayAllTransforms}
            transformParams={transformParams}
          />
        )}
        {!(schemaParams.hasRangeSlider || schemaParams.probability || schemaParams.estimate) && (
          <div className={styles.pipelineParamsWrapper}>
            <div className={styles.pipelineParams}>
              {Object.keys(transformParam?.paramsDescription || {}).map((field: string) => {
                const fieldValue = pipelineSections?.train
                  ?.find(section => section.id === transformParam?.id)
                  ?.params.find(({ name }) => name === field)?.value;
                const fieldRule = rule?.[field];
                return field === 'always_apply' || field === 'p' ? null : (
                  <ParamInput
                    key={field}
                    jsonSchema={transformParam?.paramsDescription[field].jsonSchema || {}}
                    isRequired={transformParam?.paramsDescription[field].isRequired || false}
                    disabled={false}
                    onChange={value => {
                      updateTransformParam(transformParam?.id!, {
                        name: field,
                        value,
                      });
                    }}
                    onBlur={e => {
                      const { value } = e.target;
                      if (value) {
                        onParamInputBlur(Number(value), field);
                      }
                    }}
                    label={field}
                    value={fieldValue === null ? '' : fieldValue}
                    id={field}
                    min={fieldRule?.min}
                    max={fieldRule?.max}
                    errorMessage={errorsByParamName[field]}
                  />
                );
              })}
            </div>
            <Button
              id="add-pipeline-transform"
              disabled={isExecuting || hasError}
              variant="contained"
              color="primary"
              onClick={() => {
                handleCreateConfirmed();
              }}
            >
              {t('Add')}
              {isExecuting && <CircularProgress size={24} className={styles.buttonProgress} />}
            </Button>

            {resizingAfterCrop && (
              <Box display="block" width={'100%'}>
                <Alert className={styles.alerts} severity="error" data-testid={'defaultSizeAlert'}>
                  {t(
                    'When resizing or rescaling after cropping, the crop will be removed. Please re-crop to ensure it fits within the resized image.',
                  )}
                </Alert>
              </Box>
            )}

            {showSizeAlert && (
              <Alert className={styles.alerts} severity="error" data-testid={'defaultSizeAlert'}>
                {t(
                  'Models need a consistent size for all media. Please add resize, crop or rescale with padding transform. To ensure good performance keep media size below {{num}} megapixels.',
                  {
                    num: Math.round(maxTransformPixels / 10_000) / 100,
                  },
                )}
              </Alert>
            )}
          </div>
        )}
        {schemaParams.hasRangeSlider &&
          lowerRangeSliderKeys.map((lowerKey, index) => {
            const upperKey = lowerKey.replace('lower', 'upper');
            return (
              <div key={lowerKey} className={styles.pipelineParams}>
                <SingleSlider
                  title={
                    hueSaturationValueSldierLabels
                      ? hueSaturationValueSldierLabels[index] + ' Shift'
                      : transformParam?.label || ''
                  }
                  transform={transformParam!}
                  minKey={lowerKey}
                  maxKey={upperKey}
                  value={Number(transformParam?.paramsDescription[lowerKey].jsonSchema.default)}
                  updateTransformParam={value => {
                    // @ts-ignore
                    setLocalParam(localValues => {
                      localValues[index] = value;
                      return [...localValues];
                    });
                  }}
                />

                <Button
                  id="set-lower-limit"
                  variant="text"
                  color="primary"
                  onClick={() => {
                    updateTransformParam(transformParam?.id!, {
                      name: lowerKey,
                      value: localParam[index],
                    });
                  }}
                >
                  Set as Lower Limit
                </Button>
                <Button
                  id="set-upper-limit"
                  variant="text"
                  color="primary"
                  onClick={() => {
                    updateTransformParam(transformParam?.id!, {
                      name: upperKey,
                      value: localParam[index],
                    });
                  }}
                >
                  Set as Upper Limit
                </Button>
              </div>
            );
          })}
      </Paper>
    </Modal>
  );
};

export default TransformModal;
