67 lines
2.5 KiB
C++
67 lines
2.5 KiB
C++
#ifndef ANSONNXSAM3_H
|
|
#define ANSONNXSAM3_H
|
|
#pragma once
|
|
#include "ANSEngineCommon.h"
|
|
#include "ANSCLIPTokenizer.h"
|
|
#include "ONNXSAM3.h"
|
|
|
|
namespace ANSCENTER
|
|
{
|
|
|
|
/// ANSONNXSAM3 — text-prompted segmentation using ONNX Runtime.
|
|
///
|
|
/// Public API (ANSODBase) wrapper around ONNXSAM3 engine.
|
|
/// Supports all execution providers via BasicOrtHandler
|
|
/// (CUDA, DirectML, OpenVINO, CPU).
|
|
class ANSENGINE_API ANSONNXSAM3 : public ANSODBase
|
|
{
|
|
public:
|
|
virtual bool Initialize(std::string licenseKey, ModelConfig modelConfig,
|
|
const std::string& modelZipFilePath, const std::string& modelZipPassword,
|
|
std::string& labelMap) override;
|
|
virtual bool LoadModel(const std::string& modelZipFilePath, const std::string& modelZipPassword) override;
|
|
virtual bool LoadModelFromFolder(std::string licenseKey, ModelConfig modelConfig,
|
|
std::string modelName, std::string className,
|
|
const std::string& modelFolder, std::string& labelMap) override;
|
|
virtual bool OptimizeModel(bool fp16, std::string& optimizedModelFolder);
|
|
|
|
/// Set the text prompt for segmentation (pre-tokenized).
|
|
bool SetPrompt(const std::vector<int64_t>& inputIds,
|
|
const std::vector<int64_t>& attentionMask) override;
|
|
|
|
/// Set the text prompt by tokenizing the given text.
|
|
/// Requires merges.txt (CLIP BPE vocabulary) in the model folder.
|
|
bool SetPrompt(const std::string& text) override;
|
|
|
|
std::vector<Object> RunInference(const cv::Mat& input);
|
|
std::vector<Object> RunInference(const cv::Mat& input, const std::string& camera_id);
|
|
|
|
bool Destroy();
|
|
~ANSONNXSAM3();
|
|
|
|
private:
|
|
std::string _modelFilePath;
|
|
bool _modelLoadValid = false;
|
|
bool _fp16 = false;
|
|
|
|
// ONNX Runtime engine
|
|
std::unique_ptr<ONNXSAM3> m_engine;
|
|
|
|
// Tokenizer (loaded from merges.txt)
|
|
std::unique_ptr<ANSCLIPTokenizer> m_tokenizer;
|
|
|
|
// Segmentation threshold (applied after sigmoid)
|
|
float m_segThreshold = 0.5f;
|
|
|
|
// Token length (read from engine after load)
|
|
int m_tokenLength = 32;
|
|
|
|
// Prompt state
|
|
bool m_promptSet = false;
|
|
|
|
// Internal: create engine from model folder (3-session architecture)
|
|
bool InitEngine(const std::string& modelFolder);
|
|
};
|
|
}
|
|
#endif
|