import { availableModels } from '@/constants'
import { InferenceSession, Tensor, TypedTensor, env } from 'onnxruntime-web'
import npyjs from 'npyjs'

env.wasm.wasmPaths = {
  // 'ort-training-wasm-simd.wasm': '/wasm/ort-training-wasm-simd.wasm',
  'ort-wasm-simd-threaded.wasm': '/wasm/ort-wasm-simd-threaded.wasm',
  'ort-wasm-simd.wasm': '/wasm/ort-wasm-simd.wasm',
  'ort-wasm-threaded.wasm': '/wasm/ort-wasm-threaded.wasm',
  'ort-wasm.wasm': '/wasm/ort-wasm.wasm',
}

export { InferenceSession, Tensor }

const modelCache: Record<string, Promise<InferenceSession>> = {}

export const getModel = async (
  modelName: string,
): Promise<InferenceSession> => {
  if (modelCache[modelName] !== undefined) {
    try {
      const cacheResult = await modelCache[modelName]
      return cacheResult
    } catch (error) {
      console.warn(error)
    }
  }
  modelCache[modelName] = InferenceSession.create(availableModels[modelName])
  return await modelCache[modelName]
}

export const loadTensorEmbedding = async (
  embeddingUrl: string,
): Promise<TypedTensor<'float32'>> => {
  const res = await fetch(embeddingUrl)
  const segResponseData = await res.arrayBuffer()
  let npLoader = new npyjs()
  const npArray = await npLoader.parse(segResponseData)

  // const lowResTensor = new Tensor('float32', npArray.data, [1, 256, 64, 64]);
  const lowResTensor = new Tensor('float32', npArray.data, npArray.shape)
  return lowResTensor
}
