#ifndef ONNXSAM3_H #define ONNXSAM3_H #pragma once #include "onnxruntime_cxx_api.h" #include "opencv2/opencv.hpp" #include "EPLoader.h" #include #include #include #ifndef ONNXENGINE_API #ifdef ENGINE_EXPORTS #define ONNXENGINE_API __declspec(dllexport) #else #define ONNXENGINE_API __declspec(dllimport) #endif #endif namespace ANSCENTER { /// Result from SAM3 segmentation inference. struct SAM3Result { cv::Rect box; // bounding box float confidence = 0.0f; // detection confidence score cv::Mat mask; // binary mask (within bounding box) std::vector polygon; // simplified contour polygon }; /// SAM3 engine using 3 separate ONNX Runtime sessions: /// 1) Image encoder — produces backbone features + position encodings /// 2) Language encoder — produces text attention mask + text features /// 3) Decoder — combines image features + language features → boxes, scores, masks /// /// This architecture avoids the CUDA EP crash that occurs with the /// monolithic 3.3 GB model, since each sub-model is under 2 GB. class ONNXENGINE_API ONNXSAM3 { public: /// Construct from a model folder containing: /// anssam3_image_encoder.onnx (+.onnx_data) /// anssam3_language_encoder.onnx (+.onnx_data) /// anssam3_decoder.onnx (+.onnx_data) explicit ONNXSAM3(const std::string& modelFolder, EngineType engineType, unsigned int num_threads = 1); ~ONNXSAM3(); // Non-copyable ONNXSAM3(const ONNXSAM3&) = delete; ONNXSAM3& operator=(const ONNXSAM3&) = delete; /// Set text prompt (runs language encoder, caches results). void setPrompt(const std::vector& inputIds, const std::vector& attentionMask); /// Run inference: image encoder + decoder with cached language features. /// @param mat Input image (BGR). /// @param segThreshold Score threshold for filtering detections. /// @return SAM3Result objects with boxes/masks/polygons. std::vector detect(const cv::Mat& mat, float segThreshold = 0.5f); int getInputSize() const { return m_inputSize; } int getTokenLength() const { return m_tokenLength; } int getMaskH() const { return m_maskH; } int getMaskW() const { return m_maskW; } bool isPromptSet() const { return m_promptSet; } private: /// Bundle holding one ORT session and its I/O names. struct SessionBundle { Ort::Session* session = nullptr; std::vector inputNames_; // owns strings std::vector inputNames; // c_str pointers std::vector outputNames_; // owns strings std::vector outputNames; // c_str pointers ~SessionBundle(); }; /// Create one session bundle (EP attach, external data, GPU→CPU fallback). /// @param forceCPU When true, skip GPU EP and always use CPU. /// @param optLevel Graph optimization level for this session. void createSessionBundle(SessionBundle& bundle, const std::string& onnxPath, const std::string& label, bool forceCPU = false, GraphOptimizationLevel optLevel = GraphOptimizationLevel::ORT_ENABLE_ALL); /// Image preprocessing: BGR → RGB, resize to 1008, HWC→CHW, uint8. void preprocessImage(const cv::Mat& mat, std::vector& buffer); /// Convert decoder outputs (boxes, scores, masks) → SAM3Result objects. std::vector postprocessResults( const float* boxesData, int numBoxes, const float* scoresData, const bool* masksData, int maskH, int maskW, int origWidth, int origHeight, float scoreThreshold); // EP helpers (replicated from BasicOrtHandler) bool TryAppendCUDA(Ort::SessionOptions& opts); bool TryAppendDirectML(Ort::SessionOptions& opts); bool TryAppendOpenVINO(Ort::SessionOptions& opts); // ORT environment (shared across all 3 sessions) Ort::Env* m_env = nullptr; Ort::MemoryInfo* m_memInfo = nullptr; // Three session bundles SessionBundle m_imageEncoder; SessionBundle m_langEncoder; SessionBundle m_decoder; // Engine configuration EngineType m_engineType; unsigned int m_numThreads; std::string m_modelFolder; // Model dimensions int m_inputSize = 1008; // image spatial size (H=W) int m_tokenLength = 32; // text token sequence length int m_maskH = 0; // output mask height (from decoder output) int m_maskW = 0; // output mask width (from decoder output) // Cached language encoder outputs (set by setPrompt) std::vector m_cachedLangMask; // bool tensor data std::vector m_cachedLangMaskShape; std::vector m_cachedLangFeatures; // float32 tensor data std::vector m_cachedLangFeaturesShape; bool m_promptSet = false; }; } #endif