import React, { useMemo, useState } from 'react';
import cx from 'classnames';
import {
  DatasetGroupOptions,
  DatasetVersion,
  JobMetricData,
  MetricLabel,
  RegisteredModel,
} from '@clef/shared/types';
import { useJobDetailForCurrentProject } from '@/serverStore/jobs';
import { Box, Paper, Table, TableBody, TableCell, TableRow, makeStyles } from '@material-ui/core';
import { groupBy, includes, isEmpty, startCase, cloneDeep, upperFirst } from 'lodash';
import { useGetDatasetStatsQuery } from '@/serverStore/dataset';
import { defaultSelectOptions } from '@/constants/data_browser';
import { TRANSFORMS_UI_SCHEMA } from '@/constants/model_train';
import { useUserInfo } from '@/hooks/api/useUserApi';
import { MediaSplitName, SplitMapping, splitColorMap } from '@/constants/stats_card';
import { ApiResponseLoader, Button, DistributionChart, Typography } from '@clef/client-library';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import SnapshotDetailItem from '@/pages/DatasetSnapshot/SnapshotSummary/SnapshotDetailItem';
import LoadingProgress from '../LoadingProgress';
import { useMaglevGetJobLogsApi } from '@/hooks/api/useMaglevApi';
import { LazyLog } from 'react-lazylog';
import { isModelTrainingFailed } from '@/store/projectModelInfoState/utils';
import ImageDetailsDialog from './ImageDetailsDialog';
import MetricsChart from '../../MetricsChart';
import { DISPLAY_SUMMARY_EVAL_METRICS } from '../../JobDetailsPage';
import { getModelArchModelSizeDisplayName } from '@/utils/job_train_utils';
import { ModelArch } from '@clef/shared/types/model_arch';
import { useIsFeatureEnabledAndMayHideForSnowflake } from '@/hooks/useFeatureGate';

export const DISPLAY_LOSS_METRICS = [MetricLabel.loss_train];

export const DISPLAY_VALIDATION_METRICS = [
  MetricLabel.loss_val,
  MetricLabel.mIoU_val,
  MetricLabel.mAP_val,
  MetricLabel.f1_val,
];

const CHART_WIDTH = 540;

type TrainingInformationPanelProps = {
  model: RegisteredModel;
  datasetVersion: DatasetVersion;
  hideCharts?: boolean;
};

const useStyles = makeStyles(theme => ({
  chartsContainer: {
    padding: theme.spacing(8, 5, 3, 5),
    marginBottom: theme.spacing(8),
    backgroundColor: theme.palette.greyModern[50],
  },
  chartAndTitleContainer: {
    marginRight: theme.spacing(5),
    marginBottom: theme.spacing(5),
    '&:last-child': {
      marginRight: 0,
    },
  },
  chartTitle: {
    paddingBottom: theme.spacing(2),
  },
  chartContainer: {
    width: 480,
    borderRadius: 6,
    backgroundColor: theme.palette.common.white,
    boxShadow: theme.boxShadow.default,
  },
  metricsChartContainer: {
    width: CHART_WIDTH,
    borderRadius: 6,
  },
  tableFirstLevelHeader: {
    fontWeight: 'bold',
    width: 160,
  },
  tableSecondLevelHeader: {
    width: 120,
  },
  tableSecondLevelContent: {
    color: theme.palette.greyModern[500],
  },
  commonTableCell: {
    padding: theme.spacing(0, 0, 6, 0),
    border: 'none',
    verticalAlign: 'top',
  },
  tableCellContentValue: {
    color: theme.palette.greyModern[500],
    paddingLeft: theme.spacing(2),
  },
  infoBlock: {
    minWidth: 200,
    maxWidth: 600,
    height: 130,
    borderRadius: 6,
  },
  statsInfo: {
    flex: 1,
    minWidth: 270,
    position: 'relative',
    padding: theme.spacing(7, 12, 3, 0),
  },
  detailTitle: {
    position: 'absolute',
    top: theme.spacing(4),
    fontSize: 14,
    fontWeight: 700,
  },
  sectionTitle: {
    fontWeight: 500,
  },
}));

const TransformTable = (props: { transformMap: any }) => {
  const styles = useStyles();
  const { transformMap } = props;
  const { Resize: resize, RescaleWithPadding: rescaleWithPadding, Crop: crop } = transformMap;

  return resize || rescaleWithPadding || crop ? (
    <>
      <TableRow>
        <TableCell
          className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
          padding="none"
          component="th"
          scope="row"
          rowSpan={Object.keys(transformMap).length}
        >
          {t('Transform')}
        </TableCell>
        {resize && (
          <TableCell className={styles.commonTableCell}>
            {t('Resize {{resizeText}}', {
              resizeText: (
                <Typography display="inline" className={styles.tableCellContentValue}>
                  {t('Height: {{height}} x Width: {{width}}', {
                    height: resize.height,
                    width: resize.width,
                  })}
                </Typography>
              ),
            })}
          </TableCell>
        )}
        {rescaleWithPadding && (
          <TableCell className={styles.commonTableCell}>
            {t('Rescale {{rescaleWithPadding}}', {
              rescaleWithPadding: (
                <Typography display="inline" className={styles.tableCellContentValue}>
                  {t('Height: {{height}} x Width: {{width}} | Padding Value: {{paddingValue}}', {
                    height: rescaleWithPadding.height,
                    width: rescaleWithPadding.width,
                    paddingValue: rescaleWithPadding.padding_value,
                  })}
                </Typography>
              ),
            })}
          </TableCell>
        )}
      </TableRow>
      {crop && (
        <TableRow>
          <TableCell className={styles.commonTableCell}>
            {t('Crop {{cropText}}', {
              cropText: (
                <Typography display="inline" className={styles.tableCellContentValue}>
                  {t('X Min: {{xMin}}, Y Min: {{yMin}}, X Max: {{xMax}}, Y Max: {{yMax}}', {
                    xMin: crop.x_min,
                    yMin: crop.y_min,
                    xMax: crop.x_max,
                    yMax: crop.y_max,
                  })}
                </Typography>
              ),
            })}
          </TableCell>
        </TableRow>
      )}
    </>
  ) : null;
};

const AugmentationConfigTable = (props: { config: any }) => {
  const { config } = props;

  // Since TRANSFORMS_UI_SCHEMA is a global object, here we only display the label, no modify
  // If need to change its content, please use useState or jotai
  const transformsUiSchema = cloneDeep(TRANSFORMS_UI_SCHEMA);

  const styles = useStyles();
  return (
    <Table>
      <TableBody>
        {config.map((config: any) => {
          const configName = Object.keys(config)[0];
          const probability = config[configName].p;
          const metricsExcludeProbability: { metricsName: string; metricsValue: number }[] = [];
          Object.entries(config[configName]).forEach((entry: any) => {
            if (!['always_apply', 'p'].includes(entry[0])) {
              metricsExcludeProbability.push({
                metricsName: entry[0],
                metricsValue: entry[1],
              });
            }
          });
          return (
            <TableRow key={configName}>
              <TableCell
                className={cx(styles.tableSecondLevelHeader, styles.commonTableCell)}
                padding="none"
                component="th"
                scope="row"
              >
                {t('{{configName}}', {
                  configName: transformsUiSchema[2][configName].label ?? configName,
                })}
              </TableCell>
              <TableCell
                className={cx(styles.commonTableCell, styles.tableSecondLevelContent)}
                padding="none"
                component="th"
                scope="row"
              >
                {t('Probability {{probability}}{{otherConfigsText}}', {
                  probability,
                  otherConfigsText: metricsExcludeProbability.length ? (
                    <Typography display="inline">
                      {t(' | {{otherConfigs}}', {
                        otherConfigs: metricsExcludeProbability.map(
                          ({ metricsName, metricsValue }, index) => (
                            <span key={index}>
                              {t('{{metricsNameFormatted}}: {{metricsValue}}{{separator}}', {
                                metricsNameFormatted: startCase(metricsName),
                                metricsValue,
                                separator: index < metricsExcludeProbability.length - 1 ? ', ' : '',
                              })}
                            </span>
                          ),
                        ),
                      })}
                    </Typography>
                  ) : (
                    ''
                  ),
                })}
              </TableCell>
            </TableRow>
          );
        })}
      </TableBody>
    </Table>
  );
};

const isEvalMetric = (metric: JobMetricData) => {
  return metric.values.length == 1 && includes(DISPLAY_SUMMARY_EVAL_METRICS, metric.name);
};

const getFormattedMetrics = (metrics: JobMetricData[] | undefined, labelArr: MetricLabel[]) => {
  return groupBy(
    metrics?.filter(
      // Note: MetricLabel.old_mean_iou_train and MetricLabel.mIOU have the same value
      // which is why we need to add below "!isEvalMetric" check
      metric => includes(labelArr, metric.name) && !isEvalMetric(metric),
    ),
    // Render metrics with `_val` suffix in same chart
    ({ name }) => name.replace(/_val$/, ''),
  );
};

const TrainingInformationPanel = (props: TrainingInformationPanelProps) => {
  const { model, datasetVersion, hideCharts = false } = props;
  const styles = useStyles();
  const { data: project } = useGetSelectedProjectQuery();
  const { data: jobDetails, isLoading: isJobDetailLoading } = useJobDetailForCurrentProject(
    model?.id,
  );
  const { metrics, createdAt, creatorId, hyperParams, modelArch, archName } = jobDetails || {};
  const { data: mediaGroupBySplit, isLoading: isDatasetStatsLoading } = useGetDatasetStatsQuery({
    selectOptions: defaultSelectOptions,
    groupOptions: [DatasetGroupOptions.SPLIT],
    version: datasetVersion?.version,
  });
  const isTrainingLogEnabled = useIsFeatureEnabledAndMayHideForSnowflake().trainingLogs;
  const [jobLogs, jobLogsLoading, jobLogsError] = useMaglevGetJobLogsApi(
    project ? { projectId: project.id, jobId: model.id } : undefined,
  );
  const [imageDetailsDialogOpen, setImageDetailsDialogOpen] = useState<boolean>(false);

  const trainingMetrics = getFormattedMetrics(metrics, DISPLAY_LOSS_METRICS);
  const validationMetrics = getFormattedMetrics(metrics, DISPLAY_VALIDATION_METRICS);

  const mediaSplitFormatted = useMemo(() => {
    return Object.entries(SplitMapping).map(([split, name]) => ({
      name,
      value:
        mediaGroupBySplit?.find(item => {
          if (item.split === null && split === MediaSplitName.Unassigned) return true;
          return item.split === split;
        })?.count ?? 0,
    }));
  }, [mediaGroupBySplit]);

  const trainedAt = createdAt ? new Date(createdAt).toLocaleString() : '';
  const users = useUserInfo(creatorId);

  const epochs = hyperParams?.model.learningParams.epochs;
  const modelSize = hyperParams?.model.backboneParams.name;

  const transform = hyperParams?.preprocessingConfig;
  const transformMap: Record<string, any> =
    (transform &&
      transform.reduce((acc: Record<string, any>, cur: any) => {
        return { ...acc, ...cur };
      }, {} as Record<string, any>)) ??
    {};

  const augmentationConfig = hyperParams?.augmentationConfig;

  const onViewImagesClicked = () => {
    setImageDetailsDialogOpen(true);
  };

  return isJobDetailLoading ? (
    <LoadingProgress />
  ) : (
    <Box display="flex" flexDirection="column">
      {!hideCharts && (
        <Box display="flex" flexDirection="row" flexWrap="wrap" className={styles.chartsContainer}>
          <Box
            display="flex"
            flexDirection="column"
            alignItems="flex-start"
            justifyContent="flex-start"
            className={styles.chartAndTitleContainer}
          >
            <Box className={styles.chartTitle}>
              <Typography variant="body_bold">{t('Loss Chart')}</Typography>
            </Box>
            {!isEmpty(trainingMetrics) ? (
              <Box display="flex" flexWrap="wrap" className={styles.metricsChartContainer}>
                {Object.entries(trainingMetrics).map(([metricGroupName, metrics]) => {
                  return (
                    <Box maxWidth={CHART_WIDTH} mt={2} key={metricGroupName} mr={4} width="100%">
                      <Paper>
                        <MetricsChart
                          name={metricGroupName}
                          metrics={metrics}
                          startTime={Number(createdAt)}
                        />
                      </Paper>
                    </Box>
                  );
                })}
              </Box>
            ) : (
              <Box mt={2}>
                <Typography color="textSecondary">{t('No metrics available')}</Typography>
              </Box>
            )}
          </Box>
          <Box
            display="flex"
            flexDirection="column"
            alignItems="flex-start"
            justifyContent="flex-start"
            className={styles.chartAndTitleContainer}
          >
            <Box className={styles.chartTitle}>
              <Typography variant="body_bold">{t('Validation Chart')}</Typography>
            </Box>
            {!isEmpty(validationMetrics) ? (
              <Box display="flex" flexWrap="wrap" className={styles.metricsChartContainer}>
                {Object.entries(validationMetrics).map(([metricGroupName, metrics]) => {
                  return (
                    <Box maxWidth={CHART_WIDTH} mt={2} key={metricGroupName} mr={4} width="100%">
                      <Paper>
                        <MetricsChart
                          name={metricGroupName}
                          metrics={metrics}
                          startTime={Number(createdAt)}
                          customConfigs={{
                            axisTitleY:
                              metricGroupName === 'f1'
                                ? upperFirst(metricGroupName)
                                : metricGroupName,
                          }}
                        />
                      </Paper>
                    </Box>
                  );
                })}
              </Box>
            ) : (
              <Box mt={2}>
                <Typography color="textSecondary">
                  {t(
                    'No data available. The Validation Chart only displays when the Dev set has at least 6 images.',
                  )}
                </Typography>
              </Box>
            )}
          </Box>
        </Box>
      )}
      <Table>
        {/* trained from */}
        <TableBody>
          <TableRow>
            <TableCell
              className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
              padding="none"
              component="th"
              scope="row"
            >
              {t('Trained from')}
            </TableCell>
            <TableCell className={styles.commonTableCell}>
              {datasetVersion?.name ?? t('Deleted data snapshot')}
            </TableCell>
          </TableRow>
          {/* Split */}
          {datasetVersion && (
            <TableRow>
              <TableCell
                className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
                padding="none"
                component="th"
                scope="row"
              >
                {t('Split')}
              </TableCell>
              <TableCell className={styles.commonTableCell}>
                {isDatasetStatsLoading ? (
                  <LoadingProgress size={20} alignItems="flex-start" />
                ) : (
                  mediaGroupBySplit && (
                    <Box
                      display="flex"
                      flexDirection="row"
                      alignItems="center"
                      justifyContent="flex-start"
                    >
                      <Box className={cx(styles.infoBlock, styles.statsInfo)}>
                        <Box className={styles.detailTitle}>
                          {t('Split distribution on labeled images')}
                        </Box>
                        <DistributionChart
                          distributionData={[
                            {
                              distributions: mediaSplitFormatted.map(item => ({
                                distributor: item.name,
                                value: item.value,
                              })),
                            },
                          ]}
                          distributorColorMap={splitColorMap}
                          hideLabel
                          bandWidth={8}
                          compact
                          size="small"
                        />
                        {mediaSplitFormatted.map((item, index) => (
                          <SnapshotDetailItem key={index} item={item} index={index} />
                        ))}
                      </Box>
                      <Button
                        variant="outlined"
                        id={'training-info-panel-view-images'}
                        onClick={onViewImagesClicked}
                      >
                        {t('View Images')}
                      </Button>
                    </Box>
                  )
                )}
              </TableCell>
            </TableRow>
          )}
          {/* trained at */}
          <TableRow>
            <TableCell
              className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
              padding="none"
              component="th"
              scope="row"
            >
              {t('Trained at')}
            </TableCell>
            <TableCell className={styles.commonTableCell}>{trainedAt}</TableCell>
          </TableRow>
          {/* trained by */}
          <TableRow>
            <TableCell
              className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
              padding="none"
              component="th"
              scope="row"
            >
              {t('Trained by')}
            </TableCell>
            <TableCell className={styles.commonTableCell}>
              <Box>{users ? `${users.name} ${users.lastName}` : null}</Box>
            </TableCell>
          </TableRow>
          {/* Hyperparameter */}
          <TableRow>
            <TableCell
              className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
              padding="none"
              component="th"
              scope="row"
              rowSpan={2}
            >
              {t('Hyperparameter')}
            </TableCell>
            <TableCell className={styles.commonTableCell}>
              {t('Epoch {{epochs}}', {
                epochs: (
                  <Typography display="inline" className={styles.tableCellContentValue}>
                    {t('{{epochs}}', { epochs })}
                  </Typography>
                ),
              })}
            </TableCell>
          </TableRow>
          <TableRow>
            <TableCell className={styles.commonTableCell}>
              {t('Model size {{modelSizeArchText}}', {
                modelSizeArchText: (
                  <Typography display="inline" className={styles.tableCellContentValue}>
                    {archName ??
                      getModelArchModelSizeDisplayName(modelArch as ModelArch, modelSize)}
                  </Typography>
                ),
              })}
            </TableCell>
          </TableRow>
          {/* Transform */}
          <TransformTable transformMap={transformMap} />
          {/* Augmentation */}
          {augmentationConfig && (
            <TableRow>
              <TableCell
                className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
                padding="none"
                component="th"
                scope="row"
              >
                {t('Augmentation')}
              </TableCell>
              <TableCell className={styles.commonTableCell}>
                <AugmentationConfigTable config={augmentationConfig} />
              </TableCell>
            </TableRow>
          )}
          {isTrainingLogEnabled && isModelTrainingFailed(model.status) && (
            <TableRow>
              <TableCell
                className={cx(styles.tableFirstLevelHeader, styles.commonTableCell)}
                padding="none"
                component="th"
                scope="row"
              >
                {t('Logs')}
              </TableCell>
              <TableCell className={styles.commonTableCell}>
                <ApiResponseLoader response={jobLogs} loading={jobLogsLoading} error={jobLogsError}>
                  {loadedResponse => (
                    <LazyLog
                      text={loadedResponse}
                      height={600}
                      lineClassName="cy-view-details-job-log"
                    />
                  )}
                </ApiResponseLoader>
              </TableCell>
            </TableRow>
          )}
        </TableBody>
      </Table>
      <ImageDetailsDialog
        open={imageDetailsDialogOpen}
        onClose={() => {
          setImageDetailsDialogOpen(false);
        }}
        modelId={model.id}
        datasetVersion={datasetVersion?.version}
      />
    </Box>
  );
};

export default TrainingInformationPanel;
