import React, { useMemo } from 'react';
import { Box, makeStyles } from '@material-ui/core';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { isModelTrainingHasLearningCurve } from '../../../../store/projectModelInfoState/utils';
import { LineChart, Typography } from '@clef/client-library';
import { round } from 'lodash';
import { Skeleton } from '@material-ui/lab';
import { useJobDetailForCurrentProject } from '@/serverStore/jobs';
import { useModelStatusQuery } from '@/serverStore/projectModels';
import { MetricLabel } from '@clef/shared/types';

const useStyles = makeStyles(theme => ({
  lineChartBox: {
    position: 'relative',
    paddingTop: '8px',
    border: '1px solid #E6E7E9',
    boxShadow: '0px 1px 2px 1px rgba(48, 55, 79, 0.08)',
    // box-shadow: 0px 1px 2px 0px rgba(48, 55, 79, 0.16);
    borderRadius: 10,
    minHeight: 240,
    textAlign: 'center',
    marginBottom: theme.spacing(5),
  },
  lineChartBoxWithoutMarginBottom: {
    position: 'relative',
    padding: theme.spacing(2, 0, 2, 2),
    border: '1px solid #E6E7E9',
    boxShadow: '0px 1px 2px 1px rgba(48, 55, 79, 0.08)',
    // box-shadow: 0px 1px 2px 0px rgba(48, 55, 79, 0.16);
    borderRadius: 10,
    minHeight: 240,
    textAlign: 'center',
  },
  lineChartYTitle: {
    position: 'absolute',
    left: 0,
    top: '50%',
    transform: 'translateY(-50%) rotate(-90deg)',
  },
  topRightTag: {
    position: 'absolute',
    top: theme.spacing(2),
    right: theme.spacing(2),
  },
  legend: {
    borderRadius: 2,
    marginRight: theme.spacing(2),
    width: 12,
    height: 12,
  },
}));

export type LearningCurveChartProps = {
  modelId: string | undefined;
  topRightTag?: React.ReactNode;
  isTraining?: boolean;
  aspectRatio?: number;
  hideMarginBottom?: boolean;
  hideLegends?: boolean;
  hideValidationCurve?: boolean;
};

const LearningCurveChart: React.FC<LearningCurveChartProps> = ({
  topRightTag,
  isTraining = false,
  aspectRatio = 1.25,
  hideMarginBottom = false,
  modelId: selectedModelId,
  hideLegends = true,
  hideValidationCurve = true,
}) => {
  const styles = useStyles();
  const { id: projectId } = useGetSelectedProjectQuery().data ?? {};
  const { data: modelStatus } = useModelStatusQuery(projectId, selectedModelId);
  const hasLearningCurve = isModelTrainingHasLearningCurve(modelStatus?.status);
  const { data: jobDetails } = useJobDetailForCurrentProject(
    hasLearningCurve ? selectedModelId : undefined,
    modelStatus?.status,
  );

  const chartData = useMemo(() => {
    const metrics = jobDetails?.metrics;
    if (!metrics) {
      return [];
    }
    const lossMetrics = metrics.find(item => item.name === MetricLabel.loss_train);
    const lossValMetrics = metrics.find(
      item =>
        item.name === MetricLabel.mAP_val ||
        item.name === MetricLabel.mIoU_val ||
        item.name === MetricLabel.f1_val,
    );

    // for training loss, x is the epoch, which is the index of the array
    const lossMetricsData = lossMetrics
      ? {
          ...lossMetrics,
          values: lossMetrics.values.map((value, index) => ({ ...value, x: index })),
          color: '#FD6F8E',
        }
      : undefined;

    // validation always happen after training. we have to find which epoch it is from training loss timestamp.
    // since the timestamps may not exactly match with loss, we need to find the nearest epoch.
    const lossValMetricsData = lossValMetrics
      ? {
          ...lossValMetrics,
          values: lossValMetrics.values.map((value, index) => {
            let newIndex = (lossMetricsData?.values ?? []).findIndex(
              ({ timestamp }) => timestamp >= value.timestamp,
            );
            if (newIndex < 0) {
              newIndex = (lossMetricsData?.values.length ?? 0) + index;
            }
            return { ...value, x: newIndex };
          }),
          color: '#FF9800',
        }
      : undefined;

    return hideValidationCurve
      ? [lossMetricsData!].filter(Boolean)
      : [lossMetricsData!, lossValMetricsData!].filter(Boolean);
  }, [jobDetails]);

  const epochs = useMemo(() => {
    const hyperParams = jobDetails?.hyperParams;
    if (!hyperParams) {
      return chartData[0]?.values.length;
    }
    return hyperParams.model.learningParams.epochs;
  }, [chartData, jobDetails]);

  if (!chartData.length || !epochs) {
    return (
      <Box marginBottom={5}>
        <Skeleton variant="rect" height={240} />
      </Box>
    );
  }
  return (
    <Box
      className={hideMarginBottom ? styles.lineChartBoxWithoutMarginBottom : styles.lineChartBox}
    >
      {topRightTag && <Box className={styles.topRightTag}>{topRightTag}</Box>}
      <Box position="relative" style={{ transform: 'translateX(-10px)' }}>
        <Box display="flex" alignItems="center" flexDirection={'row'}>
          {/* <Box className={styles.lineChartYTitle}>{t('Error')}</Box> */}
          <LineChart
            labelFormatterX={v => round(v)}
            labelFormatterY={v => round(v, 2)}
            axisTitleX=""
            axisTitleY={t('Error')}
            data={chartData}
            showDataPoints={false}
            lastPointPulsing={isTraining}
            aspectRatio={aspectRatio}
            minX={0}
            maxX={Math.max(epochs, ...chartData.map(item => item.values.length))}
          />
          {!hideLegends && (
            <Box pr={4} ml={4}>
              {chartData.map(({ name, color }) => (
                <Box key={name} display="flex" alignItems="baseline" my={1}>
                  <span className={styles.legend} style={{ backgroundColor: color }} />
                  <Typography variant="body2">
                    {name.endsWith('_val') ? t('Dev') : t('Train')}
                  </Typography>
                </Box>
              ))}
            </Box>
          )}
        </Box>
      </Box>
    </Box>
  );
};

export default LearningCurveChart;
