import { InferenceSession, Tensor } from 'onnxruntime-web';
import { useState, useEffect, useCallback, useRef } from 'react';
import { loadNpyTensor } from '../../../utils/npy';
import { getModelScale, ModelInput, getModelFeeds } from '../../../utils/sam_model';
import { ModelScale } from '../creators/useSAMAnnotationCreator';
import { filterSmallArea } from '../utils';

export interface UseSAMProps {
  imageSize:
    | {
        height: number;
        width: number;
      }
    | null
    | undefined;
  onnxModel: ArrayBuffer | undefined;
  imageEmbedding: ArrayBuffer | undefined;
  onError?: (error: any) => void;
}

export const useSAM = ({ imageSize, onnxModel, imageEmbedding, onError }: UseSAMProps) => {
  const [modelSession, setModelSession] = useState<InferenceSession | null>(null);
  const [imageEmbeddingTensor, setImageEmbeddingTensor] = useState<Tensor | null>(null);

  // The ONNX model expects the input to be rescaled to 1024.
  // The modelScale state variable keeps track of the scale values.
  const [modelScale, setModelScale] = useState<ModelScale | null>(null);

  const onErrorRef = useRef(onError);
  onErrorRef.current = onError;

  useEffect(() => {
    if (imageSize?.height !== undefined && imageSize?.width !== undefined) {
      setModelScale(getModelScale(imageSize.height, imageSize.width));
    }
  }, [imageSize?.height, imageSize?.width]);

  // Load SAM model
  useEffect(() => {
    if (onnxModel && !modelSession) {
      import('onnxruntime-web').then(({ InferenceSession }) => {
        InferenceSession.create(onnxModel)
          .then(session => {
            setModelSession(session);
          })
          .catch(error => {
            onErrorRef.current?.(error);
          });
      });
    }
  }, [modelSession, onnxModel]);

  // Load image embedding
  useEffect(() => {
    if (imageEmbedding) {
      loadNpyTensor(imageEmbedding, 'float32')
        .then(tensor => {
          setImageEmbeddingTensor(tensor);
        })
        .catch(error => {
          onErrorRef.current?.(error);
        });
    }
  }, [imageEmbedding]);

  // Run the model on the prompt
  const runModel = useCallback(
    async (prompts: ModelInput[] | null, enableSmallAreaFiltering: boolean = false) => {
      if (prompts && modelSession && imageEmbeddingTensor && modelScale) {
        try {
          const feeds = await getModelFeeds({
            tensor: imageEmbeddingTensor,
            prompts,
            modelScale,
          });

          if (feeds) {
            const results = await modelSession.run(feeds);
            const result = results[modelSession.outputNames[0]];
            return enableSmallAreaFiltering ? filterSmallArea(result) : result;
          }
        } catch (error) {
          onErrorRef.current?.(error);
        }
      }

      return null;
    },
    [imageEmbeddingTensor, modelSession, modelScale],
  );

  return { runModel };
};
