export type DefectDistribution = {
  count: number;
  defect_distribution: { [defectId: number]: number };
};

export type DefectDistributionWithAssignment = {
  count: number;
  // final DefectDistributionWithAssignment could have null for media with no defects, not used in th core algorithm
  defect_distribution: { [defectId: number]: number } | null;
  assignment: [number, number, number];
};

export type AggregateMappingType = {
  [groupBy: string]: { [column: string]: number };
};

// a very small EPSILON to add to targetDistributions to avoid dividing by zero
const EPSILON = 0.00001;

/**
 * Core algorithm for splitting media to train/dev/test so that the distribution of defects are close to
 * given desired distribution.
 * @param defectIds - list of defect ids
 * @param defectDistributions - list of all defect distributions and corresponding media count
 * @param targetSplitPerDefect - the target split (in percentage) of train/dev/test for each defect
 * @returns defectDistributions with assignment to train/dev/test per distribution
 * @example
 * // input
 * defects =  [100, 101];
 * defectDistributions = [
  { count: 59, defect_distribution: { 100: 1, 101: 0 } },
  { count: 81, defect_distribution: { 100: 0, 101: 1 } }
  ];
 * targetSplitPerDefect = { 100: [70, 20, 10], 101: [60, 20, 10] };
 * // output
 * autoSplitCoreAlgorithm(defects, defectDistributions, defectTargetDistributions) = [
  { count: 59, defect_distribution: { 100: 1, 101: 0 }, assignment: [41, 12, 6] },
  { count: 81, defect_distribution: { 100: 0, 101: 1 }, assignment: [49, 16, 16] }
  ];
 */
export const autoSplitCoreAlgorithm = (
  defectIds: number[],
  defectDistributions: DefectDistribution[],
  targetSplitPerDefect: {
    [defectId: number]: number[];
  },
): DefectDistributionWithAssignment[] => {
  const finalAssignmentPerDistribution: [number, number, number][] = Array.from(
    { length: defectDistributions.length },
    () => [0, 0, 0],
  );

  // Sort defectDistributions based on the number of total defects. In case of a tie, use the number of distinct defect
  const defectDistributionsSorted = defectDistributions.slice().sort((dis1, dis2) => {
    const totalDefect1 = Object.values(dis1.defect_distribution).reduce(
      (acc, value) => acc + (value ?? 0),
      0,
    );
    const totalDistinctDefect1 = Object.values(dis1.defect_distribution).reduce(
      (acc, value) => acc + +!!value,
      0,
    );
    const totalDefect2 = Object.values(dis2.defect_distribution).reduce(
      (acc, value) => acc + (value ?? 0),
      0,
    );
    const totalDistinctDefect2 = Object.values(dis2.defect_distribution).reduce(
      (acc, value) => acc + +!!value,
      0,
    );
    if (totalDefect1 > totalDefect2) {
      return -1;
    }
    if (totalDefect1 < totalDefect2) {
      return 1;
    }
    return totalDistinctDefect1 >= totalDistinctDefect2 ? -1 : 1;
  });

  // There will be 3 rounds, each round split media into 2 groups, targetGroup vs remainingGroup
  // Round 1: targetGroup = train, remainingGroup = dev + test
  // Round 2: from rest of media, targetGroup = dev, remainingGroup = test
  // Round 3: from rest of media, targetGroup = test, remainingGroup = (null), so all assigned to test
  for (let round = 0; round < 3; round++) {
    const targetGroupPercentagePerDefect: { [defectId: number]: number } = defectIds.reduce(
      (acc, id) => {
        const totalRadioDenominator = targetSplitPerDefect[id].reduce(
          // if index < round, we have already done that split, we only look ahead at what's current and remaining
          // radio to assign the remaining media
          (acc, radio, index) => (index >= round ? acc + radio : acc),
          0,
        );
        return {
          ...acc,
          // defectTargetSplit[defect][round] is this round's targetGroup
          [id]: (targetSplitPerDefect[id][round] + EPSILON) / (totalRadioDenominator + EPSILON),
        };
      },
      {} as { [defect: string]: number },
    );
    // While iterating over all the distributions, keep track of accumulated defects already assigned to
    // targetGroup and remainingGroup
    const accCountPerDefect: {
      [defect: string]: { targetGroup: number; remainingGroup: number };
    } = defectIds.reduce(
      (acc, id) => ({
        ...acc,
        [id]: { targetGroup: 0, remainingGroup: 0 },
      }),
      {},
    );

    defectDistributionsSorted.forEach((distribution, defectDistributionsIndex) => {
      // For each distribution, the remaining media count
      // = total count - previous assigned media count
      const remainingMediaCount =
        distribution.count -
        finalAssignmentPerDistribution[defectDistributionsIndex].reduce(
          (acc, assigned) => acc + assigned,
          0,
        );

      // https://docs.google.com/document/d/1Gbt3gCF4-wNx7zKW6P2X_F3YHDdRQ4ozXKRkyR4ch90
      // This is pure math solution to a quadratic equation seeking convex problem
      const optimalNumerator = defectIds.reduce((acc, id) => {
        const backfilledDistributionWeight = distribution.defect_distribution[id] ?? 0;
        return (
          acc +
          (targetGroupPercentagePerDefect[id] *
            (accCountPerDefect[id].remainingGroup +
              remainingMediaCount * backfilledDistributionWeight) -
            (1 - targetGroupPercentagePerDefect[id]) * accCountPerDefect[id].targetGroup) *
            backfilledDistributionWeight
        );
      }, 0);
      const optimalDenominator = defectIds.reduce((acc, id) => {
        const backfilledDistributionWeight = distribution.defect_distribution[id] ?? 0;
        return acc + backfilledDistributionWeight * backfilledDistributionWeight;
      }, 0);
      const optimal = optimalNumerator / optimalDenominator;
      // optimal should be integer and > 0 && < remainingMediaCount
      const optimalRounded = Math.round(Math.max(0, Math.min(optimal, remainingMediaCount)));

      // accumulate defects from this distribution iteration to accCountPerDefect
      defectIds.forEach(id => {
        const backfilledDistributionWeight = distribution.defect_distribution[id] ?? 0;
        accCountPerDefect[id].targetGroup += optimalRounded * backfilledDistributionWeight;
        accCountPerDefect[id].remainingGroup +=
          (remainingMediaCount - optimalRounded) * backfilledDistributionWeight;
      });
      // save the optimalRounded to finalAssignmentPerDistribution
      finalAssignmentPerDistribution[defectDistributionsIndex][round] = optimalRounded;
    });
  }

  return defectDistributionsSorted.map((distribution, disIndex) => ({
    ...distribution,
    assignment: finalAssignmentPerDistribution[disIndex],
  }));
};

/**
 * Constrained by the actual distributions, the result from autoSplitCoreAlgorithm will not precisely match
 * the given targetSplitPerDefect, this function calculates the average percentage deviation of each defect against the target
 * @param defectIds - list of defect ids
 * @param defectDistributionsResults - the result of autoSplitCoreAlgorithm
 * @param targetSplitPerDefect - the target split (in percentage) of train/dev/test for each defect
 * @returns average deviation for each defect. e.g. target = [70,20,10]; result = [72,19,9]; average deviation = (2+1+1)/3 = 1.33
 */
export const calculateDeviationFromCoreAlgorithm = (
  defectIds: number[],
  defectDistributionsResults: DefectDistributionWithAssignment[],
  targetSplitPerDefect: {
    [defectId: number]: number[];
  },
): { [defectId: number]: number } => {
  return defectIds.reduce((acc, id) => {
    const totalDefectCount = defectDistributionsResults.reduce(
      (acc, dis) => acc + (dis.defect_distribution?.[id] ?? 0) * dis.count,
      0,
    );
    const resultSplitDistribution = ['train', 'dev', 'test'].map(
      (_, index) =>
        (defectDistributionsResults.reduce(
          (acc, dis) => acc + dis.assignment[index] * (dis.defect_distribution?.[id] ?? 0),
          0,
        ) /
          totalDefectCount) *
        100,
    );
    const totalDeviationFromTarget = resultSplitDistribution.reduce(
      (acc, result, index) => acc + Math.abs(result - targetSplitPerDefect[id][index]),
      0,
    );
    return { ...acc, [id]: totalDeviationFromTarget / resultSplitDistribution.length };
  }, {});
};

export const splitValueToNoDefectAssignment = (
  mediaCount: number,
  targetSplitPercentage: [number, number, number],
): [number, number, number] => {
  // We use one Math.floor instead of both Math.round here to avoid 2 x.5 values both round to 1 and might
  // cause overflow numbers, and deliberately start from test first because it is more likely to be set
  // to zero and ar more sensitive to math rounding
  const testCount = Math.floor((mediaCount * targetSplitPercentage[2]) / 100);
  const devCount = Math.round((mediaCount * targetSplitPercentage[1]) / 100);
  const trainCount = mediaCount - devCount - testCount;
  return [trainCount, devCount, testCount];
};

// [ { count: 59, defect_distribution: { 100: 1, 101: 0 }, assignment: [41, 12, 6] },
//   { count: 81, defect_distribution: { 100: 0, 101: 1 }, assignment: [49, 16, 16] }]
// =>
// { "{ 100: 1, 101: 0 }": { train: 41, dev: 12, test: 6 }, "{ 100: 0, 101: 1 }": { train: 49, dev: 16, test: 16 } }
export const assignmentToDistributionToAssignSplitMapping = (
  defectDistributionWithAssignment: DefectDistributionWithAssignment[],
): AggregateMappingType => {
  // first transform to AggregateMappingType
  return defectDistributionWithAssignment.reduce(
    (acc, { defect_distribution, assignment }) => ({
      ...acc,
      [JSON.stringify(defect_distribution)]: {
        train: assignment[0],
        dev: assignment[1],
        test: assignment[2],
      },
    }),
    {} as AggregateMappingType,
  );
};
