Files

123 lines
5.0 KiB
C
Raw Permalink Normal View History

2026-03-28 16:54:11 +11:00
#ifndef ANSSAM3_H
#define ANSSAM3_H
#pragma once
#include "ANSEngineCommon.h"
#include "ANSCLIPTokenizer.h"
#include "engine.h"
#include "EPLoader.h"
#include "ANSGpuFrameRegistry.h"
#include "NV12PreprocessHelper.h"
#include <unordered_map>
#include <unordered_set>
namespace ANSCENTER
{
/// SAM3 (Segment Anything Model 3) - text-prompted instance segmentation
/// Full TRT architecture:
/// 1) Image encoder (TRT FP32) — produces backbone features + position encodings
/// 2) Language encoder (TRT FP16) — produces text attention mask + text features
/// 3) Decoder (TRT FP16) — combines image + language features → boxes, scores, masks
class ANSENGINE_API ANSSAM3 : public ANSODBase
{
public:
virtual bool Initialize(std::string licenseKey, ModelConfig modelConfig,
const std::string& modelZipFilePath, const std::string& modelZipPassword,
std::string& labelMap) override;
virtual bool LoadModel(const std::string& modelZipFilePath, const std::string& modelZipPassword) override;
virtual bool LoadModelFromFolder(std::string licenseKey, ModelConfig modelConfig,
std::string modelName, std::string className,
const std::string& modelFolder, std::string& labelMap) override;
virtual bool OptimizeModel(bool fp16, std::string& optimizedModelFolder);
bool SetPrompt(const std::vector<int64_t>& inputIds, const std::vector<int64_t>& attentionMask) override;
bool SetPrompt(const std::string& text) override;
std::vector<Object> RunInference(const cv::Mat& input);
std::vector<Object> RunInference(const cv::Mat& input, const std::string& camera_id);
bool Destroy();
~ANSSAM3();
/// Precision mode for TRT engine builds
enum class TrtPrecision { FP16, BF16, FP32 };
private:
// -----------------------------------------------------------------
// TRTBundle — one TensorRT sub-model (engine + context + GPU buffers)
// -----------------------------------------------------------------
struct TRTBundle
{
std::unique_ptr<nvinfer1::IRuntime> runtime;
std::unique_ptr<nvinfer1::ICudaEngine> engine;
std::unique_ptr<nvinfer1::IExecutionContext> context;
std::vector<void*> gpuBuffers; // I/O buffers (GPU or host), one per tensor
std::vector<size_t> gpuBufferSizes; // bytes per buffer
std::unordered_set<int> hostBufferIdx; // indices of host-allocated buffers (shape tensors)
std::unordered_map<std::string, int> nameToIdx; // tensor name → buffer index
void destroy();
};
std::string _modelFilePath;
bool _modelLoadValid = false;
bool _fp16 = false;
// TRT engine bundles
TRTBundle m_imgEncoder;
TRTBundle m_langEncoder;
TRTBundle m_decoder;
cudaStream_t m_cudaStream = nullptr;
// NV12 fast-path helper (fused NV12→RGB CHW directly into TRT buffer)
NV12PreprocessHelper m_nv12Helper;
// Cached language encoder outputs (GPU-resident, set by SetPrompt)
void* m_cachedLangMask = nullptr; // bool/int32 [1, 32]
size_t m_cachedLangMaskBytes = 0;
void* m_cachedLangFeats = nullptr; // float32 [32, 1, 256]
size_t m_cachedLangFeatsBytes = 0;
// Model constants
int m_inputSize = 1008; // image spatial size (H=W) — 3-model split uses 1008
int m_tokenLength = 32;
bool m_promptSet = false;
// Segmentation threshold
float m_segThreshold = 0.5f;
// Tokenizer
std::unique_ptr<ANSCLIPTokenizer> m_tokenizer;
// --- Internal methods ---
bool BuildAndLoadEngine(TRTBundle& bundle, const std::string& onnxPath,
const std::string& label, TrtPrecision precision = TrtPrecision::FP16);
bool LoadTRTEngineBundle(TRTBundle& bundle, const std::string& enginePath, const std::string& label);
/// Pre-build any uncached TRT engines one at a time (avoids GPU OOM during build)
bool EnsureEnginesBuilt(const std::string& imgOnnx, const std::string& langOnnx, const std::string& decOnnx);
/// Generate engine cache filename: <stem>.engine.<GPUName>.<fp16|bf16|fp32>
std::string EngineFileName(const std::string& onnxPath, TrtPrecision precision = TrtPrecision::FP16) const;
std::vector<Object> Detect(const cv::Mat& input, const std::string& camera_id);
std::vector<Object> PostprocessInstances(
const float* boxesData, int numBoxes,
const float* scoresData,
const bool* masksData,
int maskH, int maskW,
int origWidth, int origHeight,
const std::string& camera_id);
static size_t DataTypeSize(nvinfer1::DataType dtype);
Logger m_trtLogger;
};
}
#endif