// The following implementation is based on https://github.com/facebookresearch/segment-anything/blob/main/demo/src/components/helpers/onnxModelAPI.tsx

import type { Tensor } from 'onnxruntime-web';

export interface ModelScale {
  samScale: number;
  height: number;
  width: number;
}

export enum ModelInputClickType {
  Add = 1,
  Remove = 0,
}

export const MODEL_INPUT_CLICK_TYPE_COLOR_LOOKUP: Record<ModelInputClickType, string> = {
  [ModelInputClickType.Add]: '#CFF000',
  [ModelInputClickType.Remove]: '#AD2E24',
};

export interface ModelInput {
  x: number;
  y: number;
  clickType: ModelInputClickType;
}

export interface ModeFeeds {
  tensor: Tensor;
  prompts?: Array<ModelInput>;
  modelScale: ModelScale;
}

// Input images to SAM must be resized so the longest side is 1024
const LONG_SIDE_LENGTH = 1024;

// Reference: https://github.com/facebookresearch/segment-anything/blob/main/demo/src/components/helpers/scaleHelper.tsx
export const getModelScale = (imageHeight: number, imageWidth: number): ModelScale => {
  const samScale = LONG_SIDE_LENGTH / Math.max(imageHeight, imageWidth);
  return { height: imageHeight, width: imageWidth, samScale };
};

// Reference: https://github.com/facebookresearch/segment-anything/blob/main/demo/src/components/helpers/onnxModelAPI.tsx
export const getModelFeeds = async ({ prompts, tensor, modelScale }: ModeFeeds) => {
  const imageEmbedding = tensor;
  let pointCoords;
  let pointLabels;
  let pointCoordsTensor;
  let pointLabelsTensor;

  const { Tensor } = await import('onnxruntime-web');

  // Check there are input click prompts
  if (prompts) {
    const n = prompts.length;

    // If there is no box input, a single padding point with
    // label -1 and coordinates (0.0, 0.0) should be concatenated
    // so initialize the array to support (n + 1) points.
    pointCoords = new Float32Array(2 * (n + 1));
    pointLabels = new Float32Array(n + 1);

    // Add prompts and scale to what SAM expects
    for (let i = 0; i < n; i++) {
      pointCoords[2 * i] = prompts[i].x * modelScale.samScale;
      pointCoords[2 * i + 1] = prompts[i].y * modelScale.samScale;
      pointLabels[i] = prompts[i].clickType;
    }

    // Add in the extra point/label when only prompts and no box
    // The extra point is at (0, 0) with label -1
    pointCoords[2 * n] = 0.0;
    pointCoords[2 * n + 1] = 0.0;
    pointLabels[n] = -1.0;

    // Create the tensor
    pointCoordsTensor = new Tensor('float32', pointCoords, [1, n + 1, 2]);
    pointLabelsTensor = new Tensor('float32', pointLabels, [1, n + 1]);
  }
  const imageSizeTensor = new Tensor('float32', [modelScale.height, modelScale.width]);

  if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) return;

  // There is no previous mask, so default to an empty tensor
  const maskInput = new Tensor('float32', new Float32Array(256 * 256), [1, 1, 256, 256]);
  // There is no previous mask, so default to 0
  const hasMaskInput = new Tensor('float32', [0]);

  return {
    image_embeddings: imageEmbedding,
    point_coords: pointCoordsTensor,
    point_labels: pointLabelsTensor,
    orig_im_size: imageSizeTensor,
    mask_input: maskInput,
    has_mask_input: hasMaskInput,
  };
};
