142 lines
5.5 KiB
C++
142 lines
5.5 KiB
C++
#ifndef ONNXSAM3_H
|
|
#define ONNXSAM3_H
|
|
#pragma once
|
|
|
|
#include "onnxruntime_cxx_api.h"
|
|
#include "opencv2/opencv.hpp"
|
|
#include "EPLoader.h"
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
#include <cstdint>
|
|
|
|
#ifndef ONNXENGINE_API
|
|
#ifdef ENGINE_EXPORTS
|
|
#define ONNXENGINE_API __declspec(dllexport)
|
|
#else
|
|
#define ONNXENGINE_API __declspec(dllimport)
|
|
#endif
|
|
#endif
|
|
|
|
namespace ANSCENTER
|
|
{
|
|
/// Result from SAM3 segmentation inference.
|
|
struct SAM3Result
|
|
{
|
|
cv::Rect box; // bounding box
|
|
float confidence = 0.0f; // detection confidence score
|
|
cv::Mat mask; // binary mask (within bounding box)
|
|
std::vector<cv::Point2f> polygon; // simplified contour polygon
|
|
};
|
|
|
|
/// SAM3 engine using 3 separate ONNX Runtime sessions:
|
|
/// 1) Image encoder — produces backbone features + position encodings
|
|
/// 2) Language encoder — produces text attention mask + text features
|
|
/// 3) Decoder — combines image features + language features → boxes, scores, masks
|
|
///
|
|
/// This architecture avoids the CUDA EP crash that occurs with the
|
|
/// monolithic 3.3 GB model, since each sub-model is under 2 GB.
|
|
class ONNXENGINE_API ONNXSAM3
|
|
{
|
|
public:
|
|
/// Construct from a model folder containing:
|
|
/// anssam3_image_encoder.onnx (+.onnx_data)
|
|
/// anssam3_language_encoder.onnx (+.onnx_data)
|
|
/// anssam3_decoder.onnx (+.onnx_data)
|
|
explicit ONNXSAM3(const std::string& modelFolder,
|
|
EngineType engineType,
|
|
unsigned int num_threads = 1);
|
|
|
|
~ONNXSAM3();
|
|
|
|
// Non-copyable
|
|
ONNXSAM3(const ONNXSAM3&) = delete;
|
|
ONNXSAM3& operator=(const ONNXSAM3&) = delete;
|
|
|
|
/// Set text prompt (runs language encoder, caches results).
|
|
void setPrompt(const std::vector<int64_t>& inputIds,
|
|
const std::vector<int64_t>& attentionMask);
|
|
|
|
/// Run inference: image encoder + decoder with cached language features.
|
|
/// @param mat Input image (BGR).
|
|
/// @param segThreshold Score threshold for filtering detections.
|
|
/// @return SAM3Result objects with boxes/masks/polygons.
|
|
std::vector<SAM3Result> detect(const cv::Mat& mat,
|
|
float segThreshold = 0.5f);
|
|
|
|
int getInputSize() const { return m_inputSize; }
|
|
int getTokenLength() const { return m_tokenLength; }
|
|
int getMaskH() const { return m_maskH; }
|
|
int getMaskW() const { return m_maskW; }
|
|
bool isPromptSet() const { return m_promptSet; }
|
|
|
|
private:
|
|
/// Bundle holding one ORT session and its I/O names.
|
|
struct SessionBundle
|
|
{
|
|
Ort::Session* session = nullptr;
|
|
|
|
std::vector<std::string> inputNames_; // owns strings
|
|
std::vector<const char*> inputNames; // c_str pointers
|
|
|
|
std::vector<std::string> outputNames_; // owns strings
|
|
std::vector<const char*> outputNames; // c_str pointers
|
|
|
|
~SessionBundle();
|
|
};
|
|
|
|
/// Create one session bundle (EP attach, external data, GPU→CPU fallback).
|
|
/// @param forceCPU When true, skip GPU EP and always use CPU.
|
|
/// @param optLevel Graph optimization level for this session.
|
|
void createSessionBundle(SessionBundle& bundle,
|
|
const std::string& onnxPath,
|
|
const std::string& label,
|
|
bool forceCPU = false,
|
|
GraphOptimizationLevel optLevel = GraphOptimizationLevel::ORT_ENABLE_ALL);
|
|
|
|
/// Image preprocessing: BGR → RGB, resize to 1008, HWC→CHW, uint8.
|
|
void preprocessImage(const cv::Mat& mat, std::vector<uint8_t>& buffer);
|
|
|
|
/// Convert decoder outputs (boxes, scores, masks) → SAM3Result objects.
|
|
std::vector<SAM3Result> postprocessResults(
|
|
const float* boxesData, int numBoxes,
|
|
const float* scoresData,
|
|
const bool* masksData, int maskH, int maskW,
|
|
int origWidth, int origHeight,
|
|
float scoreThreshold);
|
|
|
|
// EP helpers (replicated from BasicOrtHandler)
|
|
bool TryAppendCUDA(Ort::SessionOptions& opts);
|
|
bool TryAppendDirectML(Ort::SessionOptions& opts);
|
|
bool TryAppendOpenVINO(Ort::SessionOptions& opts);
|
|
// ORT environment (shared across all 3 sessions)
|
|
Ort::Env* m_env = nullptr;
|
|
Ort::MemoryInfo* m_memInfo = nullptr;
|
|
|
|
// Three session bundles
|
|
SessionBundle m_imageEncoder;
|
|
SessionBundle m_langEncoder;
|
|
SessionBundle m_decoder;
|
|
|
|
// Engine configuration
|
|
EngineType m_engineType;
|
|
unsigned int m_numThreads;
|
|
std::string m_modelFolder;
|
|
|
|
// Model dimensions
|
|
int m_inputSize = 1008; // image spatial size (H=W)
|
|
int m_tokenLength = 32; // text token sequence length
|
|
int m_maskH = 0; // output mask height (from decoder output)
|
|
int m_maskW = 0; // output mask width (from decoder output)
|
|
|
|
// Cached language encoder outputs (set by setPrompt)
|
|
std::vector<char> m_cachedLangMask; // bool tensor data
|
|
std::vector<int64_t> m_cachedLangMaskShape;
|
|
std::vector<float> m_cachedLangFeatures; // float32 tensor data
|
|
std::vector<int64_t> m_cachedLangFeaturesShape;
|
|
bool m_promptSet = false;
|
|
};
|
|
}
|
|
|
|
#endif
|