123 lines
5.0 KiB
C
123 lines
5.0 KiB
C
|
|
#ifndef ANSSAM3_H
|
||
|
|
#define ANSSAM3_H
|
||
|
|
#pragma once
|
||
|
|
#include "ANSEngineCommon.h"
|
||
|
|
#include "ANSCLIPTokenizer.h"
|
||
|
|
#include "engine.h"
|
||
|
|
#include "EPLoader.h"
|
||
|
|
#include "ANSGpuFrameRegistry.h"
|
||
|
|
#include "NV12PreprocessHelper.h"
|
||
|
|
|
||
|
|
#include <unordered_map>
|
||
|
|
#include <unordered_set>
|
||
|
|
|
||
|
|
namespace ANSCENTER
|
||
|
|
{
|
||
|
|
/// SAM3 (Segment Anything Model 3) - text-prompted instance segmentation
|
||
|
|
/// Full TRT architecture:
|
||
|
|
/// 1) Image encoder (TRT FP32) — produces backbone features + position encodings
|
||
|
|
/// 2) Language encoder (TRT FP16) — produces text attention mask + text features
|
||
|
|
/// 3) Decoder (TRT FP16) — combines image + language features → boxes, scores, masks
|
||
|
|
class ANSENGINE_API ANSSAM3 : public ANSODBase
|
||
|
|
{
|
||
|
|
public:
|
||
|
|
virtual bool Initialize(std::string licenseKey, ModelConfig modelConfig,
|
||
|
|
const std::string& modelZipFilePath, const std::string& modelZipPassword,
|
||
|
|
std::string& labelMap) override;
|
||
|
|
virtual bool LoadModel(const std::string& modelZipFilePath, const std::string& modelZipPassword) override;
|
||
|
|
virtual bool LoadModelFromFolder(std::string licenseKey, ModelConfig modelConfig,
|
||
|
|
std::string modelName, std::string className,
|
||
|
|
const std::string& modelFolder, std::string& labelMap) override;
|
||
|
|
virtual bool OptimizeModel(bool fp16, std::string& optimizedModelFolder);
|
||
|
|
|
||
|
|
bool SetPrompt(const std::vector<int64_t>& inputIds, const std::vector<int64_t>& attentionMask) override;
|
||
|
|
bool SetPrompt(const std::string& text) override;
|
||
|
|
|
||
|
|
std::vector<Object> RunInference(const cv::Mat& input);
|
||
|
|
std::vector<Object> RunInference(const cv::Mat& input, const std::string& camera_id);
|
||
|
|
|
||
|
|
bool Destroy();
|
||
|
|
~ANSSAM3();
|
||
|
|
|
||
|
|
/// Precision mode for TRT engine builds
|
||
|
|
enum class TrtPrecision { FP16, BF16, FP32 };
|
||
|
|
|
||
|
|
private:
|
||
|
|
// -----------------------------------------------------------------
|
||
|
|
// TRTBundle — one TensorRT sub-model (engine + context + GPU buffers)
|
||
|
|
// -----------------------------------------------------------------
|
||
|
|
struct TRTBundle
|
||
|
|
{
|
||
|
|
std::unique_ptr<nvinfer1::IRuntime> runtime;
|
||
|
|
std::unique_ptr<nvinfer1::ICudaEngine> engine;
|
||
|
|
std::unique_ptr<nvinfer1::IExecutionContext> context;
|
||
|
|
|
||
|
|
std::vector<void*> gpuBuffers; // I/O buffers (GPU or host), one per tensor
|
||
|
|
std::vector<size_t> gpuBufferSizes; // bytes per buffer
|
||
|
|
std::unordered_set<int> hostBufferIdx; // indices of host-allocated buffers (shape tensors)
|
||
|
|
std::unordered_map<std::string, int> nameToIdx; // tensor name → buffer index
|
||
|
|
|
||
|
|
void destroy();
|
||
|
|
};
|
||
|
|
|
||
|
|
|
||
|
|
std::string _modelFilePath;
|
||
|
|
bool _modelLoadValid = false;
|
||
|
|
bool _fp16 = false;
|
||
|
|
|
||
|
|
// TRT engine bundles
|
||
|
|
TRTBundle m_imgEncoder;
|
||
|
|
TRTBundle m_langEncoder;
|
||
|
|
TRTBundle m_decoder;
|
||
|
|
|
||
|
|
cudaStream_t m_cudaStream = nullptr;
|
||
|
|
|
||
|
|
// NV12 fast-path helper (fused NV12→RGB CHW directly into TRT buffer)
|
||
|
|
NV12PreprocessHelper m_nv12Helper;
|
||
|
|
|
||
|
|
// Cached language encoder outputs (GPU-resident, set by SetPrompt)
|
||
|
|
void* m_cachedLangMask = nullptr; // bool/int32 [1, 32]
|
||
|
|
size_t m_cachedLangMaskBytes = 0;
|
||
|
|
void* m_cachedLangFeats = nullptr; // float32 [32, 1, 256]
|
||
|
|
size_t m_cachedLangFeatsBytes = 0;
|
||
|
|
|
||
|
|
// Model constants
|
||
|
|
int m_inputSize = 1008; // image spatial size (H=W) — 3-model split uses 1008
|
||
|
|
int m_tokenLength = 32;
|
||
|
|
bool m_promptSet = false;
|
||
|
|
|
||
|
|
// Segmentation threshold
|
||
|
|
float m_segThreshold = 0.5f;
|
||
|
|
|
||
|
|
// Tokenizer
|
||
|
|
std::unique_ptr<ANSCLIPTokenizer> m_tokenizer;
|
||
|
|
|
||
|
|
// --- Internal methods ---
|
||
|
|
bool BuildAndLoadEngine(TRTBundle& bundle, const std::string& onnxPath,
|
||
|
|
const std::string& label, TrtPrecision precision = TrtPrecision::FP16);
|
||
|
|
bool LoadTRTEngineBundle(TRTBundle& bundle, const std::string& enginePath, const std::string& label);
|
||
|
|
|
||
|
|
|
||
|
|
/// Pre-build any uncached TRT engines one at a time (avoids GPU OOM during build)
|
||
|
|
bool EnsureEnginesBuilt(const std::string& imgOnnx, const std::string& langOnnx, const std::string& decOnnx);
|
||
|
|
|
||
|
|
/// Generate engine cache filename: <stem>.engine.<GPUName>.<fp16|bf16|fp32>
|
||
|
|
std::string EngineFileName(const std::string& onnxPath, TrtPrecision precision = TrtPrecision::FP16) const;
|
||
|
|
|
||
|
|
std::vector<Object> Detect(const cv::Mat& input, const std::string& camera_id);
|
||
|
|
|
||
|
|
std::vector<Object> PostprocessInstances(
|
||
|
|
const float* boxesData, int numBoxes,
|
||
|
|
const float* scoresData,
|
||
|
|
const bool* masksData,
|
||
|
|
int maskH, int maskW,
|
||
|
|
int origWidth, int origHeight,
|
||
|
|
const std::string& camera_id);
|
||
|
|
|
||
|
|
static size_t DataTypeSize(nvinfer1::DataType dtype);
|
||
|
|
|
||
|
|
Logger m_trtLogger;
|
||
|
|
};
|
||
|
|
}
|
||
|
|
#endif
|