Files
ANSCORE/modules/ANSODEngine/ANSONNXSAM3.cpp

361 lines
14 KiB
C++

#include "ANSONNXSAM3.h"
#include "ANSCLIPTokenizer.h"
#include "ONNXSAM3.h"
#include "Utility.h"
namespace ANSCENTER
{
// =========================================================================
// OptimizeModel — ONNX doesn't need separate optimization
// =========================================================================
bool ANSONNXSAM3::OptimizeModel(bool fp16, std::string& optimizedModelFolder)
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (!ANSODBase::OptimizeModel(fp16, optimizedModelFolder)) {
return false;
}
// ONNX Runtime handles graph optimization internally.
// Just verify the model folder exists.
std::string imgPath = CreateFilePath(_modelFolder, "sam3_image_encoder.onnx");
if (!FileExist(imgPath)) {
_logger.LogFatal("ANSONNXSAM3::OptimizeModel",
"Model files not found in: " + _modelFolder, __FILE__, __LINE__);
return false;
}
_fp16 = fp16;
optimizedModelFolder = _modelFolder;
return true;
}
// =========================================================================
// Initialize
// =========================================================================
bool ANSONNXSAM3::Initialize(std::string licenseKey, ModelConfig modelConfig,
const std::string& modelZipFilePath, const std::string& modelZipPassword,
std::string& labelMap)
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
ModelLoadingGuard mlg(_modelLoading);
try {
bool result = ANSODBase::Initialize(licenseKey, modelConfig,
modelZipFilePath, modelZipPassword, labelMap);
if (!result) return false;
_modelConfig.detectionType = DetectionType::SEGMENTATION;
if (_modelConfig.modelConfThreshold < 0.1f)
_modelConfig.modelConfThreshold = 0.5f;
m_segThreshold = _modelConfig.modelConfThreshold;
// Create ONNX engine from 3-model folder
if (!InitEngine(_modelFolder)) {
_logger.LogError("ANSONNXSAM3::Initialize",
"Failed to init ONNX engine from folder: " + _modelFolder, __FILE__, __LINE__);
_modelLoadValid = false;
return false;
}
_modelLoadValid = true;
_isInitialized = true;
// Load tokenizer if merges.txt is available
m_tokenizer = std::make_unique<ANSCLIPTokenizer>();
std::string tokenizerPath = CreateFilePath(_modelFolder, "merges.txt");
if (FileExist(tokenizerPath)) {
m_tokenizer->Load(tokenizerPath);
_logger.LogDebug("ANSONNXSAM3::Initialize",
"CLIP tokenizer loaded from: " + tokenizerPath, __FILE__, __LINE__);
}
return true;
}
catch (const std::exception& e) {
_logger.LogFatal("ANSONNXSAM3::Initialize", e.what(), __FILE__, __LINE__);
return false;
}
}
// =========================================================================
// LoadModel
// =========================================================================
bool ANSONNXSAM3::LoadModel(const std::string& modelZipFilePath,
const std::string& modelZipPassword)
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
ModelLoadingGuard mlg(_modelLoading);
try {
bool result = ANSODBase::LoadModel(modelZipFilePath, modelZipPassword);
if (!result) return false;
_modelConfig.detectionType = DetectionType::SEGMENTATION;
if (_modelConfig.modelConfThreshold < 0.1f)
_modelConfig.modelConfThreshold = 0.5f;
m_segThreshold = _modelConfig.modelConfThreshold;
// Create ONNX engine from 3-model folder
if (!InitEngine(_modelFolder)) {
_logger.LogError("ANSONNXSAM3::LoadModel",
"Failed to init ONNX engine from folder: " + _modelFolder, __FILE__, __LINE__);
_modelLoadValid = false;
return false;
}
_modelLoadValid = true;
_isInitialized = true;
// Load tokenizer if merges.txt is available
m_tokenizer = std::make_unique<ANSCLIPTokenizer>();
std::string tokenizerPath = CreateFilePath(_modelFolder, "merges.txt");
if (FileExist(tokenizerPath)) {
m_tokenizer->Load(tokenizerPath);
_logger.LogDebug("ANSONNXSAM3::LoadModel",
"CLIP tokenizer loaded from: " + tokenizerPath, __FILE__, __LINE__);
}
return true;
}
catch (const std::exception& e) {
_logger.LogFatal("ANSONNXSAM3::LoadModel", e.what(), __FILE__, __LINE__);
return false;
}
}
// =========================================================================
// LoadModelFromFolder
// =========================================================================
bool ANSONNXSAM3::LoadModelFromFolder(std::string licenseKey, ModelConfig modelConfig,
std::string modelName, std::string className,
const std::string& modelFolder, std::string& labelMap)
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
ModelLoadingGuard mlg(_modelLoading);
try {
bool result = ANSODBase::LoadModelFromFolder(licenseKey, modelConfig,
modelName, className, modelFolder, labelMap);
if (!result) return false;
_modelConfig = modelConfig;
_modelConfig.detectionType = DetectionType::SEGMENTATION;
if (_modelConfig.modelConfThreshold < 0.1f)
_modelConfig.modelConfThreshold = 0.5f;
m_segThreshold = _modelConfig.modelConfThreshold;
_modelFilePath = modelFolder;
// Create ONNX engine from 3-model folder
if (!InitEngine(modelFolder)) {
_logger.LogError("ANSONNXSAM3::LoadModelFromFolder",
"Failed to init ONNX engine from folder: " + modelFolder, __FILE__, __LINE__);
_modelLoadValid = false;
return false;
}
_modelLoadValid = true;
_isInitialized = true;
// Load tokenizer if merges.txt is available
m_tokenizer = std::make_unique<ANSCLIPTokenizer>();
std::string tokenizerPath = CreateFilePath(_modelFolder, "merges.txt");
if (FileExist(tokenizerPath)) {
m_tokenizer->Load(tokenizerPath);
_logger.LogDebug("ANSONNXSAM3::LoadModelFromFolder",
"CLIP tokenizer loaded from: " + tokenizerPath, __FILE__, __LINE__);
}
return true;
}
catch (const std::exception& e) {
_logger.LogFatal("ANSONNXSAM3::LoadModelFromFolder", e.what(), __FILE__, __LINE__);
return false;
}
}
// =========================================================================
// InitEngine — create ONNXSAM3 engine
// =========================================================================
bool ANSONNXSAM3::InitEngine(const std::string& modelFolder)
{
// Verify required model files exist
std::string imgPath = CreateFilePath(modelFolder, "sam3_image_encoder.onnx");
std::string langPath = CreateFilePath(modelFolder, "sam3_language_encoder.onnx");
std::string decPath = CreateFilePath(modelFolder, "sam3_decoder.onnx");
if (!FileExist(imgPath)) {
_logger.LogError("ANSONNXSAM3::InitEngine",
"Image encoder not found: " + imgPath, __FILE__, __LINE__);
return false;
}
if (!FileExist(langPath)) {
_logger.LogError("ANSONNXSAM3::InitEngine",
"Language encoder not found: " + langPath, __FILE__, __LINE__);
return false;
}
if (!FileExist(decPath)) {
_logger.LogError("ANSONNXSAM3::InitEngine",
"Decoder not found: " + decPath, __FILE__, __LINE__);
return false;
}
try {
m_engine = std::make_unique<ONNXSAM3>(modelFolder, EngineType::NVIDIA_GPU);
m_tokenLength = m_engine->getTokenLength();
_logger.LogDebug("ANSONNXSAM3::InitEngine",
"3-session ONNX engine created. inputSize=" + std::to_string(m_engine->getInputSize()) +
" tokenLength=" + std::to_string(m_tokenLength),
__FILE__, __LINE__);
return true;
}
catch (const Ort::Exception& e) {
_logger.LogError("ANSONNXSAM3::InitEngine",
std::string("ORT Exception: ") + e.what(), __FILE__, __LINE__);
return false;
}
catch (const std::exception& e) {
_logger.LogError("ANSONNXSAM3::InitEngine",
std::string("Failed to create ONNX engine: ") + e.what(), __FILE__, __LINE__);
return false;
}
}
// =========================================================================
// SetPrompt (pre-tokenized)
// =========================================================================
bool ANSONNXSAM3::SetPrompt(const std::vector<int64_t>& inputIds,
const std::vector<int64_t>& attentionMask)
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (static_cast<int>(inputIds.size()) != m_tokenLength ||
static_cast<int>(attentionMask.size()) != m_tokenLength) {
_logger.LogError("ANSONNXSAM3::SetPrompt",
"Token vectors must have exactly " + std::to_string(m_tokenLength) + " elements",
__FILE__, __LINE__);
return false;
}
if (!m_engine) {
_logger.LogError("ANSONNXSAM3::SetPrompt", "Engine not initialized", __FILE__, __LINE__);
return false;
}
m_engine->setPrompt(inputIds, attentionMask);
m_promptSet = true;
return true;
}
// =========================================================================
// SetPrompt (text string — uses CLIP tokenizer)
// =========================================================================
bool ANSONNXSAM3::SetPrompt(const std::string& text)
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (!m_tokenizer || !m_tokenizer->IsLoaded()) {
_logger.LogError("ANSONNXSAM3::SetPrompt",
"Tokenizer not loaded. Place merges.txt in model folder, "
"or use SetPrompt(inputIds, attentionMask) directly.",
__FILE__, __LINE__);
return false;
}
auto result = m_tokenizer->Tokenize(text, m_tokenLength);
SetPrompt(result.inputIds, result.attentionMask);
return true;
}
// =========================================================================
// RunInference
// =========================================================================
std::vector<Object> ANSONNXSAM3::RunInference(const cv::Mat& input)
{
return RunInference(input, "");
}
std::vector<Object> ANSONNXSAM3::RunInference(const cv::Mat& input, const std::string& camera_id)
{
if (!PreInferenceCheck("ANSONNXSAM3::RunInference")) return {};
try {
// Run ONNX engine inference
auto sam3Results = m_engine->detect(input, m_segThreshold);
// Convert SAM3Result -> ANSODBase Object
std::vector<Object> results;
results.reserve(sam3Results.size());
const float imgW = static_cast<float>(input.cols);
const float imgH = static_cast<float>(input.rows);
for (auto& sr : sam3Results) {
Object obj;
obj.box = sr.box;
obj.classId = 0;
obj.className = "object";
obj.cameraId = camera_id;
obj.confidence = sr.confidence;
obj.mask = std::move(sr.mask);
// Create normalized polygon from mask (closed, maxPoints-limited)
obj.polygon = ANSUtilityHelper::MaskToNormalizedPolygon(
obj.mask, obj.box, imgW, imgH);
// Fallback: normalized box corners if mask polygon failed
if (obj.polygon.empty()) {
obj.polygon = ANSUtilityHelper::RectToNormalizedPolygon(obj.box, imgW, imgH);
}
results.push_back(std::move(obj));
}
if (_trackerEnabled) {
results = ApplyTracking(results, camera_id);
if (_stabilizationEnabled) results = StabilizeDetections(results, camera_id);
}
return results;
}
catch (const std::exception& e) {
_logger.LogFatal("ANSONNXSAM3::RunInference", e.what(), __FILE__, __LINE__);
return {};
}
}
// =========================================================================
// Destroy
// =========================================================================
bool ANSONNXSAM3::Destroy()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
try {
m_engine.reset();
m_tokenizer.reset();
m_promptSet = false;
_modelLoadValid = false;
_isInitialized = false;
return true;
}
catch (const std::exception& e) {
_logger.LogFatal("ANSONNXSAM3::Destroy", e.what(), __FILE__, __LINE__);
return false;
}
}
ANSONNXSAM3::~ANSONNXSAM3()
{
Destroy();
}
}