import { DefectDistributionWithAssignment } from '@clef/shared/utils/auto_split_core_algorithm';
import { Defect } from '@clef/shared/types';
import produce from 'immer';
import { PieChartProps, DistributionChartProps, DistributionType } from '@clef/client-library';
import { splitColors } from '@clef/client-library';

type AggregateMappingType = {
  [groupBy: string]: { [column: string]: number };
};

export const supportedAutoSplit = ['train', 'dev', 'test'];

// [ { count: 59, defect_distribution: { 100: 1, 101: 0 }, assignment: [41, 12, 6] },
//   { count: 81, defect_distribution: { 100: 0, 101: 1 }, assignment: [49, 16, 16] },];
// =>
// { defectA: { train: 41, dev: 12, test: 6 }, defectB: { train: 49, dev: 16, test: 16 } }
export const distributionAssignmentAggregateByDefect = (
  defectDistributionWithAssignment: DefectDistributionWithAssignment[],
  defects: Defect[],
): AggregateMappingType => {
  // first transform to AggregateMappingType
  return defectDistributionWithAssignment.reduce(
    (acc, { defect_distribution, assignment }) => {
      return produce(acc, draftState => {
        defects.forEach(({ id, name }) => {
          if (defect_distribution?.[id]) {
            ['train', 'dev', 'test'].forEach((split, index) => {
              draftState[name][split] += assignment[index] * defect_distribution[id];
            });
          }
        });
      });
    },
    defects.reduce(
      (acc, defect) => ({
        ...acc,
        [defect.name]: { train: 0, dev: 0, test: 0 },
      }),
      {} as AggregateMappingType,
    ),
  );
};

// [ { count: 59, defect_distribution: { 100: 1, 101: 0 }, assignment: [41, 12, 6] },
//   { count: 81, defect_distribution: { 100: 0, 101: 1 }, assignment: [49, 16, 16] },];
// =>
// { train: { defectA: 41, defectB: 49 }, dev: { defectA: 12, defectB: 16 }, test: { defectA: 6, defectB: 16} }
export const distributionAssignmentAggregateBySplit = (
  defectDistributionWithAssignment: DefectDistributionWithAssignment[],
  defects: Defect[],
): AggregateMappingType => {
  const initiatedMapping: { [distributions: string]: number } = defects.reduce(
    (acc, defect) => ({ ...acc, [defect.name]: 0 }),
    {},
  );
  return defectDistributionWithAssignment.reduce(
    (acc, { defect_distribution, assignment }) => {
      return produce(acc, draftState => {
        defects.forEach(({ id, name }) => {
          if (defect_distribution?.[id]) {
            ['train', 'dev', 'test'].forEach((split, index) => {
              draftState[split][name] += assignment[index] * defect_distribution[id];
            });
          }
        });
      });
    },
    {
      train: { ...initiatedMapping },
      dev: { ...initiatedMapping },
      test: { ...initiatedMapping },
    } as AggregateMappingType,
  );
};

// AggregateMappingType => DistributionChartProps['chartData']
export const aggregateMappingTypeToDistributionList = (
  aggregatedMapping: AggregateMappingType,
  captionMapping: { [key: string]: React.ReactNode } = {},
): DistributionChartProps['distributionData'] =>
  Object.entries(aggregatedMapping).reduce((acc, [groupedBy, distributionsMapping]) => {
    return [
      ...acc,
      {
        name: groupedBy,
        distributions: Object.entries(distributionsMapping).reduce(
          (accInner, [distributorName, distributionValue]) => [
            ...accInner,
            { distributor: distributorName, value: distributionValue },
          ],
          [] as DistributionType[],
        ),
        caption: captionMapping[groupedBy],
      },
    ];
  }, [] as DistributionChartProps['distributionData']);

// [ { count: 59, defect_distribution: { 100: 1, 101: 0 }, assignment: [41, 12, 6] },
//   { count: 81, defect_distribution: { 100: 0, 101: 1 }, assignment: [49, 16, 16] },];
// => PieChartProps['chartData']
export const distributionAssignmentToSplitStats = (
  defectDistributionWithAssignment: DefectDistributionWithAssignment[],
): PieChartProps['chartData'] =>
  defectDistributionWithAssignment.reduce(
    (acc, { assignment }) => {
      return produce(acc, draftState => {
        draftState.forEach((_, index) => {
          draftState[index].value += assignment[index];
        });
      });
    },
    ['train', 'dev', 'test'].map((name, index) => ({
      name,
      color: splitColors[index % splitColors.length],
      value: 0,
    })),
  );
