import { resizeCanvas } from '@clef/client-library';
import { useMemo } from 'react';
import useImage from 'use-image';
import { imageToImageData } from '../../DataBrowser/utils';

export const enum ColorMap {
  inferno = 'inferno',
  hotR = 'hotR',
}

interface Props {
  imageSrc: string;
  attentionHeatmapSrc: string | undefined | null;
  opacity?: number;
  confidenceThreshold: number;
  colorMap?: ColorMap;
}

const COLORMAPS = {
  [ColorMap.inferno]: [
    [0, 0, 3],
    [22, 11, 57],
    [65, 9, 103],
    [106, 23, 110],
    [147, 37, 103],
    [187, 55, 84],
    [220, 80, 57],
    [243, 119, 25],
    [251, 164, 10],
    [245, 215, 69],
    [252, 254, 164],
  ],
  [ColorMap.hotR]: [
    [255, 255, 255],
    [255, 255, 156],
    [255, 255, 54],
    [255, 225, 0],
    [255, 157, 0],
    [255, 89, 0],
    [255, 23, 0],
    [210, 0, 0],
    [144, 0, 0],
    [76, 0, 0],
    [10, 0, 0],
  ],
};

const GRADCAM_HEATMAP_OPACITY = 0.4;

const getRGBBasedOnHeatMapData = (heatMapValue: number, colormap: ColorMap = ColorMap.inferno) => {
  // Get a value between 0 and 10
  const discretizedValue = Math.floor(heatMapValue / 25);
  // Get colormap
  const colormapValues = COLORMAPS[colormap];
  // Get corresponding color for value
  const colorArray = colormapValues[discretizedValue];
  return { r: colorArray[0], g: colorArray[1], b: colorArray[2] };
};

/**
 * parse gradcam image to get heatmap result.
 * gradcam image received is 1 channel PNG image
 * @returns offscreen canvas of the heatmap image
 */
export const useHeatmapCanvas = ({
  imageSrc = '',
  attentionHeatmapSrc = '',
  opacity = GRADCAM_HEATMAP_OPACITY,
  confidenceThreshold = 0,
  colorMap = ColorMap.inferno,
}: Props) => {
  // Gradcam image
  const [attentionHeatmap] = useImage(attentionHeatmapSrc || '', 'use-credentials');

  const attentionHeatmapWidth = attentionHeatmap?.width ?? 0;
  const attentionHeatmapHeight = attentionHeatmap?.height ?? 0;

  // Main image
  const [image] = useImage(imageSrc || '', 'use-credentials');

  const imageWidth = image?.width ?? 0;
  const imageHeight = image?.height ?? 0;

  const classificationHeatmapArray = useMemo(
    () =>
      imageToImageData(attentionHeatmap, {
        width: attentionHeatmapWidth,
        height: attentionHeatmapHeight,
      }),
    [attentionHeatmap, attentionHeatmapWidth, attentionHeatmapHeight],
  );

  return useMemo(() => {
    if (!attentionHeatmap || !image) {
      return;
    }

    const offscreenHeatmapCanvas = new OffscreenCanvas(
      attentionHeatmapWidth,
      attentionHeatmapHeight,
    );

    const heatMapContext = offscreenHeatmapCanvas.getContext('2d', {
      desynchronized: true,
    })!;

    heatMapContext.imageSmoothingEnabled = false;
    const modifiedHeatMapImage = heatMapContext.getImageData(
      0,
      0,
      attentionHeatmapWidth,
      attentionHeatmapHeight,
    );

    // Create 3d channel image
    for (let i = 0; i < classificationHeatmapArray.length; i += 4) {
      const heatMapValue = classificationHeatmapArray[i];

      // get the specified RGB value based on heatMapValue
      const { r, g, b } = getRGBBasedOnHeatMapData(heatMapValue, colorMap);

      modifiedHeatMapImage.data[i] = r;
      modifiedHeatMapImage.data[i + 1] = g;
      modifiedHeatMapImage.data[i + 2] = b;
      modifiedHeatMapImage.data[i + 3] =
        255 * (heatMapValue < confidenceThreshold * 255 ? 0 : opacity);
    }
    heatMapContext.putImageData(modifiedHeatMapImage, 0, 0);

    // resize it to original image
    return resizeCanvas(offscreenHeatmapCanvas, { width: imageWidth, height: imageHeight });
  }, [
    attentionHeatmap,
    image,
    attentionHeatmapWidth,
    attentionHeatmapHeight,
    imageWidth,
    imageHeight,
    classificationHeatmapArray,
    confidenceThreshold,
    opacity,
    colorMap,
  ]);
};
