Files
ANSCORE/engines/ONNXEngine/ONNXSAM3.h

142 lines
5.5 KiB
C++

#ifndef ONNXSAM3_H
#define ONNXSAM3_H
#pragma once
#include "onnxruntime_cxx_api.h"
#include "opencv2/opencv.hpp"
#include "EPLoader.h"
#include <string>
#include <vector>
#include <cstdint>
#ifndef ONNXENGINE_API
#ifdef ENGINE_EXPORTS
#define ONNXENGINE_API __declspec(dllexport)
#else
#define ONNXENGINE_API __declspec(dllimport)
#endif
#endif
namespace ANSCENTER
{
/// Result from SAM3 segmentation inference.
struct SAM3Result
{
cv::Rect box; // bounding box
float confidence = 0.0f; // detection confidence score
cv::Mat mask; // binary mask (within bounding box)
std::vector<cv::Point2f> polygon; // simplified contour polygon
};
/// SAM3 engine using 3 separate ONNX Runtime sessions:
/// 1) Image encoder — produces backbone features + position encodings
/// 2) Language encoder — produces text attention mask + text features
/// 3) Decoder — combines image features + language features → boxes, scores, masks
///
/// This architecture avoids the CUDA EP crash that occurs with the
/// monolithic 3.3 GB model, since each sub-model is under 2 GB.
class ONNXENGINE_API ONNXSAM3
{
public:
/// Construct from a model folder containing:
/// anssam3_image_encoder.onnx (+.onnx_data)
/// anssam3_language_encoder.onnx (+.onnx_data)
/// anssam3_decoder.onnx (+.onnx_data)
explicit ONNXSAM3(const std::string& modelFolder,
EngineType engineType,
unsigned int num_threads = 1);
~ONNXSAM3();
// Non-copyable
ONNXSAM3(const ONNXSAM3&) = delete;
ONNXSAM3& operator=(const ONNXSAM3&) = delete;
/// Set text prompt (runs language encoder, caches results).
void setPrompt(const std::vector<int64_t>& inputIds,
const std::vector<int64_t>& attentionMask);
/// Run inference: image encoder + decoder with cached language features.
/// @param mat Input image (BGR).
/// @param segThreshold Score threshold for filtering detections.
/// @return SAM3Result objects with boxes/masks/polygons.
std::vector<SAM3Result> detect(const cv::Mat& mat,
float segThreshold = 0.5f);
int getInputSize() const { return m_inputSize; }
int getTokenLength() const { return m_tokenLength; }
int getMaskH() const { return m_maskH; }
int getMaskW() const { return m_maskW; }
bool isPromptSet() const { return m_promptSet; }
private:
/// Bundle holding one ORT session and its I/O names.
struct SessionBundle
{
Ort::Session* session = nullptr;
std::vector<std::string> inputNames_; // owns strings
std::vector<const char*> inputNames; // c_str pointers
std::vector<std::string> outputNames_; // owns strings
std::vector<const char*> outputNames; // c_str pointers
~SessionBundle();
};
/// Create one session bundle (EP attach, external data, GPU→CPU fallback).
/// @param forceCPU When true, skip GPU EP and always use CPU.
/// @param optLevel Graph optimization level for this session.
void createSessionBundle(SessionBundle& bundle,
const std::string& onnxPath,
const std::string& label,
bool forceCPU = false,
GraphOptimizationLevel optLevel = GraphOptimizationLevel::ORT_ENABLE_ALL);
/// Image preprocessing: BGR → RGB, resize to 1008, HWC→CHW, uint8.
void preprocessImage(const cv::Mat& mat, std::vector<uint8_t>& buffer);
/// Convert decoder outputs (boxes, scores, masks) → SAM3Result objects.
std::vector<SAM3Result> postprocessResults(
const float* boxesData, int numBoxes,
const float* scoresData,
const bool* masksData, int maskH, int maskW,
int origWidth, int origHeight,
float scoreThreshold);
// EP helpers (replicated from BasicOrtHandler)
bool TryAppendCUDA(Ort::SessionOptions& opts);
bool TryAppendDirectML(Ort::SessionOptions& opts);
bool TryAppendOpenVINO(Ort::SessionOptions& opts);
// ORT environment (shared across all 3 sessions)
Ort::Env* m_env = nullptr;
Ort::MemoryInfo* m_memInfo = nullptr;
// Three session bundles
SessionBundle m_imageEncoder;
SessionBundle m_langEncoder;
SessionBundle m_decoder;
// Engine configuration
EngineType m_engineType;
unsigned int m_numThreads;
std::string m_modelFolder;
// Model dimensions
int m_inputSize = 1008; // image spatial size (H=W)
int m_tokenLength = 32; // text token sequence length
int m_maskH = 0; // output mask height (from decoder output)
int m_maskW = 0; // output mask width (from decoder output)
// Cached language encoder outputs (set by setPrompt)
std::vector<char> m_cachedLangMask; // bool tensor data
std::vector<int64_t> m_cachedLangMaskShape;
std::vector<float> m_cachedLangFeatures; // float32 tensor data
std::vector<int64_t> m_cachedLangFeaturesShape;
bool m_promptSet = false;
};
}
#endif