import React, { useCallback } from 'react';
import {
  Box,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
  Dialog,
  makeStyles,
  DialogTitle,
  DialogContent,
} from '@material-ui/core';
import CloseOutlined from '@material-ui/icons/CloseOutlined';
import cx from 'classnames';
import { isEqual } from 'lodash';

import { format } from 'date-fns';
import { cloneDeep } from 'lodash';

import { Typography, ApiResponseLoader } from '@clef/client-library';
import { TransformParamValue } from '@clef/shared/types';
import { ModelArch } from '@clef/shared/types/model_arch';

import { TransformUISchema } from '@/types/client';
import { TRANSFORMS_UI_SCHEMA } from '@/constants/model_train';
import { CUSTOMIZED_REGIME } from '@/constants/model';
import { getModelArchModelSizeDisplayName } from '@/utils/job_train_utils';
import { RegisteredModelWithThreshold } from '@/api/model_api';
import { useJobDetailForCurrentProject } from '@/serverStore/jobs';
import { useDatasetExportedWithVersionsQuery } from '@/serverStore/dataset';

const useStyles = makeStyles(theme => ({
  dialog: {
    minWidth: '1000px',
  },
  dialogContent: {
    padding: theme.spacing(4, 8, 10, 8),
  },
  tableContainer: {
    borderRadius: '10px',
    border: `1px solid ${theme.palette.greyModern[200]}`,
  },
  compareSettingsTable: {
    '& th, & td': {
      borderRight: `1px solid ${theme.palette.greyModern[200]}`,
      borderBottom: `1px solid ${theme.palette.greyModern[200]}`,
    },
    '& th:last-child, & td:last-child': {
      borderRight: 'none',
    },
    '& tbody tr:last-child td, & tbody tr:last-child th': {
      borderBottom: 'none',
    },
  },
  tableHeader: {
    background: theme.palette.grey[50],
    '& .MuiTableCell-root': {
      padding: '16px 16px',
    },
  },
  sectionHeader: {
    width: '20%',
  },
  modelInfoHeader: {
    width: '40%',
  },
  sectionColumn: {
    verticalAlign: 'top',
  },
  sectionColumnHighlight: {
    backgroundColor: theme.palette.yellow[50],
  },
  configItem: {
    display: 'flex',
    alignItems: 'flex-start',
    marginBottom: theme.spacing(3),
    '&:last-child': {
      marginBottom: 0,
    },
  },
  configLabel: {
    width: 134,
    color: theme.palette.greyModern[900],
    textTransform: 'capitalize',
    whiteSpace: 'nowrap',
  },
  configValue: {
    color: theme.palette.greyModern[500],
    textTransform: 'capitalize',
  },
  explanation: {
    color: theme.palette.greyModern[500],
  },
}));

const trainingParamsTypes = {
  HYPERPARAMETER: 'hyperparameter',
  TRANSFORM: 'transform',
  AUGMENTATION: 'augmentation',
};

const ComparisonTableDefaultSections = [
  { id: 'modelName', label: 'Model' },
  { id: 'threshold', label: 'Confidence threshold' },
  { id: 'snapshot', label: 'Trained From' },
  { id: 'createdAt', label: 'Trained at' },
  { id: 'creator', label: 'Trained by' },
  { id: trainingParamsTypes.HYPERPARAMETER, label: 'Hyperparameter' },
  { id: trainingParamsTypes.TRANSFORM, label: 'Transform' },
  { id: trainingParamsTypes.AUGMENTATION, label: 'Augmentation' },
];
type TransformConfig = Record<string, any>;

type HyperParametersConfig = {
  archClass: string;
  backboneParams: { name: string };
  learningParams: { epochs: number };
};
type ComparisonTableSection = {
  id: string;
  label: string;
  baseline: string | number | HyperParametersConfig | TransformConfig[];
  candidate: string | number | HyperParametersConfig | TransformConfig[];
};

const enum ModelType {
  BASELINE = 'baseline',
  CANDIDATE = 'candidate',
}

export type ModelComparisonTrainingSettingsProps = {
  baseline: RegisteredModelWithThreshold;
  candidate: RegisteredModelWithThreshold;
  open: boolean;
  onClose: () => void;
};

const ModelComparisonTrainingSettings: React.FC<ModelComparisonTrainingSettingsProps> = ({
  baseline,
  candidate,
  open,
  onClose,
}) => {
  const styles = useStyles();

  const {
    data: baselineDetailsData,
    isLoading: baselineDetailsLoading,
    error: baselineDetailsError,
  } = useJobDetailForCurrentProject(baseline.id);

  const {
    data: candidateDetailsData,
    isLoading: candidateDetailsLoading,
    error: candidateDetailsError,
  } = useJobDetailForCurrentProject(candidate.id);

  const { data: datasetExported } = useDatasetExportedWithVersionsQuery({
    withCount: true,
    includeNotCompleted: true,
    includeFastEasy: true,
  });

  const defaultConfigDesc = 'Fast train default config';

  const formatData = useCallback(
    (data: RegisteredModelWithThreshold, type: ModelType, sectionId: string) => {
      switch (sectionId) {
        case 'modelName':
        case 'threshold':
          return data[sectionId];
        case 'snapshot':
          return (
            datasetExported?.datasetVersions?.find(item => item.id === data.datasetVersionId)
              ?.name ?? '-'
          );
        case 'createdAt':
          return data.createdAt ? format(new Date(data.createdAt), 'MM/dd/yyyy hh:mm a') : '-';
        case 'creator':
          return (
            datasetExported?.datasetVersions?.find(item => item.id === data.datasetVersionId)
              ?.creator ?? '-'
          );
        case trainingParamsTypes.HYPERPARAMETER:
          return (
            (type === ModelType.BASELINE
              ? baselineDetailsData?.hyperParams?.model
              : candidateDetailsData?.hyperParams?.model) ?? '-'
          );
        case trainingParamsTypes.TRANSFORM:
          return type === ModelType.BASELINE
            ? baselineDetailsData?.hyperParams?.preprocessingConfig?.reduce(
                (acc: TransformConfig, cur: TransformConfig) => ({ ...acc, ...cur }),
                {},
              ) ??
                (baselineDetailsData?.hyperParams?.model.regime === CUSTOMIZED_REGIME
                  ? ''
                  : defaultConfigDesc)
            : candidateDetailsData?.hyperParams?.preprocessingConfig?.reduce(
                (acc: TransformConfig, cur: TransformConfig) => ({ ...acc, ...cur }),
                {},
              ) ??
                (candidateDetailsData?.hyperParams?.model.regime === CUSTOMIZED_REGIME
                  ? ''
                  : defaultConfigDesc);
        case trainingParamsTypes.AUGMENTATION:
          return type === ModelType.BASELINE
            ? baselineDetailsData?.hyperParams?.augmentationConfig?.reduce(
                (acc: TransformConfig, cur: TransformConfig) => ({ ...acc, ...cur }),
                {},
              ) ??
                (baselineDetailsData?.hyperParams?.model.regime === CUSTOMIZED_REGIME
                  ? ''
                  : defaultConfigDesc)
            : candidateDetailsData?.hyperParams?.augmentationConfig?.reduce(
                (acc: TransformConfig, cur: TransformConfig) => ({ ...acc, ...cur }),
                {},
              ) ??
                (candidateDetailsData?.hyperParams?.model.regime === CUSTOMIZED_REGIME
                  ? ''
                  : defaultConfigDesc);
        default:
          return '-';
      }
    },
    [
      baselineDetailsData?.hyperParams,
      candidateDetailsData?.hyperParams,
      datasetExported?.datasetVersions,
    ],
  );

  const tableDataFormatted = ComparisonTableDefaultSections.map(section => {
    return {
      ...section,
      baseline: formatData(baseline, ModelType.BASELINE, section.id),
      candidate: formatData(candidate, ModelType.CANDIDATE, section.id),
    } as ComparisonTableSection;
  });

  const renderHyperparameter = (hyperparameters: HyperParametersConfig, type: ModelType) => {
    const archName =
      type === ModelType.BASELINE ? baselineDetailsData?.archName : candidateDetailsData?.archName;
    const modelArch =
      type === ModelType.BASELINE
        ? baselineDetailsData?.modelArch
        : candidateDetailsData?.modelArch;
    return (
      <>
        <Box className={styles.configItem}>
          <Box className={styles.configLabel}>{t('Epoch')}</Box>
          <Box padding={2} />
          <Box className={styles.configValue}>{hyperparameters.learningParams?.epochs}</Box>
        </Box>
        <Box className={styles.configItem}>
          <Box className={styles.configLabel}>{t('Model size')}</Box>
          <Box padding={2} />
          <Box className={styles.configValue}>
            {archName ??
              getModelArchModelSizeDisplayName(
                modelArch as ModelArch | undefined,
                hyperparameters.backboneParams?.name,
              )}
          </Box>
        </Box>
      </>
    );
  };

  // Since TRANSFORMS_UI_SCHEMA is a global object, here we only display the label, no modify
  // If need to change its content, please use useState or jotai
  const transformsUiSchema = cloneDeep(TRANSFORMS_UI_SCHEMA);

  const renderTransforms = (transforms: string | TransformConfig[], schema: TransformUISchema) => {
    if (typeof transforms === 'string') {
      return transforms === defaultConfigDesc ? (
        <Typography className={styles.explanation}>{transforms}</Typography>
      ) : (
        transforms
      );
    }

    const entries = Object.entries(transforms);
    return entries
      .map(([key, value], index) => {
        return (
          <Box key={key + index} className={styles.configItem}>
            <Box className={styles.configLabel}>{schema[key].label ?? key}</Box>
            <Box padding={2} />
            <Box>
              {Object.entries(value as TransformParamValue)
                .map(([paramName, paramValue]) => {
                  if (
                    paramName === 'always_apply' ||
                    paramName === 'interpolation' ||
                    paramName === 'border_mode' ||
                    (!schema[key].probability && paramName === 'p')
                  ) {
                    return null;
                  } else {
                    return (
                      <Box key={paramName} className={styles.configValue}>
                        {paramName == 'p' ? 'Probability' : paramName.replace(/_/g, ' ')}:{' '}
                        {paramValue}
                      </Box>
                    );
                  }
                })
                .filter(item => item !== null)}
            </Box>
          </Box>
        );
      })
      .flat();
  };

  const renderSectionValue = (section: ComparisonTableSection, type: ModelType) => {
    if (section.id === trainingParamsTypes.HYPERPARAMETER) {
      return renderHyperparameter(
        type === ModelType.BASELINE
          ? (section.baseline as HyperParametersConfig)
          : (section.candidate as HyperParametersConfig),
        type,
      );
    }
    if (section.id === trainingParamsTypes.TRANSFORM) {
      return type === ModelType.BASELINE
        ? renderTransforms(section.baseline as string | TransformConfig[], transformsUiSchema[1])
        : renderTransforms(section.candidate as string | TransformConfig[], transformsUiSchema[1]);
    }
    if (section.id === trainingParamsTypes.AUGMENTATION) {
      return type === ModelType.BASELINE
        ? renderTransforms(section.baseline as string | TransformConfig[], transformsUiSchema[2])
        : renderTransforms(section.candidate as string | TransformConfig[], transformsUiSchema[2]);
    }
    if (section.id === 'modelName') {
      return (
        <Typography maxWidth={300}>
          {type === ModelType.BASELINE ? section.baseline : section.candidate}
        </Typography>
      );
    }
    return type === ModelType.BASELINE ? section.baseline : section.candidate;
  };

  return (
    <Dialog open={open} classes={{ paper: styles.dialog }} onClose={onClose}>
      <DialogTitle>
        <Box display="flex" justifyContent="space-between" alignItems="center">
          <Typography variant="h2_semibold">{t('Compare Training Settings')}</Typography>
          <CloseOutlined onClick={onClose} />
        </Box>
      </DialogTitle>
      <DialogContent className={styles.dialogContent}>
        <ApiResponseLoader
          response={{ baselineDetailsData, candidateDetailsData }}
          loading={baselineDetailsLoading || candidateDetailsLoading}
          error={baselineDetailsError || candidateDetailsError}
          defaultHeight={300}
        >
          {_ => (
            <>
              <TableContainer className={styles.tableContainer}>
                <Table className={styles.compareSettingsTable}>
                  <TableHead className={styles.tableHeader}>
                    <TableRow>
                      <TableCell className={styles.sectionHeader}>{''}</TableCell>
                      <TableCell className={styles.modelInfoHeader}>
                        {t('Baseline model')}
                      </TableCell>
                      <TableCell className={styles.modelInfoHeader}>
                        {t('Candidate model')}
                      </TableCell>
                    </TableRow>
                  </TableHead>
                  <TableBody>
                    {tableDataFormatted.map(section => (
                      <TableRow key={section.id}>
                        <TableCell component="th" scope="row" className={styles.sectionColumn}>
                          {section.label}
                        </TableCell>
                        <TableCell
                          className={cx(styles.sectionColumn, {
                            [styles.sectionColumnHighlight]:
                              !isEqual(section.baseline, section.candidate) &&
                              Object.values(trainingParamsTypes).includes(section.id),
                          })}
                        >
                          {renderSectionValue(section, ModelType.BASELINE)}
                        </TableCell>
                        <TableCell
                          className={cx(styles.sectionColumn, {
                            [styles.sectionColumnHighlight]:
                              !isEqual(section.baseline, section.candidate) &&
                              Object.values(trainingParamsTypes).includes(section.id),
                          })}
                        >
                          {renderSectionValue(section, ModelType.CANDIDATE)}
                        </TableCell>
                      </TableRow>
                    ))}
                  </TableBody>
                </Table>
              </TableContainer>
              <Box display="flex" justifyContent="flex-end" mt={2}>
                <Typography className={styles.explanation}>
                  {t('Different training configurations are highlighted.')}
                </Typography>
              </Box>
            </>
          )}
        </ApiResponseLoader>
      </DialogContent>
    </Dialog>
  );
};

export default ModelComparisonTrainingSettings;
