Initial setup for CLion
This commit is contained in:
581
engines/ONNXEngine/EPLoader.cpp
Normal file
581
engines/ONNXEngine/EPLoader.cpp
Normal file
@@ -0,0 +1,581 @@
|
||||
// EPLoader.cpp
|
||||
// Dynamic ONNX Runtime EP loader.
|
||||
// Loads onnxruntime.dll at runtime — no onnxruntime.lib linkage required.
|
||||
//
|
||||
// Compile this file in EXACTLY ONE project (ANSCore.dll / ANSCore.lib).
|
||||
// That project MUST define ANSCORE_EXPORTS in its Preprocessor Definitions.
|
||||
// All other projects that consume EPLoader include only EPLoader.h and link
|
||||
// against ANSCore — they must NOT add EPLoader.cpp to their source list.
|
||||
//
|
||||
// Windows: LoadLibraryExW + AddDllDirectory + GetProcAddress
|
||||
// Linux: dlopen (RTLD_NOW | RTLD_GLOBAL) + dlsym
|
||||
|
||||
#include "EPLoader.h"
|
||||
|
||||
// ORT C++ headers — included ONLY in this translation unit, never in EPLoader.h.
|
||||
// ORT_API_MANUAL_INIT must be defined project-wide (Preprocessor Definitions)
|
||||
// in every project that includes ORT headers, so all translation units see
|
||||
// Ort::Global<void>::api_ as an extern rather than a default-constructed object.
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef _WIN32
|
||||
# ifndef WIN32_LEAN_AND_MEAN
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# endif
|
||||
# include <windows.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <sys/stat.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace ANSCENTER {
|
||||
|
||||
// ── Static member definitions ────────────────────────────────────────────
|
||||
#ifdef ANSCORE_EXPORTS
|
||||
std::mutex EPLoader::s_mutex;
|
||||
bool EPLoader::s_initialized = false;
|
||||
EPInfo EPLoader::s_info;
|
||||
# ifdef _WIN32
|
||||
std::string EPLoader::s_temp_ort_path;
|
||||
std::string EPLoader::s_temp_dir; // ← NEW: tracks our temp staging dir
|
||||
# endif
|
||||
#endif
|
||||
|
||||
// ── File-scope state ─────────────────────────────────────────────────────
|
||||
#ifdef _WIN32
|
||||
static HMODULE s_ort_module = nullptr;
|
||||
#else
|
||||
static void* s_ort_module = nullptr;
|
||||
#endif
|
||||
static const OrtApi* s_ort_api = nullptr;
|
||||
|
||||
|
||||
// ════════════════════════════════════════════════════════════════════════
|
||||
// File-scope helpers (anonymous namespace — not exported)
|
||||
// ════════════════════════════════════════════════════════════════════════
|
||||
namespace {
|
||||
|
||||
bool FileExists(const std::string& path)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
DWORD attr = GetFileAttributesA(path.c_str());
|
||||
return (attr != INVALID_FILE_ATTRIBUTES) &&
|
||||
!(attr & FILE_ATTRIBUTE_DIRECTORY);
|
||||
#else
|
||||
struct stat st {};
|
||||
return (::stat(path.c_str(), &st) == 0) && S_ISREG(st.st_mode);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string JoinPath(const std::string& base, const std::string& component)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
const char sep = '\\';
|
||||
#else
|
||||
const char sep = '/';
|
||||
#endif
|
||||
if (base.empty()) return component;
|
||||
if (base.back() == sep || base.back() == '/')
|
||||
return base + component;
|
||||
return base + sep + component;
|
||||
}
|
||||
|
||||
void InjectDllSearchPath(const std::string& ep_dir)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
std::wstring wdir(ep_dir.begin(), ep_dir.end());
|
||||
DLL_DIRECTORY_COOKIE cookie = AddDllDirectory(wdir.c_str());
|
||||
if (!cookie)
|
||||
std::cerr << "[EPLoader] WARNING: AddDllDirectory failed for: "
|
||||
<< ep_dir << " (error " << GetLastError() << ")" << std::endl;
|
||||
|
||||
char existing_path[32767] = {};
|
||||
GetEnvironmentVariableA("PATH", existing_path, sizeof(existing_path));
|
||||
std::string new_path = ep_dir + ";" + existing_path;
|
||||
if (!SetEnvironmentVariableA("PATH", new_path.c_str()))
|
||||
std::cerr << "[EPLoader] WARNING: SetEnvironmentVariable PATH failed."
|
||||
<< std::endl;
|
||||
#else
|
||||
const char* existing = getenv("LD_LIBRARY_PATH");
|
||||
std::string new_path = ep_dir +
|
||||
(existing ? (":" + std::string(existing)) : "");
|
||||
setenv("LD_LIBRARY_PATH", new_path.c_str(), 1);
|
||||
#endif
|
||||
std::cout << "[EPLoader] DLL search path injected: " << ep_dir << std::endl;
|
||||
}
|
||||
|
||||
// ── GetOrtApi ────────────────────────────────────────────────────────
|
||||
const OrtApi* GetOrtApi()
|
||||
{
|
||||
if (!s_ort_module)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] ORT DLL not loaded — call EPLoader::Current() first.");
|
||||
|
||||
if (s_ort_api)
|
||||
return s_ort_api;
|
||||
|
||||
#ifdef _WIN32
|
||||
using Fn = const OrtApiBase* (ORT_API_CALL*)();
|
||||
auto fn = reinterpret_cast<Fn>(
|
||||
GetProcAddress(s_ort_module, "OrtGetApiBase"));
|
||||
#else
|
||||
using Fn = const OrtApiBase* (ORT_API_CALL*)();
|
||||
auto fn = reinterpret_cast<Fn>(
|
||||
dlsym(s_ort_module, "OrtGetApiBase"));
|
||||
#endif
|
||||
if (!fn)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] OrtGetApiBase symbol not found in loaded ORT DLL.");
|
||||
|
||||
const OrtApiBase* base = fn();
|
||||
if (!base)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] OrtGetApiBase() returned null.");
|
||||
|
||||
int dllMaxApi = ORT_API_VERSION;
|
||||
{
|
||||
const char* verStr = base->GetVersionString();
|
||||
int major = 0, minor = 0;
|
||||
if (verStr && sscanf(verStr, "%d.%d", &major, &minor) == 2)
|
||||
dllMaxApi = minor;
|
||||
}
|
||||
int targetApi = min(ORT_API_VERSION, dllMaxApi);
|
||||
|
||||
const OrtApi* api = base->GetApi(targetApi);
|
||||
if (!api)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] No compatible ORT API version found in loaded DLL.");
|
||||
|
||||
s_ort_api = api;
|
||||
return s_ort_api;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
// ════════════════════════════════════════════════════════════════════════
|
||||
// EPLoader public / private methods
|
||||
// ════════════════════════════════════════════════════════════════════════
|
||||
const char* OrtDllName()
|
||||
{
|
||||
#ifdef _WIN32
|
||||
return "onnxruntime.dll";
|
||||
#elif defined(__APPLE__)
|
||||
return "libonnxruntime.dylib";
|
||||
#else
|
||||
return "libonnxruntime.so";
|
||||
#endif
|
||||
}
|
||||
|
||||
const char* EPLoader::SubdirName(EngineType type)
|
||||
{
|
||||
switch (type) {
|
||||
case EngineType::NVIDIA_GPU: return "cuda";
|
||||
case EngineType::AMD_GPU: return "directml";
|
||||
case EngineType::OPENVINO_GPU: return "openvino";
|
||||
case EngineType::CPU: return "cpu";
|
||||
default: return "cpu";
|
||||
}
|
||||
}
|
||||
|
||||
const char* EPLoader::EngineTypeName(EngineType type)
|
||||
{
|
||||
switch (type) {
|
||||
case EngineType::NVIDIA_GPU: return "NVIDIA_GPU";
|
||||
case EngineType::AMD_GPU: return "AMD_GPU";
|
||||
case EngineType::OPENVINO_GPU: return "OPENVINO_GPU";
|
||||
case EngineType::CPU: return "CPU";
|
||||
case EngineType::AUTO_DETECT: return "AUTO_DETECT";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/
|
||||
std::string EPLoader::ResolveEPDir(const std::string& shared_dir,
|
||||
EngineType type)
|
||||
{
|
||||
std::string ep_base = JoinPath(shared_dir, "ep");
|
||||
std::string subdir = JoinPath(ep_base, SubdirName(type));
|
||||
std::string dll_probe = JoinPath(subdir, OrtDllName());
|
||||
|
||||
if (FileExists(dll_probe)) {
|
||||
std::cout << "[EPLoader] EP subdir found: " << subdir << std::endl;
|
||||
return subdir;
|
||||
}
|
||||
|
||||
std::string flat_probe = JoinPath(shared_dir, OrtDllName());
|
||||
if (FileExists(flat_probe)) {
|
||||
std::cout << "[EPLoader] EP subdir not found — "
|
||||
"using flat Shared/ (backward compat): "
|
||||
<< shared_dir << std::endl;
|
||||
return shared_dir;
|
||||
}
|
||||
|
||||
std::cerr << "[EPLoader] WARNING: " << OrtDllName() << " not found in:\n"
|
||||
<< " " << subdir << "\n"
|
||||
<< " " << shared_dir << "\n"
|
||||
<< " LoadOrtDll will fail with a clear error."
|
||||
<< std::endl;
|
||||
return subdir;
|
||||
}
|
||||
|
||||
EngineType EPLoader::AutoDetect()
|
||||
{
|
||||
std::cout << "[EPLoader] Auto-detecting hardware..." << std::endl;
|
||||
ANSLicenseHelper helper;
|
||||
EngineType detected = helper.CheckHardwareInformation();
|
||||
std::cout << "[EPLoader] Detected: " << EngineTypeName(detected) << std::endl;
|
||||
return detected;
|
||||
}
|
||||
|
||||
/*static*/
|
||||
const EPInfo& EPLoader::Initialize(const std::string& shared_dir,
|
||||
EngineType preferred)
|
||||
{
|
||||
if (s_initialized) return s_info;
|
||||
|
||||
std::lock_guard<std::mutex> lock(s_mutex);
|
||||
if (s_initialized) return s_info;
|
||||
|
||||
std::cout << "[EPLoader] Initializing..." << std::endl;
|
||||
std::cout << "[EPLoader] Shared dir : " << shared_dir << std::endl;
|
||||
std::cout << "[EPLoader] Preferred EP : " << EngineTypeName(preferred) << std::endl;
|
||||
|
||||
EngineType type = (preferred == EngineType::AUTO_DETECT)
|
||||
? AutoDetect() : preferred;
|
||||
|
||||
std::string ep_dir = ResolveEPDir(shared_dir, type);
|
||||
|
||||
// When the EP lives in a subdirectory (e.g. ep/openvino/), provider
|
||||
// DLLs may depend on runtime libraries that live in the parent
|
||||
// shared_dir (e.g. openvino.dll). Inject shared_dir into the DLL
|
||||
// search path so Windows can resolve those dependencies.
|
||||
if (ep_dir != shared_dir)
|
||||
InjectDllSearchPath(shared_dir);
|
||||
|
||||
LoadOrtDll(ep_dir);
|
||||
|
||||
s_info.type = type;
|
||||
s_info.libraryDir = ep_dir;
|
||||
s_info.fromSubdir = (ep_dir != shared_dir);
|
||||
s_initialized = true;
|
||||
|
||||
std::cout << "[EPLoader] Ready. EP=" << EngineTypeName(type)
|
||||
<< " dir=" << ep_dir << std::endl;
|
||||
return s_info;
|
||||
}
|
||||
|
||||
/*static*/
|
||||
const EPInfo& EPLoader::Current()
|
||||
{
|
||||
if (!s_initialized)
|
||||
return Initialize(DEFAULT_SHARED_DIR, EngineType::AUTO_DETECT);
|
||||
return s_info;
|
||||
}
|
||||
|
||||
/*static*/
|
||||
bool EPLoader::IsInitialized()
|
||||
{
|
||||
return s_initialized;
|
||||
}
|
||||
|
||||
void* EPLoader::GetOrtApiRaw()
|
||||
{
|
||||
Current();
|
||||
const OrtApi* api = GetOrtApi();
|
||||
if (!api)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] GetOrtApiRaw: OrtApi not available.");
|
||||
return static_cast<void*>(const_cast<OrtApi*>(api));
|
||||
}
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
// ── MakeTempDir ──────────────────────────────────────────────────────────
|
||||
// Creates a process-unique staging directory under %TEMP%:
|
||||
// %TEMP%\anscenter_ort_<pid>\
|
||||
//
|
||||
// All ORT DLLs are copied here so they share the same directory.
|
||||
// ORT resolves provider DLLs (onnxruntime_providers_shared.dll, etc.)
|
||||
// relative to its own loaded path — they must be co-located with the
|
||||
// renamed onnxruntime DLL, not left behind in the original ep_dir.
|
||||
static std::string MakeTempDir()
|
||||
{
|
||||
char tmp[MAX_PATH] = {};
|
||||
GetTempPathA(MAX_PATH, tmp);
|
||||
std::string dir = std::string(tmp)
|
||||
+ "anscenter_ort_"
|
||||
+ std::to_string(GetCurrentProcessId());
|
||||
CreateDirectoryA(dir.c_str(), nullptr); // OK if already exists
|
||||
return dir;
|
||||
}
|
||||
|
||||
// ── CopyDirToTemp ─────────────────────────────────────────────────────────
|
||||
// Copies every .dll from ep_dir into temp_dir.
|
||||
// ORT and all its provider DLLs (onnxruntime_providers_shared.dll,
|
||||
// onnxruntime_providers_cuda.dll, onnxruntime_providers_tensorrt.dll,
|
||||
// DirectML.dll, etc.) must all live in the same folder so Windows can
|
||||
// resolve their mutual dependencies.
|
||||
static void CopyDirToTemp(const std::string& ep_dir,
|
||||
const std::string& temp_dir)
|
||||
{
|
||||
std::string pattern = ep_dir + "\\*.dll";
|
||||
WIN32_FIND_DATAA fd{};
|
||||
HANDLE hFind = FindFirstFileA(pattern.c_str(), &fd);
|
||||
if (hFind == INVALID_HANDLE_VALUE) {
|
||||
std::cerr << "[EPLoader] WARNING: No DLLs found in ep_dir: "
|
||||
<< ep_dir << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
int copied = 0, skipped = 0;
|
||||
do {
|
||||
std::string src = ep_dir + "\\" + fd.cFileName;
|
||||
std::string dst = temp_dir + "\\" + fd.cFileName;
|
||||
|
||||
if (CopyFileA(src.c_str(), dst.c_str(), /*bFailIfExists=*/FALSE)) {
|
||||
++copied;
|
||||
std::cout << "[EPLoader] copied: " << fd.cFileName << std::endl;
|
||||
}
|
||||
else {
|
||||
// ERROR_FILE_EXISTS (80) is fine — stale copy from previous run.
|
||||
DWORD err = GetLastError();
|
||||
if (err != ERROR_FILE_EXISTS && err != ERROR_ALREADY_EXISTS) {
|
||||
std::cerr << "[EPLoader] WARNING: could not copy "
|
||||
<< fd.cFileName
|
||||
<< " (err=" << err << ") — skipping." << std::endl;
|
||||
}
|
||||
++skipped;
|
||||
}
|
||||
} while (FindNextFileA(hFind, &fd));
|
||||
|
||||
FindClose(hFind);
|
||||
std::cout << "[EPLoader] Staged " << copied << " DLL(s) to temp dir"
|
||||
<< (skipped ? " (" + std::to_string(skipped) + " already present)" : "")
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// ── DeleteTempDir ─────────────────────────────────────────────────────────
|
||||
// Removes the staging directory and all files in it.
|
||||
static void DeleteTempDir(const std::string& dir)
|
||||
{
|
||||
if (dir.empty()) return;
|
||||
|
||||
std::string pattern = dir + "\\*";
|
||||
WIN32_FIND_DATAA fd{};
|
||||
HANDLE hFind = FindFirstFileA(pattern.c_str(), &fd);
|
||||
if (hFind != INVALID_HANDLE_VALUE) {
|
||||
do {
|
||||
if (strcmp(fd.cFileName, ".") == 0 ||
|
||||
strcmp(fd.cFileName, "..") == 0) continue;
|
||||
std::string f = dir + "\\" + fd.cFileName;
|
||||
DeleteFileA(f.c_str());
|
||||
} while (FindNextFileA(hFind, &fd));
|
||||
FindClose(hFind);
|
||||
}
|
||||
RemoveDirectoryA(dir.c_str());
|
||||
std::cout << "[EPLoader] Temp staging dir deleted: " << dir << std::endl;
|
||||
}
|
||||
#endif // _WIN32
|
||||
|
||||
|
||||
// ── LoadOrtDll ───────────────────────────────────────────────────────────────
|
||||
void EPLoader::LoadOrtDll(const std::string& ep_dir)
|
||||
{
|
||||
if (s_ort_module) {
|
||||
std::cout << "[EPLoader] ORT DLL already loaded — skipping." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// Inject ep_dir into the DLL search path so cudart64_*.dll and other
|
||||
// CUDA runtime DLLs that are NOT copied to temp are still found.
|
||||
InjectDllSearchPath(ep_dir);
|
||||
|
||||
std::string src_path = JoinPath(ep_dir, OrtDllName());
|
||||
std::cout << "[EPLoader] ORT source : " << src_path << std::endl;
|
||||
|
||||
#ifdef _WIN32
|
||||
// ── Windows: stage ALL ep_dir DLLs into a process-unique temp folder ────
|
||||
//
|
||||
// Why a renamed copy of onnxruntime.dll?
|
||||
// Windows DLL identity is keyed on base filename. A pre-loaded
|
||||
// System32\onnxruntime.dll would be returned by LoadLibraryExW regardless
|
||||
// of the full path we supply. A unique name bypasses that.
|
||||
//
|
||||
// Why copy ALL DLLs (not just onnxruntime.dll)?
|
||||
// ORT internally calls GetModuleFileName on its own HMODULE to discover
|
||||
// its directory, then loads provider DLLs (onnxruntime_providers_shared.dll,
|
||||
// onnxruntime_providers_cuda.dll, etc.) from that same directory.
|
||||
// Because our copy is named "anscenter_ort_<pid>.dll", ORT's
|
||||
// GetModuleHandleA("onnxruntime.dll") returns NULL (or the System32 copy),
|
||||
// so it resolves providers relative to %TEMP% — where they don't exist.
|
||||
// Staging ALL provider DLLs alongside our renamed copy fixes this.
|
||||
|
||||
std::string temp_dir = MakeTempDir();
|
||||
std::cout << "[EPLoader] Staging dir : " << temp_dir << std::endl;
|
||||
|
||||
// Copy every .dll from ep_dir into temp_dir (providers included).
|
||||
CopyDirToTemp(ep_dir, temp_dir);
|
||||
|
||||
// The main ORT DLL gets an additional process-unique alias so Windows
|
||||
// loads it as a fresh module rather than returning the System32 handle.
|
||||
std::string ort_alias = temp_dir + "\\anscenter_ort_"
|
||||
+ std::to_string(GetCurrentProcessId()) + ".dll";
|
||||
std::string ort_in_temp = temp_dir + "\\" + OrtDllName();
|
||||
|
||||
if (!CopyFileA(ort_in_temp.c_str(), ort_alias.c_str(), /*bFailIfExists=*/FALSE)) {
|
||||
DWORD err = GetLastError();
|
||||
if (err != ERROR_FILE_EXISTS && err != ERROR_ALREADY_EXISTS) {
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] Failed to create ORT alias in temp dir.\n"
|
||||
" src : " + ort_in_temp + "\n"
|
||||
" dst : " + ort_alias + "\n"
|
||||
" err : " + std::to_string(err));
|
||||
}
|
||||
}
|
||||
std::cout << "[EPLoader] ORT alias : " << ort_alias << std::endl;
|
||||
|
||||
// Inject temp_dir so that the alias's own dependencies (and ORT's
|
||||
// runtime provider loading) also search here.
|
||||
InjectDllSearchPath(temp_dir);
|
||||
|
||||
// Log if a System32 copy is already resident — our alias bypasses it.
|
||||
HMODULE hExisting = GetModuleHandleA("onnxruntime.dll");
|
||||
if (hExisting) {
|
||||
char existingPath[MAX_PATH] = {};
|
||||
GetModuleFileNameA(hExisting, existingPath, MAX_PATH);
|
||||
// Only warn if it looks like a system copy that could cause confusion
|
||||
std::string existingStr(existingPath);
|
||||
std::transform(existingStr.begin(), existingStr.end(),
|
||||
existingStr.begin(), ::tolower);
|
||||
if (existingStr.find("system32") != std::string::npos ||
|
||||
existingStr.find("syswow64") != std::string::npos) {
|
||||
std::cerr << "[EPLoader] WARNING: System ORT is resident ("
|
||||
<< existingPath << ") - our alias overrides it." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::wstring walias(ort_alias.begin(), ort_alias.end());
|
||||
s_ort_module = LoadLibraryExW(
|
||||
walias.c_str(),
|
||||
nullptr,
|
||||
LOAD_WITH_ALTERED_SEARCH_PATH); // resolves deps from temp_dir first
|
||||
|
||||
if (!s_ort_module) {
|
||||
DWORD err = GetLastError();
|
||||
DeleteTempDir(temp_dir);
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] LoadLibraryExW failed: " + ort_alias +
|
||||
"\n Windows error code: " + std::to_string(err) +
|
||||
"\n Ensure all CUDA runtime DLLs (cudart64_*.dll etc.)"
|
||||
"\n exist in: " + ep_dir);
|
||||
}
|
||||
|
||||
// Store paths so Shutdown() can clean up.
|
||||
s_temp_ort_path = ort_alias;
|
||||
s_temp_dir = temp_dir;
|
||||
|
||||
#else
|
||||
// ── Linux / macOS ─────────────────────────────────────────────────────────
|
||||
// Full path is the module key on ELF platforms — no collision issue.
|
||||
// RTLD_GLOBAL exposes ORT symbols to provider .so files loaded internally.
|
||||
s_ort_module = dlopen(src_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
|
||||
if (!s_ort_module) {
|
||||
throw std::runtime_error(
|
||||
std::string("[EPLoader] dlopen failed: ") + src_path +
|
||||
"\n " + dlerror());
|
||||
}
|
||||
#endif
|
||||
|
||||
// ── Bootstrap ORT C++ API ─────────────────────────────────────────────────
|
||||
using OrtGetApiBase_fn = const OrtApiBase* (ORT_API_CALL*)();
|
||||
#ifdef _WIN32
|
||||
auto fn = reinterpret_cast<OrtGetApiBase_fn>(
|
||||
GetProcAddress(s_ort_module, "OrtGetApiBase"));
|
||||
#else
|
||||
auto fn = reinterpret_cast<OrtGetApiBase_fn>(
|
||||
dlsym(s_ort_module, "OrtGetApiBase"));
|
||||
#endif
|
||||
if (!fn)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] OrtGetApiBase not exported — is this a genuine onnxruntime build?\n"
|
||||
" path: " + src_path);
|
||||
|
||||
const OrtApiBase* base = fn();
|
||||
if (!base)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] OrtGetApiBase() returned null from: " + src_path);
|
||||
|
||||
// ── Version negotiation ───────────────────────────────────────────────────
|
||||
int dllMaxApi = ORT_API_VERSION;
|
||||
{
|
||||
const char* verStr = base->GetVersionString();
|
||||
int major = 0, minor = 0;
|
||||
if (verStr && sscanf(verStr, "%d.%d", &major, &minor) == 2)
|
||||
dllMaxApi = minor;
|
||||
}
|
||||
int targetApi = min(ORT_API_VERSION, dllMaxApi);
|
||||
|
||||
if (targetApi < ORT_API_VERSION) {
|
||||
std::cerr << "[EPLoader] WARNING: ORT DLL version "
|
||||
<< base->GetVersionString()
|
||||
<< " supports up to API " << dllMaxApi
|
||||
<< " but headers expect API " << ORT_API_VERSION << ".\n"
|
||||
<< " Using API " << targetApi
|
||||
<< ". Consider upgrading onnxruntime.dll in ep/ to match SDK headers."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
const OrtApi* api = base->GetApi(targetApi);
|
||||
if (!api)
|
||||
throw std::runtime_error(
|
||||
"[EPLoader] GetApi(" + std::to_string(targetApi) +
|
||||
") returned null — the DLL may be corrupt.");
|
||||
|
||||
s_ort_api = api;
|
||||
Ort::Global<void>::api_ = api;
|
||||
|
||||
std::cout << "[EPLoader] ORT loaded successfully." << std::endl;
|
||||
std::cout << "[EPLoader] ORT DLL version : " << base->GetVersionString() << std::endl;
|
||||
std::cout << "[EPLoader] ORT header API : " << ORT_API_VERSION << std::endl;
|
||||
std::cout << "[EPLoader] ORT active API : " << targetApi << std::endl;
|
||||
}
|
||||
|
||||
// ── Shutdown ──────────────────────────────────────────────────────────────────
|
||||
// IMPORTANT: All Ort::Session, Ort::Env, Ort::SessionOptions objects
|
||||
// MUST be destroyed BEFORE calling Shutdown(), otherwise FreeLibrary will
|
||||
// unload code that is still referenced by those objects.
|
||||
void EPLoader::Shutdown()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(s_mutex);
|
||||
|
||||
if (s_ort_module) {
|
||||
std::cout << "[EPLoader] Unloading ORT DLL..." << std::endl;
|
||||
#ifdef _WIN32
|
||||
FreeLibrary(s_ort_module);
|
||||
// Delete the entire staging directory (all copied provider DLLs).
|
||||
if (!s_temp_dir.empty()) {
|
||||
DeleteTempDir(s_temp_dir);
|
||||
s_temp_dir.clear();
|
||||
}
|
||||
s_temp_ort_path.clear();
|
||||
#else
|
||||
dlclose(s_ort_module);
|
||||
#endif
|
||||
s_ort_module = nullptr;
|
||||
}
|
||||
|
||||
s_ort_api = nullptr;
|
||||
s_initialized = false;
|
||||
s_info = EPInfo{};
|
||||
std::cout << "[EPLoader] Shutdown complete." << std::endl;
|
||||
}
|
||||
|
||||
} // namespace ANSCENTER
|
||||
9
engines/ONNXEngine/EPLoader.h
Normal file
9
engines/ONNXEngine/EPLoader.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
// ============================================================================
|
||||
// Forwarding header — EPLoader moved to ANSLibsLoader
|
||||
//
|
||||
// This file is retained for backward compatibility. All consuming projects
|
||||
// should update their include paths to reference ANSLibsLoader/include/
|
||||
// directly. Once all projects are updated, this file can be removed.
|
||||
// ============================================================================
|
||||
#include "../ANSLibsLoader/include/EPLoader.h"
|
||||
1422
engines/ONNXEngine/ONNXEngine.cpp
Normal file
1422
engines/ONNXEngine/ONNXEngine.cpp
Normal file
File diff suppressed because it is too large
Load Diff
518
engines/ONNXEngine/ONNXEngine.h
Normal file
518
engines/ONNXEngine/ONNXEngine.h
Normal file
@@ -0,0 +1,518 @@
|
||||
#pragma once
|
||||
#ifndef ONNXEngine_H
|
||||
#define ONNXEngine_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include "opencv2/opencv.hpp"
|
||||
#include "EPLoader.h" // brings in EngineType via ANSLicenseHelper
|
||||
|
||||
#define LITEORT_CHAR wchar_t
|
||||
|
||||
#ifdef ENGINE_EXPORTS
|
||||
#define ONNXENGINE_API __declspec(dllexport)
|
||||
#else
|
||||
#define ONNXENGINE_API __declspec(dllimport)
|
||||
#endif
|
||||
|
||||
namespace ANSCENTER {
|
||||
|
||||
// ====================================================================
|
||||
// types
|
||||
// ====================================================================
|
||||
namespace types {
|
||||
|
||||
template<typename _T1 = float, typename _T2 = float>
|
||||
static inline void __assert_type()
|
||||
{
|
||||
static_assert(
|
||||
std::is_standard_layout_v<_T1> && std::is_trivially_copyable_v<_T1>
|
||||
&& std::is_standard_layout_v<_T2> && std::is_trivially_copyable_v<_T2>
|
||||
&& std::is_floating_point<_T2>::value
|
||||
&& (std::is_integral<_T1>::value || std::is_floating_point<_T1>::value),
|
||||
"not support type.");
|
||||
}
|
||||
|
||||
template<typename T1 = float, typename T2 = float>
|
||||
struct BoundingBoxType
|
||||
{
|
||||
typedef T1 value_type;
|
||||
typedef T2 score_type;
|
||||
|
||||
value_type x1, y1, x2, y2;
|
||||
score_type score;
|
||||
const char* label_text;
|
||||
unsigned int label;
|
||||
bool flag;
|
||||
|
||||
template<typename O1, typename O2 = score_type>
|
||||
BoundingBoxType<O1, O2> convert_type() const;
|
||||
|
||||
template<typename O1, typename O2 = score_type>
|
||||
value_type iou_of(const BoundingBoxType<O1, O2>& other) const;
|
||||
|
||||
value_type width() const;
|
||||
value_type height() const;
|
||||
value_type area() const;
|
||||
::cv::Rect rect() const;
|
||||
::cv::Point2i tl() const;
|
||||
::cv::Point2i rb() const;
|
||||
|
||||
BoundingBoxType() :
|
||||
x1(0), y1(0), x2(0), y2(0),
|
||||
score(0), label_text(nullptr), label(0), flag(false)
|
||||
{
|
||||
types::__assert_type<value_type, score_type>();
|
||||
}
|
||||
};
|
||||
|
||||
template class BoundingBoxType<int, float>;
|
||||
template class BoundingBoxType<float, float>;
|
||||
template class BoundingBoxType<double, double>;
|
||||
|
||||
typedef BoundingBoxType<int, float> Boxi;
|
||||
typedef BoundingBoxType<float, float> Boxf;
|
||||
typedef BoundingBoxType<double, double> Boxd;
|
||||
|
||||
typedef struct LandmarksType {
|
||||
std::vector<cv::Point2f> points;
|
||||
bool flag;
|
||||
LandmarksType() : flag(false) {}
|
||||
} Landmarks;
|
||||
|
||||
typedef Landmarks Landmarks2D;
|
||||
|
||||
typedef struct Landmarks3DType {
|
||||
std::vector<cv::Point3f> points;
|
||||
bool flag;
|
||||
Landmarks3DType() : flag(false) {}
|
||||
} Landmarks3D;
|
||||
|
||||
typedef struct BoxfWithLandmarksType {
|
||||
Boxf box;
|
||||
Landmarks landmarks;
|
||||
bool flag;
|
||||
BoxfWithLandmarksType() : flag(false) {}
|
||||
} BoxfWithLandmarks;
|
||||
|
||||
typedef struct EulerAnglesType {
|
||||
float yaw, pitch, roll;
|
||||
bool flag;
|
||||
EulerAnglesType() : flag(false) {}
|
||||
} EulerAngles;
|
||||
|
||||
typedef struct EmotionsType {
|
||||
float score;
|
||||
unsigned int label;
|
||||
const char* text;
|
||||
bool flag;
|
||||
EmotionsType() : flag(false) {}
|
||||
} Emotions;
|
||||
|
||||
typedef struct AgeType {
|
||||
float age;
|
||||
unsigned int age_interval[2];
|
||||
float interval_prob;
|
||||
bool flag;
|
||||
AgeType() : flag(false) {}
|
||||
} Age;
|
||||
|
||||
typedef struct GenderType {
|
||||
float score;
|
||||
unsigned int label;
|
||||
const char* text;
|
||||
bool flag;
|
||||
GenderType() : flag(false) {}
|
||||
} Gender;
|
||||
|
||||
typedef struct FaceContentType {
|
||||
std::vector<float> embedding;
|
||||
unsigned int dim;
|
||||
bool flag;
|
||||
FaceContentType() : flag(false) {}
|
||||
} FaceContent;
|
||||
|
||||
typedef struct SegmentContentType {
|
||||
cv::Mat class_mat;
|
||||
cv::Mat color_mat;
|
||||
std::unordered_map<int, std::string> names_map;
|
||||
bool flag;
|
||||
SegmentContentType() : flag(false) {}
|
||||
} SegmentContent;
|
||||
|
||||
typedef struct MattingContentType {
|
||||
cv::Mat fgr_mat;
|
||||
cv::Mat pha_mat;
|
||||
cv::Mat merge_mat;
|
||||
bool flag;
|
||||
MattingContentType() : flag(false) {}
|
||||
} MattingContent;
|
||||
|
||||
typedef struct SegmentationMaskContentType {
|
||||
cv::Mat mask;
|
||||
bool flag;
|
||||
SegmentationMaskContentType() : flag(false) {}
|
||||
} SegmentationMaskContent;
|
||||
|
||||
typedef struct ImageNetContentType {
|
||||
std::vector<float> scores;
|
||||
std::vector<const char*> texts;
|
||||
std::vector<unsigned int> labels;
|
||||
bool flag;
|
||||
ImageNetContentType() : flag(false) {}
|
||||
} ImageNetContent;
|
||||
|
||||
typedef ImageNetContent ClassificationContent;
|
||||
|
||||
typedef struct StyleContentType {
|
||||
cv::Mat mat;
|
||||
bool flag;
|
||||
StyleContentType() : flag(false) {}
|
||||
} StyleContent;
|
||||
|
||||
typedef struct SuperResolutionContentType {
|
||||
cv::Mat mat;
|
||||
bool flag;
|
||||
SuperResolutionContentType() : flag(false) {}
|
||||
} SuperResolutionContent;
|
||||
|
||||
typedef struct FaceParsingContentType {
|
||||
cv::Mat label;
|
||||
cv::Mat merge;
|
||||
bool flag;
|
||||
FaceParsingContentType() : flag(false) {}
|
||||
} FaceParsingContent;
|
||||
|
||||
typedef SegmentationMaskContent HairSegContent;
|
||||
typedef SegmentationMaskContent HeadSegContent;
|
||||
typedef SegmentationMaskContent FaceHairSegContent;
|
||||
typedef SegmentationMaskContent PortraitSegContent;
|
||||
|
||||
} // namespace types
|
||||
|
||||
// ====================================================================
|
||||
// utils
|
||||
// ====================================================================
|
||||
namespace utils {
|
||||
namespace transform {
|
||||
|
||||
enum { CHW = 0, HWC = 1 };
|
||||
|
||||
Ort::Value create_tensor(
|
||||
const cv::Mat& mat,
|
||||
const std::vector<int64_t>& tensor_dims,
|
||||
const Ort::MemoryInfo& memory_info_handler,
|
||||
std::vector<float>& tensor_value_handler,
|
||||
unsigned int data_format = CHW);
|
||||
|
||||
Ort::Value create_tensor_batch(
|
||||
const std::vector<cv::Mat>& batch_mats,
|
||||
const std::vector<int64_t>& tensor_dims,
|
||||
const Ort::MemoryInfo& memory_info_handler,
|
||||
std::vector<float>& tensor_value_handler,
|
||||
unsigned int data_format = CHW);
|
||||
|
||||
Ort::Value create_video_tensor_5d(
|
||||
const std::deque<cv::Mat>& frames,
|
||||
const std::vector<int64_t>& tensor_dims,
|
||||
const Ort::MemoryInfo& memory_info_handler,
|
||||
std::vector<float>& tensor_value_handler);
|
||||
|
||||
cv::Mat normalize(const cv::Mat& mat, float mean, float scale);
|
||||
cv::Mat normalize(const cv::Mat& mat, const float mean[3], const float scale[3]);
|
||||
void normalize(const cv::Mat& inmat, cv::Mat& outmat, float mean, float scale);
|
||||
void normalize_inplace(cv::Mat& mat_inplace, float mean, float scale);
|
||||
void normalize_inplace(cv::Mat& mat_inplace, const float mean[3], const float scale[3]);
|
||||
|
||||
} // namespace transform
|
||||
} // namespace utils
|
||||
|
||||
// ====================================================================
|
||||
// Helpers
|
||||
// ====================================================================
|
||||
inline static std::string OrtCompatiableGetInputName(
|
||||
size_t index, OrtAllocator* allocator, Ort::Session* ort_session)
|
||||
{
|
||||
return std::string(ort_session->GetInputNameAllocated(index, allocator).get());
|
||||
}
|
||||
|
||||
inline static std::string OrtCompatiableGetOutputName(
|
||||
size_t index, OrtAllocator* allocator, Ort::Session* ort_session)
|
||||
{
|
||||
return std::string(ort_session->GetOutputNameAllocated(index, allocator).get());
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// BasicOrtHandler
|
||||
// ====================================================================
|
||||
class ONNXENGINE_API BasicOrtHandler
|
||||
{
|
||||
protected:
|
||||
|
||||
const char* input_name = nullptr;
|
||||
std::vector<const char*> input_node_names;
|
||||
std::vector<std::string> input_node_names_;
|
||||
std::vector<int64_t> input_node_dims;
|
||||
std::size_t input_tensor_size = 1;
|
||||
std::vector<float> input_values_handler;
|
||||
|
||||
std::vector<const char*> output_node_names;
|
||||
std::vector<std::string> output_node_names_;
|
||||
std::vector<std::vector<int64_t>> output_node_dims;
|
||||
int num_outputs = 1;
|
||||
|
||||
Ort::Env* ort_env = nullptr; // ← pointer, no in-class init
|
||||
Ort::Session* ort_session = nullptr;
|
||||
Ort::MemoryInfo* memory_info_handler = nullptr;
|
||||
|
||||
std::wstring onnx_path_w; // ← owns the wstring storage
|
||||
const LITEORT_CHAR* onnx_path = nullptr; // ← points into onnx_path_w
|
||||
const char* log_id = nullptr;
|
||||
|
||||
|
||||
protected:
|
||||
const unsigned int num_threads;
|
||||
EngineType m_engineType;
|
||||
|
||||
protected:
|
||||
// Default: hardware auto-detection via ANSLicenseHelper through EPLoader
|
||||
explicit BasicOrtHandler(const std::string& _onnx_path,
|
||||
unsigned int _num_threads = 1);
|
||||
|
||||
// Explicit engine override per-session
|
||||
explicit BasicOrtHandler(const std::string& _onnx_path,
|
||||
EngineType engineType,
|
||||
unsigned int _num_threads = 1);
|
||||
|
||||
virtual ~BasicOrtHandler();
|
||||
|
||||
BasicOrtHandler(const BasicOrtHandler&) = delete;
|
||||
BasicOrtHandler& operator=(const BasicOrtHandler&) = delete;
|
||||
private:
|
||||
void initialize_handler();
|
||||
protected:
|
||||
virtual Ort::Value transform(const cv::Mat& mat) = 0;
|
||||
virtual Ort::Value transformBatch(const std::vector<cv::Mat>& images) = 0;
|
||||
|
||||
// EP-specific session option builders
|
||||
bool TryAppendCUDA(Ort::SessionOptions& opts);
|
||||
bool TryAppendDirectML(Ort::SessionOptions& opts);
|
||||
bool TryAppendOpenVINO(Ort::SessionOptions& opts);
|
||||
};
|
||||
|
||||
// ====================================================================
|
||||
// SCRFD — face detection
|
||||
// ====================================================================
|
||||
class SCRFD : public BasicOrtHandler
|
||||
{
|
||||
public:
|
||||
explicit SCRFD(const std::string& _onnx_path,unsigned int _num_threads = 1);
|
||||
explicit SCRFD(const std::string& _onnx_path,EngineType engineType,unsigned int _num_threads = 1);
|
||||
~SCRFD() override = default;
|
||||
|
||||
void detect(const cv::Mat& mat,
|
||||
std::vector<types::BoxfWithLandmarks>& detected_boxes_kps,
|
||||
float score_threshold = 0.3f,
|
||||
float iou_threshold = 0.45f,
|
||||
unsigned int topk = 400);
|
||||
|
||||
private:
|
||||
typedef struct { float cx, cy, stride; } SCRFDPoint;
|
||||
typedef struct { float ratio; int dw, dh; bool flag; } SCRFDScaleParams;
|
||||
|
||||
const float mean_vals[3] = { 127.5f, 127.5f, 127.5f };
|
||||
const float scale_vals[3] = { 1.f / 128.f, 1.f / 128.f, 1.f / 128.f };
|
||||
|
||||
unsigned int fmc = 3;
|
||||
bool use_kps = false;
|
||||
unsigned int num_anchors = 2;
|
||||
std::vector<int> feat_stride_fpn = { 8, 16, 32 };
|
||||
std::unordered_map<int, std::vector<SCRFDPoint>> center_points;
|
||||
bool center_points_is_update = false;
|
||||
|
||||
static constexpr unsigned int nms_pre = 1000;
|
||||
static constexpr unsigned int max_nms = 30000;
|
||||
|
||||
Ort::Value transform(const cv::Mat& mat_rs) override;
|
||||
Ort::Value transformBatch(const std::vector<cv::Mat>& images) override;
|
||||
|
||||
void initial_context();
|
||||
void resize_unscale(const cv::Mat& mat, cv::Mat& mat_rs,
|
||||
int target_height, int target_width,
|
||||
SCRFDScaleParams& scale_params);
|
||||
void generate_points(int target_height, int target_width);
|
||||
|
||||
void generate_bboxes_kps(const SCRFDScaleParams& scale_params,
|
||||
std::vector<types::BoxfWithLandmarks>& bbox_kps_collection,
|
||||
std::vector<Ort::Value>& output_tensors,
|
||||
float score_threshold,
|
||||
float img_height, float img_width);
|
||||
|
||||
void generate_bboxes_single_stride(
|
||||
const SCRFDScaleParams& scale_params,
|
||||
Ort::Value& score_pred, Ort::Value& bbox_pred,
|
||||
unsigned int stride, float score_threshold,
|
||||
float img_height, float img_width,
|
||||
std::vector<types::BoxfWithLandmarks>& bbox_kps_collection);
|
||||
|
||||
void generate_bboxes_kps_single_stride(
|
||||
const SCRFDScaleParams& scale_params,
|
||||
Ort::Value& score_pred, Ort::Value& bbox_pred, Ort::Value& kps_pred,
|
||||
unsigned int stride, float score_threshold,
|
||||
float img_height, float img_width,
|
||||
std::vector<types::BoxfWithLandmarks>& bbox_kps_collection);
|
||||
|
||||
void nms_bboxes_kps(std::vector<types::BoxfWithLandmarks>& input,
|
||||
std::vector<types::BoxfWithLandmarks>& output,
|
||||
float iou_threshold, unsigned int topk);
|
||||
};
|
||||
|
||||
// ====================================================================
|
||||
// GlintArcFace — face recognition
|
||||
// ====================================================================
|
||||
class GlintArcFace : public BasicOrtHandler
|
||||
{
|
||||
public:
|
||||
explicit GlintArcFace(const std::string& _onnx_path,
|
||||
unsigned int _num_threads = 1)
|
||||
: BasicOrtHandler(_onnx_path, _num_threads)
|
||||
{
|
||||
}
|
||||
|
||||
explicit GlintArcFace(const std::string& _onnx_path,
|
||||
EngineType engineType,
|
||||
unsigned int _num_threads = 1)
|
||||
: BasicOrtHandler(_onnx_path, engineType, _num_threads) {
|
||||
}
|
||||
|
||||
~GlintArcFace() override = default;
|
||||
|
||||
void detect(const cv::Mat& mat, types::FaceContent& face_content);
|
||||
void detectBatch(const std::vector<cv::Mat>& images,
|
||||
std::vector<types::FaceContent>& face_contents);
|
||||
|
||||
private:
|
||||
static constexpr float mean_val = 127.5f;
|
||||
static constexpr float scale_val = 1.f / 127.5f;
|
||||
|
||||
Ort::Value transform(const cv::Mat& mat) override;
|
||||
Ort::Value transformBatch(const std::vector<cv::Mat>& images) override;
|
||||
};
|
||||
|
||||
// ====================================================================
|
||||
// GlintCosFace — face recognition
|
||||
// ====================================================================
|
||||
class GlintCosFace : public BasicOrtHandler
|
||||
{
|
||||
public:
|
||||
explicit GlintCosFace(const std::string& _onnx_path,
|
||||
unsigned int _num_threads = 1)
|
||||
: BasicOrtHandler(_onnx_path, _num_threads)
|
||||
{
|
||||
}
|
||||
|
||||
explicit GlintCosFace(const std::string& _onnx_path,
|
||||
EngineType engineType,
|
||||
unsigned int _num_threads = 1)
|
||||
: BasicOrtHandler(_onnx_path, engineType, _num_threads)
|
||||
{
|
||||
}
|
||||
|
||||
~GlintCosFace() override = default;
|
||||
|
||||
void detect(const cv::Mat& mat, types::FaceContent& face_content);
|
||||
void detectBatch(const std::vector<cv::Mat>& images,
|
||||
std::vector<types::FaceContent>& face_contents);
|
||||
|
||||
private:
|
||||
static constexpr float mean_val = 127.5f;
|
||||
static constexpr float scale_val = 1.f / 127.5f;
|
||||
|
||||
Ort::Value transform(const cv::Mat& mat) override;
|
||||
Ort::Value transformBatch(const std::vector<cv::Mat>& images) override;
|
||||
};
|
||||
|
||||
// ====================================================================
|
||||
// MOVINET — action recognition
|
||||
// ====================================================================
|
||||
class MOVINET : public BasicOrtHandler
|
||||
{
|
||||
public:
|
||||
explicit MOVINET(const std::string& _onnx_path,
|
||||
unsigned int _num_threads = 1);
|
||||
|
||||
explicit MOVINET(const std::string& _onnx_path,
|
||||
int _temporal, int _width, int _height, int _channels = 3,
|
||||
unsigned int _num_threads = 1);
|
||||
|
||||
explicit MOVINET(const std::string& _onnx_path,
|
||||
EngineType engineType,
|
||||
unsigned int _num_threads = 1);
|
||||
|
||||
explicit MOVINET(const std::string& _onnx_path,
|
||||
EngineType engineType,
|
||||
int _temporal, int _width, int _height, int _channels = 3,
|
||||
unsigned int _num_threads = 1);
|
||||
|
||||
~MOVINET() override = default;
|
||||
|
||||
void inference(const std::deque<cv::Mat>& frames,
|
||||
std::pair<int, float>& out_result);
|
||||
|
||||
private:
|
||||
struct InputConfig {
|
||||
int temporal = 16;
|
||||
int width = 172;
|
||||
int height = 172;
|
||||
int channels = 3;
|
||||
} input_params;
|
||||
|
||||
struct OutputConfig {
|
||||
int num_classes = 2;
|
||||
} output_params;
|
||||
|
||||
std::string _MoviNetInputName;
|
||||
std::string _MoviNetOutputName;
|
||||
std::vector<float> input_tensor_values;
|
||||
|
||||
void init_io_names();
|
||||
|
||||
Ort::Value transform(const std::deque<cv::Mat>& frames);
|
||||
std::pair<int, float> post_processing(const float* pOutput);
|
||||
|
||||
// Required by BasicOrtHandler pure virtuals
|
||||
Ort::Value transform(const cv::Mat& mat) override;
|
||||
Ort::Value transformBatch(const std::vector<cv::Mat>& images) override;
|
||||
};
|
||||
|
||||
// ====================================================================
|
||||
// BoundingBoxType template implementations
|
||||
// ====================================================================
|
||||
template<typename T1, typename T2>
|
||||
template<typename O1, typename O2>
|
||||
inline ANSCENTER::types::BoundingBoxType<O1, O2>
|
||||
ANSCENTER::types::BoundingBoxType<T1, T2>::convert_type() const
|
||||
{
|
||||
types::__assert_type<O1, O2>();
|
||||
types::__assert_type<value_type, score_type>();
|
||||
BoundingBoxType<O1, O2> other;
|
||||
other.x1 = static_cast<O1>(x1);
|
||||
other.y1 = static_cast<O1>(y1);
|
||||
other.x2 = static_cast<O1>(x2);
|
||||
other.y2 = static_cast<O1>(y2);
|
||||
other.score = static_cast<O2>(score);
|
||||
other.label_text = label_text;
|
||||
other.label = label;
|
||||
other.flag = flag;
|
||||
return other;
|
||||
}
|
||||
|
||||
} // namespace ANSCENTER
|
||||
|
||||
#endif // ONNXEngine_H
|
||||
834
engines/ONNXEngine/ONNXSAM3.cpp
Normal file
834
engines/ONNXEngine/ONNXSAM3.cpp
Normal file
@@ -0,0 +1,834 @@
|
||||
#include "ONNXSAM3.h"
|
||||
#include "ONNXEngine.h" // OrtCompatiableGetInputName/OutputName helpers
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace ANSCENTER
|
||||
{
|
||||
// ====================================================================
|
||||
// SessionBundle destructor
|
||||
// ====================================================================
|
||||
|
||||
ONNXSAM3::SessionBundle::~SessionBundle()
|
||||
{
|
||||
if (session) {
|
||||
delete session;
|
||||
session = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// EP helpers (same logic as BasicOrtHandler)
|
||||
// ====================================================================
|
||||
|
||||
bool ONNXSAM3::TryAppendCUDA(Ort::SessionOptions& session_options)
|
||||
{
|
||||
try {
|
||||
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
|
||||
Ort::GetApi().CreateCUDAProviderOptions(&cuda_options);
|
||||
const char* keys[] = {
|
||||
"device_id",
|
||||
"arena_extend_strategy",
|
||||
"cudnn_conv_algo_search",
|
||||
"cudnn_conv_use_max_workspace", // reduce cuDNN temp memory
|
||||
"do_copy_in_default_stream", // allow async copies
|
||||
};
|
||||
const char* values[] = {
|
||||
"0",
|
||||
"kSameAsRequested",
|
||||
"HEURISTIC",
|
||||
"0", // 0 = minimal workspace
|
||||
"0", // 0 = use separate stream
|
||||
};
|
||||
Ort::GetApi().UpdateCUDAProviderOptions(cuda_options, keys, values, 5);
|
||||
session_options.AppendExecutionProvider_CUDA_V2(*cuda_options);
|
||||
Ort::GetApi().ReleaseCUDAProviderOptions(cuda_options);
|
||||
std::cout << "[ONNXSAM3] CUDA EP attached." << std::endl;
|
||||
return true;
|
||||
}
|
||||
catch (const Ort::Exception& e) {
|
||||
std::cerr << "[ONNXSAM3] CUDA EP failed: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ONNXSAM3::TryAppendDirectML(Ort::SessionOptions& session_options)
|
||||
{
|
||||
try {
|
||||
std::unordered_map<std::string, std::string> options = { {"device_id","0"} };
|
||||
session_options.AppendExecutionProvider("DML", options);
|
||||
std::cout << "[ONNXSAM3] DirectML EP attached." << std::endl;
|
||||
return true;
|
||||
}
|
||||
catch (const Ort::Exception& e) {
|
||||
std::cerr << "[ONNXSAM3] DirectML EP failed: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ONNXSAM3::TryAppendOpenVINO(Ort::SessionOptions& session_options)
|
||||
{
|
||||
std::vector<std::unordered_map<std::string, std::string>> configs = {
|
||||
{{"device_type","AUTO:NPU,GPU"},{"precision","FP16"},{"num_of_threads","4"},{"num_streams","4"}},
|
||||
{{"device_type","GPU.0"}, {"precision","FP16"},{"num_of_threads","4"},{"num_streams","4"}},
|
||||
{{"device_type","AUTO:GPU,CPU"},{"precision","FP16"},{"num_of_threads","4"},{"num_streams","4"}}
|
||||
};
|
||||
for (const auto& config : configs) {
|
||||
try {
|
||||
session_options.AppendExecutionProvider_OpenVINO_V2(config);
|
||||
std::cout << "[ONNXSAM3] OpenVINO EP attached (" << config.at("device_type") << ")." << std::endl;
|
||||
return true;
|
||||
}
|
||||
catch (const Ort::Exception&) { /* try next */ }
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// createSessionBundle — create one ORT session with EP + external data
|
||||
// ====================================================================
|
||||
|
||||
void ONNXSAM3::createSessionBundle(SessionBundle& bundle,
|
||||
const std::string& onnxPath,
|
||||
const std::string& label,
|
||||
bool forceCPU,
|
||||
GraphOptimizationLevel optLevel)
|
||||
{
|
||||
std::cout << "[ONNXSAM3] Creating " << label << " session..." << std::endl;
|
||||
|
||||
Ort::SessionOptions opts;
|
||||
opts.SetIntraOpNumThreads(m_numThreads);
|
||||
opts.SetGraphOptimizationLevel(optLevel);
|
||||
opts.SetLogSeverityLevel(4);
|
||||
|
||||
// Determine effective engine type
|
||||
EngineType engine = forceCPU ? EngineType::CPU : m_engineType;
|
||||
if (forceCPU)
|
||||
std::cout << "[ONNXSAM3] " << label << ": forced to CPU to save GPU memory." << std::endl;
|
||||
|
||||
std::vector<std::string> available = Ort::GetAvailableProviders();
|
||||
auto hasProvider = [&](const std::string& name) {
|
||||
return std::find(available.begin(), available.end(), name) != available.end();
|
||||
};
|
||||
|
||||
bool epAttached = false;
|
||||
{
|
||||
switch (engine)
|
||||
{
|
||||
case EngineType::NVIDIA_GPU:
|
||||
if (hasProvider("CUDAExecutionProvider"))
|
||||
epAttached = TryAppendCUDA(opts);
|
||||
break;
|
||||
case EngineType::AMD_GPU:
|
||||
if (hasProvider("DmlExecutionProvider"))
|
||||
epAttached = TryAppendDirectML(opts);
|
||||
break;
|
||||
case EngineType::OPENVINO_GPU:
|
||||
if (hasProvider("OpenVINOExecutionProvider"))
|
||||
epAttached = TryAppendOpenVINO(opts);
|
||||
break;
|
||||
case EngineType::CPU:
|
||||
default:
|
||||
epAttached = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!epAttached)
|
||||
std::cout << "[ONNXSAM3] " << label << ": using CPU EP." << std::endl;
|
||||
|
||||
// -- CWD workaround for external data resolution --
|
||||
std::filesystem::path modelFsPath(onnxPath);
|
||||
std::filesystem::path modelDir = modelFsPath.parent_path();
|
||||
std::filesystem::path prevCwd = std::filesystem::current_path();
|
||||
|
||||
if (!modelDir.empty() && std::filesystem::is_directory(modelDir))
|
||||
std::filesystem::current_path(modelDir);
|
||||
|
||||
// -- Pre-load external data file if one matches the model stem --
|
||||
// The external data filename stored inside the .onnx protobuf may
|
||||
// differ from the .onnx filename on disk (e.g. anssam3_image_encoder.onnx
|
||||
// internally references sam3_image_encoder.onnx.data). We only
|
||||
// pre-load when a stem-based candidate exists on disk. If no match
|
||||
// is found, we load the model from its FILE PATH (not memory buffer)
|
||||
// so that ORT resolves external data relative to the model directory.
|
||||
std::vector<char> extDataBuffer;
|
||||
std::filesystem::path extDataPath;
|
||||
{
|
||||
std::wstring stem = modelFsPath.stem().wstring();
|
||||
std::vector<std::filesystem::path> candidates = {
|
||||
modelDir / (stem + L".onnx_data"), // monolithic convention
|
||||
modelDir / (stem + L".onnx.data"), // split-model convention
|
||||
};
|
||||
for (auto& c : candidates) {
|
||||
if (std::filesystem::exists(c)) {
|
||||
extDataPath = c;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (extDataPath.empty()) {
|
||||
std::cout << "[ONNXSAM3] " << label
|
||||
<< ": no stem-matched external data; "
|
||||
<< "ORT will resolve from model directory." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (!extDataPath.empty() && std::filesystem::exists(extDataPath)) {
|
||||
auto fileSize = std::filesystem::file_size(extDataPath);
|
||||
std::cout << "[ONNXSAM3] " << label << ": external data "
|
||||
<< extDataPath.filename().string()
|
||||
<< " (" << (fileSize / (1024*1024)) << " MB)" << std::endl;
|
||||
try {
|
||||
std::ifstream ifs(extDataPath, std::ios::binary);
|
||||
if (ifs) {
|
||||
extDataBuffer.resize(static_cast<size_t>(fileSize));
|
||||
ifs.read(extDataBuffer.data(), static_cast<std::streamsize>(fileSize));
|
||||
ifs.close();
|
||||
|
||||
std::vector<std::basic_string<ORTCHAR_T>> extFileNames = {
|
||||
extDataPath.filename().wstring()
|
||||
};
|
||||
std::vector<char*> extBuffers = { extDataBuffer.data() };
|
||||
std::vector<size_t> extLengths = { extDataBuffer.size() };
|
||||
opts.AddExternalInitializersFromFilesInMemory(
|
||||
extFileNames, extBuffers, extLengths);
|
||||
}
|
||||
}
|
||||
catch (const std::bad_alloc&) {
|
||||
std::cerr << "[ONNXSAM3] " << label
|
||||
<< ": could not allocate memory for external data. "
|
||||
<< "ORT will use file mapping." << std::endl;
|
||||
extDataBuffer.clear();
|
||||
extDataBuffer.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
// -- Load .onnx proto into memory --
|
||||
std::vector<char> modelBuffer;
|
||||
bool useModelBuffer = false;
|
||||
try {
|
||||
auto modelFileSize = std::filesystem::file_size(modelFsPath);
|
||||
modelBuffer.resize(static_cast<size_t>(modelFileSize));
|
||||
std::ifstream mifs(modelFsPath, std::ios::binary);
|
||||
if (mifs) {
|
||||
mifs.read(modelBuffer.data(), static_cast<std::streamsize>(modelFileSize));
|
||||
mifs.close();
|
||||
useModelBuffer = true;
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << "[ONNXSAM3] " << label
|
||||
<< ": could not read model file: " << e.what() << std::endl;
|
||||
}
|
||||
|
||||
// -- Create session (with GPU → CPU fallback) --
|
||||
std::wstring onnxPathW(onnxPath.begin(), onnxPath.end());
|
||||
|
||||
auto doCreate = [&](Ort::SessionOptions& sopts, const char* tag) {
|
||||
// Use memory-buffer loading when external data has been pre-loaded;
|
||||
// otherwise use file-path loading so ORT can resolve external data
|
||||
// relative to the model's directory on disk.
|
||||
if (useModelBuffer && !extDataBuffer.empty())
|
||||
bundle.session = new Ort::Session(*m_env, modelBuffer.data(), modelBuffer.size(), sopts);
|
||||
else
|
||||
bundle.session = new Ort::Session(*m_env, onnxPathW.c_str(), sopts);
|
||||
std::cout << "[ONNXSAM3] " << label << " session created (" << tag << ")." << std::endl;
|
||||
};
|
||||
|
||||
try {
|
||||
doCreate(opts, "primary EP");
|
||||
}
|
||||
catch (const Ort::Exception& e) {
|
||||
std::cerr << "[ONNXSAM3] " << label << " session FAILED: " << e.what() << std::endl;
|
||||
if (engine != EngineType::CPU && epAttached) {
|
||||
std::cerr << "[ONNXSAM3] " << label << ": retrying with CPU..." << std::endl;
|
||||
Ort::SessionOptions cpuOpts;
|
||||
cpuOpts.SetIntraOpNumThreads(m_numThreads);
|
||||
cpuOpts.SetGraphOptimizationLevel(optLevel);
|
||||
cpuOpts.SetLogSeverityLevel(4);
|
||||
|
||||
if (!extDataBuffer.empty()) {
|
||||
std::vector<std::basic_string<ORTCHAR_T>> extFileNames = {
|
||||
extDataPath.filename().wstring()
|
||||
};
|
||||
std::vector<char*> extBuffers = { extDataBuffer.data() };
|
||||
std::vector<size_t> extLengths = { extDataBuffer.size() };
|
||||
cpuOpts.AddExternalInitializersFromFilesInMemory(
|
||||
extFileNames, extBuffers, extLengths);
|
||||
}
|
||||
doCreate(cpuOpts, "CPU fallback");
|
||||
} else {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
// Restore CWD & free buffers
|
||||
std::filesystem::current_path(prevCwd);
|
||||
extDataBuffer.clear(); extDataBuffer.shrink_to_fit();
|
||||
modelBuffer.clear(); modelBuffer.shrink_to_fit();
|
||||
|
||||
// -- Read input/output names --
|
||||
Ort::Allocator allocator(*bundle.session, *m_memInfo);
|
||||
|
||||
size_t numInputs = bundle.session->GetInputCount();
|
||||
bundle.inputNames_.resize(numInputs);
|
||||
bundle.inputNames.resize(numInputs);
|
||||
for (size_t i = 0; i < numInputs; ++i) {
|
||||
bundle.inputNames_[i] = OrtCompatiableGetInputName(i, allocator, bundle.session);
|
||||
bundle.inputNames[i] = bundle.inputNames_[i].c_str();
|
||||
}
|
||||
|
||||
size_t numOutputs = bundle.session->GetOutputCount();
|
||||
bundle.outputNames_.resize(numOutputs);
|
||||
bundle.outputNames.resize(numOutputs);
|
||||
for (size_t i = 0; i < numOutputs; ++i) {
|
||||
bundle.outputNames_[i] = OrtCompatiableGetOutputName(i, allocator, bundle.session);
|
||||
bundle.outputNames[i] = bundle.outputNames_[i].c_str();
|
||||
}
|
||||
|
||||
// Log I/O info
|
||||
for (size_t i = 0; i < numInputs; ++i) {
|
||||
auto info = bundle.session->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo();
|
||||
auto shape = info.GetShape();
|
||||
std::cout << "[ONNXSAM3] " << label << " input[" << i << "]: "
|
||||
<< bundle.inputNames_[i] << " shape=[";
|
||||
for (size_t d = 0; d < shape.size(); ++d) {
|
||||
if (d > 0) std::cout << ",";
|
||||
std::cout << shape[d];
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
for (size_t i = 0; i < numOutputs; ++i) {
|
||||
auto info = bundle.session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo();
|
||||
auto shape = info.GetShape();
|
||||
std::cout << "[ONNXSAM3] " << label << " output[" << i << "]: "
|
||||
<< bundle.outputNames_[i] << " shape=[";
|
||||
for (size_t d = 0; d < shape.size(); ++d) {
|
||||
if (d > 0) std::cout << ",";
|
||||
std::cout << shape[d];
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// Constructor
|
||||
// ====================================================================
|
||||
|
||||
ONNXSAM3::ONNXSAM3(const std::string& modelFolder,
|
||||
EngineType engineType,
|
||||
unsigned int num_threads)
|
||||
: m_engineType(engineType),
|
||||
m_numThreads(num_threads),
|
||||
m_modelFolder(modelFolder)
|
||||
{
|
||||
// Initialize ORT API
|
||||
const auto& epInfo = EPLoader::Current();
|
||||
if (Ort::Global<void>::api_ == nullptr)
|
||||
Ort::InitApi(static_cast<const OrtApi*>(EPLoader::GetOrtApiRaw()));
|
||||
|
||||
m_env = new Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ONNXSAM3");
|
||||
m_memInfo = new Ort::MemoryInfo(
|
||||
Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault));
|
||||
|
||||
// Build paths
|
||||
std::string imgPath = modelFolder + "\\sam3_image_encoder.onnx";
|
||||
std::string langPath = modelFolder + "\\sam3_language_encoder.onnx";
|
||||
std::string decPath = modelFolder + "\\sam3_decoder.onnx";
|
||||
|
||||
// Create 3 sessions.
|
||||
// Language encoder runs on CPU: it is only called once per prompt
|
||||
// change and keeping it on GPU wastes ~1.5 GB of VRAM that the
|
||||
// image encoder and decoder need for their activation tensors.
|
||||
// Image encoder uses ORT_ENABLE_BASIC: enables constant folding
|
||||
// and redundant-node elimination without the complex fusions
|
||||
// (MatMulScaleFusion) that are slower on this model and risk OOM
|
||||
// on 8 GB GPUs. Benchmarked: BASIC=4.6s vs ALL=5.2s.
|
||||
createSessionBundle(m_imageEncoder, imgPath, "ImageEncoder", false,
|
||||
GraphOptimizationLevel::ORT_ENABLE_BASIC);
|
||||
createSessionBundle(m_langEncoder, langPath, "LangEncoder", true,
|
||||
GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
createSessionBundle(m_decoder, decPath, "Decoder");
|
||||
|
||||
std::cout << "[ONNXSAM3] All 3 sessions created successfully." << std::endl;
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// Destructor
|
||||
// ====================================================================
|
||||
|
||||
ONNXSAM3::~ONNXSAM3()
|
||||
{
|
||||
// Sessions must be destroyed BEFORE the Ort::Env they were created
|
||||
// with. Member destructors run after the destructor body, so we
|
||||
// must explicitly release sessions here first.
|
||||
if (m_decoder.session) { delete m_decoder.session; m_decoder.session = nullptr; }
|
||||
if (m_langEncoder.session) { delete m_langEncoder.session; m_langEncoder.session = nullptr; }
|
||||
if (m_imageEncoder.session) { delete m_imageEncoder.session; m_imageEncoder.session = nullptr; }
|
||||
|
||||
if (m_memInfo) { delete m_memInfo; m_memInfo = nullptr; }
|
||||
if (m_env) { delete m_env; m_env = nullptr; }
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// preprocessImage — BGR → RGB, resize to 1008, HWC→CHW, uint8
|
||||
// ====================================================================
|
||||
|
||||
void ONNXSAM3::preprocessImage(const cv::Mat& mat, std::vector<uint8_t>& buffer)
|
||||
{
|
||||
// 3-model image encoder expects uint8 [3, 1008, 1008]
|
||||
cv::Mat resized;
|
||||
cv::resize(mat, resized, cv::Size(m_inputSize, m_inputSize));
|
||||
|
||||
cv::Mat rgb;
|
||||
cv::cvtColor(resized, rgb, cv::COLOR_BGR2RGB);
|
||||
|
||||
const size_t planeSize = static_cast<size_t>(m_inputSize) * m_inputSize;
|
||||
buffer.resize(3 * planeSize);
|
||||
|
||||
// HWC → CHW via cv::split + memcpy (much faster than per-pixel loop)
|
||||
cv::Mat channels[3];
|
||||
cv::split(rgb, channels);
|
||||
for (int c = 0; c < 3; ++c)
|
||||
std::memcpy(buffer.data() + c * planeSize, channels[c].data, planeSize);
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// setPrompt — run language encoder, cache results
|
||||
// ====================================================================
|
||||
|
||||
void ONNXSAM3::setPrompt(const std::vector<int64_t>& inputIds,
|
||||
const std::vector<int64_t>& attentionMask)
|
||||
{
|
||||
if (!m_langEncoder.session) {
|
||||
std::cerr << "[ONNXSAM3] Language encoder not initialized." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// Language encoder input: "tokens" [1, 32] int64
|
||||
std::vector<int64_t> tokenShape = { 1, static_cast<int64_t>(inputIds.size()) };
|
||||
m_tokenLength = static_cast<int>(inputIds.size());
|
||||
|
||||
// We need a non-const copy for CreateTensor
|
||||
std::vector<int64_t> tokenData = inputIds;
|
||||
|
||||
Ort::Value tokenTensor = Ort::Value::CreateTensor<int64_t>(
|
||||
*m_memInfo, tokenData.data(), tokenData.size(),
|
||||
tokenShape.data(), tokenShape.size());
|
||||
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.push_back(std::move(tokenTensor));
|
||||
|
||||
// Run language encoder
|
||||
std::cout << "[ONNXSAM3] Running language encoder..." << std::endl;
|
||||
auto outputs = m_langEncoder.session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
m_langEncoder.inputNames.data(),
|
||||
inputs.data(),
|
||||
inputs.size(),
|
||||
m_langEncoder.outputNames.data(),
|
||||
m_langEncoder.outputNames.size());
|
||||
|
||||
// Language encoder outputs (from Python analysis):
|
||||
// output[0]: text_attention_mask [1, 32] bool → "language_mask" for decoder
|
||||
// output[1]: text_memory [32, 1, 256] float32 → "language_features" for decoder
|
||||
// output[2]: text_embeds [32, 1, 1024] float32 → NOT used by decoder
|
||||
|
||||
// Find outputs by name or fall back to index
|
||||
int maskIdx = -1, featIdx = -1;
|
||||
for (size_t i = 0; i < m_langEncoder.outputNames_.size(); ++i) {
|
||||
const auto& name = m_langEncoder.outputNames_[i];
|
||||
if (name.find("attention_mask") != std::string::npos ||
|
||||
name.find("text_attention") != std::string::npos) {
|
||||
maskIdx = static_cast<int>(i);
|
||||
}
|
||||
else if (name.find("text_memory") != std::string::npos ||
|
||||
name.find("memory") != std::string::npos) {
|
||||
featIdx = static_cast<int>(i);
|
||||
}
|
||||
}
|
||||
// Fallback: first output is mask, second is features
|
||||
if (maskIdx < 0) maskIdx = 0;
|
||||
if (featIdx < 0) featIdx = 1;
|
||||
|
||||
// Cache language mask (bool)
|
||||
{
|
||||
auto info = outputs[maskIdx].GetTensorTypeAndShapeInfo();
|
||||
m_cachedLangMaskShape = info.GetShape();
|
||||
size_t count = info.GetElementCount();
|
||||
const bool* data = outputs[maskIdx].GetTensorData<bool>();
|
||||
m_cachedLangMask.resize(count);
|
||||
for (size_t i = 0; i < count; ++i)
|
||||
m_cachedLangMask[i] = data[i] ? 1 : 0;
|
||||
}
|
||||
|
||||
// Cache language features (float32)
|
||||
{
|
||||
auto info = outputs[featIdx].GetTensorTypeAndShapeInfo();
|
||||
m_cachedLangFeaturesShape = info.GetShape();
|
||||
size_t count = info.GetElementCount();
|
||||
const float* data = outputs[featIdx].GetTensorData<float>();
|
||||
m_cachedLangFeatures.assign(data, data + count);
|
||||
}
|
||||
|
||||
m_promptSet = true;
|
||||
std::cout << "[ONNXSAM3] Language encoder done. Mask shape=[";
|
||||
for (size_t i = 0; i < m_cachedLangMaskShape.size(); ++i) {
|
||||
if (i > 0) std::cout << ",";
|
||||
std::cout << m_cachedLangMaskShape[i];
|
||||
}
|
||||
std::cout << "] Features shape=[";
|
||||
for (size_t i = 0; i < m_cachedLangFeaturesShape.size(); ++i) {
|
||||
if (i > 0) std::cout << ",";
|
||||
std::cout << m_cachedLangFeaturesShape[i];
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// detect — image encoder + decoder pipeline
|
||||
// ====================================================================
|
||||
|
||||
std::vector<SAM3Result> ONNXSAM3::detect(const cv::Mat& mat, float segThreshold)
|
||||
{
|
||||
if (mat.empty()) return {};
|
||||
if (!m_promptSet) {
|
||||
std::cerr << "[ONNXSAM3] No prompt set. Call setPrompt() first." << std::endl;
|
||||
return {};
|
||||
}
|
||||
|
||||
const int origW = mat.cols;
|
||||
const int origH = mat.rows;
|
||||
|
||||
// ---- 1) Image Encoder ----
|
||||
std::vector<uint8_t> imgBuffer;
|
||||
preprocessImage(mat, imgBuffer);
|
||||
|
||||
std::vector<int64_t> imgShape = { 3, m_inputSize, m_inputSize };
|
||||
Ort::Value imgTensor = Ort::Value::CreateTensor<uint8_t>(
|
||||
*m_memInfo, imgBuffer.data(), imgBuffer.size(),
|
||||
imgShape.data(), imgShape.size());
|
||||
|
||||
std::vector<Ort::Value> imgInputs;
|
||||
imgInputs.push_back(std::move(imgTensor));
|
||||
|
||||
auto imgOutputs = m_imageEncoder.session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
m_imageEncoder.inputNames.data(),
|
||||
imgInputs.data(),
|
||||
imgInputs.size(),
|
||||
m_imageEncoder.outputNames.data(),
|
||||
m_imageEncoder.outputNames.size());
|
||||
|
||||
// Image encoder outputs (6 total, matched by name):
|
||||
// vision_pos_enc_0/1/2 — only _2 used by decoder
|
||||
// backbone_fpn_0/1/2 — all 3 used by decoder
|
||||
|
||||
// Build a map from image encoder output names to indices
|
||||
std::unordered_map<std::string, int> imgOutputMap;
|
||||
for (size_t i = 0; i < m_imageEncoder.outputNames_.size(); ++i)
|
||||
imgOutputMap[m_imageEncoder.outputNames_[i]] = static_cast<int>(i);
|
||||
|
||||
// Release unused outputs (vision_pos_enc_0, vision_pos_enc_1) to free
|
||||
// GPU memory before running the decoder. These are ~105 MB on CUDA.
|
||||
for (size_t i = 0; i < m_imageEncoder.outputNames_.size(); ++i) {
|
||||
const auto& name = m_imageEncoder.outputNames_[i];
|
||||
if (name == "vision_pos_enc_0" || name == "vision_pos_enc_1") {
|
||||
imgOutputs[i] = Ort::Value(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 2) Build decoder inputs ----
|
||||
size_t numDecInputs = m_decoder.inputNames.size();
|
||||
std::vector<Ort::Value> decInputs;
|
||||
decInputs.reserve(numDecInputs);
|
||||
|
||||
// Prepare scalar and prompt tensors
|
||||
int64_t origHeightVal = static_cast<int64_t>(origH);
|
||||
int64_t origWidthVal = static_cast<int64_t>(origW);
|
||||
std::vector<int64_t> scalarShape = {}; // scalar = 0-dim tensor
|
||||
|
||||
float boxCoordsData[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
||||
std::vector<int64_t> boxCoordsShape = { 1, 1, 4 };
|
||||
|
||||
int64_t boxLabelsData[1] = { -1 }; // no box prompt
|
||||
std::vector<int64_t> boxLabelsShape = { 1, 1 };
|
||||
|
||||
bool boxMasksData[1] = { false }; // no box prompt (language-only grounding)
|
||||
std::vector<int64_t> boxMasksShape = { 1, 1 };
|
||||
|
||||
// Build input tensors in the order expected by the decoder
|
||||
for (size_t i = 0; i < numDecInputs; ++i) {
|
||||
const std::string& name = m_decoder.inputNames_[i];
|
||||
|
||||
if (name == "original_height") {
|
||||
decInputs.push_back(Ort::Value::CreateTensor<int64_t>(
|
||||
*m_memInfo, &origHeightVal, 1, scalarShape.data(), scalarShape.size()));
|
||||
}
|
||||
else if (name == "original_width") {
|
||||
decInputs.push_back(Ort::Value::CreateTensor<int64_t>(
|
||||
*m_memInfo, &origWidthVal, 1, scalarShape.data(), scalarShape.size()));
|
||||
}
|
||||
else if (name == "backbone_fpn_0" || name == "backbone_fpn_1" ||
|
||||
name == "backbone_fpn_2" || name == "vision_pos_enc_2") {
|
||||
// Find matching image encoder output by name
|
||||
auto it = imgOutputMap.find(name);
|
||||
if (it != imgOutputMap.end()) {
|
||||
decInputs.push_back(std::move(imgOutputs[it->second]));
|
||||
} else {
|
||||
std::cerr << "[ONNXSAM3] Image encoder output not found: " << name << std::endl;
|
||||
float dummy = 0.0f;
|
||||
std::vector<int64_t> dummyShape = { 1 };
|
||||
decInputs.push_back(Ort::Value::CreateTensor<float>(
|
||||
*m_memInfo, &dummy, 1, dummyShape.data(), dummyShape.size()));
|
||||
}
|
||||
}
|
||||
else if (name == "language_mask") {
|
||||
decInputs.push_back(Ort::Value::CreateTensor<bool>(
|
||||
*m_memInfo,
|
||||
reinterpret_cast<bool*>(m_cachedLangMask.data()),
|
||||
m_cachedLangMask.size(),
|
||||
m_cachedLangMaskShape.data(),
|
||||
m_cachedLangMaskShape.size()));
|
||||
}
|
||||
else if (name == "language_features") {
|
||||
decInputs.push_back(Ort::Value::CreateTensor<float>(
|
||||
*m_memInfo,
|
||||
m_cachedLangFeatures.data(),
|
||||
m_cachedLangFeatures.size(),
|
||||
m_cachedLangFeaturesShape.data(),
|
||||
m_cachedLangFeaturesShape.size()));
|
||||
}
|
||||
else if (name == "box_coords") {
|
||||
decInputs.push_back(Ort::Value::CreateTensor<float>(
|
||||
*m_memInfo, boxCoordsData, 4,
|
||||
boxCoordsShape.data(), boxCoordsShape.size()));
|
||||
}
|
||||
else if (name == "box_labels") {
|
||||
decInputs.push_back(Ort::Value::CreateTensor<int64_t>(
|
||||
*m_memInfo, boxLabelsData, 1,
|
||||
boxLabelsShape.data(), boxLabelsShape.size()));
|
||||
}
|
||||
else if (name == "box_masks") {
|
||||
decInputs.push_back(Ort::Value::CreateTensor<bool>(
|
||||
*m_memInfo, boxMasksData, 1,
|
||||
boxMasksShape.data(), boxMasksShape.size()));
|
||||
}
|
||||
else {
|
||||
std::cerr << "[ONNXSAM3] Unknown decoder input: " << name << std::endl;
|
||||
// Create a dummy scalar float tensor
|
||||
float dummy = 0.0f;
|
||||
std::vector<int64_t> dummyShape = { 1 };
|
||||
decInputs.push_back(Ort::Value::CreateTensor<float>(
|
||||
*m_memInfo, &dummy, 1, dummyShape.data(), dummyShape.size()));
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Debug: print decoder input stats for comparison with ANSSAM3 ----
|
||||
{
|
||||
std::cout << "[ONNXSAM3] Decoder inputs before Run:" << std::endl;
|
||||
for (size_t di = 0; di < numDecInputs; ++di) {
|
||||
const std::string& dname = m_decoder.inputNames_[di];
|
||||
auto info = decInputs[di].GetTensorTypeAndShapeInfo();
|
||||
auto shape = info.GetShape();
|
||||
auto elemType = info.GetElementType();
|
||||
std::cout << " " << dname << " type=" << elemType << " shape=[";
|
||||
for (size_t d = 0; d < shape.size(); ++d) {
|
||||
if (d > 0) std::cout << ",";
|
||||
std::cout << shape[d];
|
||||
}
|
||||
std::cout << "]";
|
||||
// Print mean/first5 for float tensors
|
||||
if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT && !shape.empty()) {
|
||||
size_t numElems = info.GetElementCount();
|
||||
if (numElems > 0 && numElems < 100000000) {
|
||||
const float* data = decInputs[di].GetTensorData<float>();
|
||||
double sum = 0;
|
||||
for (size_t k = 0; k < numElems; ++k) sum += data[k];
|
||||
double mean = sum / numElems;
|
||||
std::cout << " mean=" << mean << " first5:";
|
||||
for (size_t k = 0; k < std::min(numElems, (size_t)5); ++k)
|
||||
std::cout << " " << data[k];
|
||||
}
|
||||
}
|
||||
// Print bool tensor values (for language_mask)
|
||||
else if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL && !shape.empty()) {
|
||||
size_t numElems = info.GetElementCount();
|
||||
const bool* data = decInputs[di].GetTensorData<bool>();
|
||||
std::cout << " vals:";
|
||||
for (size_t k = 0; k < std::min(numElems, (size_t)32); ++k)
|
||||
std::cout << " " << (int)data[k];
|
||||
}
|
||||
// Print int64 scalar value
|
||||
else if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 && shape.empty()) {
|
||||
const int64_t* data = decInputs[di].GetTensorData<int64_t>();
|
||||
std::cout << " value=" << data[0];
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 3) Run decoder ----
|
||||
auto decOutputs = m_decoder.session->Run(
|
||||
Ort::RunOptions{nullptr},
|
||||
m_decoder.inputNames.data(),
|
||||
decInputs.data(),
|
||||
decInputs.size(),
|
||||
m_decoder.outputNames.data(),
|
||||
m_decoder.outputNames.size());
|
||||
|
||||
// Decoder outputs (from Python analysis):
|
||||
// output[0]: boxes [N, 4] float32
|
||||
// output[1]: scores [N] float32
|
||||
// output[2]: masks [N, 1, H, W] bool
|
||||
|
||||
// Find output indices by name
|
||||
int boxesIdx = 0, scoresIdx = 1, masksIdx = 2;
|
||||
for (size_t i = 0; i < m_decoder.outputNames_.size(); ++i) {
|
||||
const auto& name = m_decoder.outputNames_[i];
|
||||
if (name.find("box") != std::string::npos) boxesIdx = static_cast<int>(i);
|
||||
else if (name.find("score") != std::string::npos) scoresIdx = static_cast<int>(i);
|
||||
else if (name.find("mask") != std::string::npos) masksIdx = static_cast<int>(i);
|
||||
}
|
||||
|
||||
// Get boxes
|
||||
auto boxInfo = decOutputs[boxesIdx].GetTensorTypeAndShapeInfo();
|
||||
auto boxShape = boxInfo.GetShape();
|
||||
int numBoxes = (boxShape.size() >= 1) ? static_cast<int>(boxShape[0]) : 0;
|
||||
const float* boxesData = decOutputs[boxesIdx].GetTensorData<float>();
|
||||
|
||||
// Get scores
|
||||
const float* scoresData = decOutputs[scoresIdx].GetTensorData<float>();
|
||||
|
||||
// Get masks
|
||||
auto maskInfo = decOutputs[masksIdx].GetTensorTypeAndShapeInfo();
|
||||
auto maskShape = maskInfo.GetShape();
|
||||
// masks shape: [N, 1, H, W]
|
||||
int maskH = (maskShape.size() >= 3) ? static_cast<int>(maskShape[2]) : 0;
|
||||
int maskW = (maskShape.size() >= 4) ? static_cast<int>(maskShape[3]) : 0;
|
||||
const bool* masksData = decOutputs[masksIdx].GetTensorData<bool>();
|
||||
|
||||
m_maskH = maskH;
|
||||
m_maskW = maskW;
|
||||
|
||||
std::cout << "[ONNXSAM3] Decoder: " << numBoxes << " detections, "
|
||||
<< "mask=" << maskH << "x" << maskW << std::endl;
|
||||
|
||||
return postprocessResults(boxesData, numBoxes, scoresData,
|
||||
masksData, maskH, maskW,
|
||||
origW, origH, segThreshold);
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// postprocessResults — convert decoder outputs to SAM3Result
|
||||
// ====================================================================
|
||||
|
||||
std::vector<SAM3Result> ONNXSAM3::postprocessResults(
|
||||
const float* boxesData, int numBoxes,
|
||||
const float* scoresData,
|
||||
const bool* masksData, int maskH, int maskW,
|
||||
int origWidth, int origHeight,
|
||||
float scoreThreshold)
|
||||
{
|
||||
std::vector<SAM3Result> results;
|
||||
|
||||
for (int i = 0; i < numBoxes; ++i) {
|
||||
float score = scoresData[i];
|
||||
if (score < scoreThreshold)
|
||||
continue;
|
||||
|
||||
// Box: [x1, y1, x2, y2] in original image coordinates
|
||||
float x1 = boxesData[i * 4 + 0];
|
||||
float y1 = boxesData[i * 4 + 1];
|
||||
float x2 = boxesData[i * 4 + 2];
|
||||
float y2 = boxesData[i * 4 + 3];
|
||||
|
||||
// Clamp to image bounds
|
||||
x1 = std::max(0.0f, std::min(x1, static_cast<float>(origWidth)));
|
||||
y1 = std::max(0.0f, std::min(y1, static_cast<float>(origHeight)));
|
||||
x2 = std::max(0.0f, std::min(x2, static_cast<float>(origWidth)));
|
||||
y2 = std::max(0.0f, std::min(y2, static_cast<float>(origHeight)));
|
||||
|
||||
SAM3Result obj;
|
||||
obj.box = cv::Rect(
|
||||
static_cast<int>(x1), static_cast<int>(y1),
|
||||
static_cast<int>(x2 - x1), static_cast<int>(y2 - y1));
|
||||
obj.confidence = score;
|
||||
|
||||
if (obj.box.width <= 0 || obj.box.height <= 0)
|
||||
continue;
|
||||
|
||||
// Extract this instance's mask: [1, H, W] bool at index i
|
||||
// masksData layout: [N, 1, H, W]
|
||||
const bool* instanceMask = masksData + static_cast<size_t>(i) * 1 * maskH * maskW;
|
||||
|
||||
// Create binary mask at original resolution
|
||||
cv::Mat boolMask(maskH, maskW, CV_8UC1);
|
||||
for (int y = 0; y < maskH; ++y)
|
||||
for (int x = 0; x < maskW; ++x)
|
||||
boolMask.at<uint8_t>(y, x) = instanceMask[y * maskW + x] ? 255 : 0;
|
||||
|
||||
// Resize mask to original image dimensions
|
||||
cv::Mat fullMask;
|
||||
cv::resize(boolMask, fullMask, cv::Size(origWidth, origHeight),
|
||||
0, 0, cv::INTER_LINEAR);
|
||||
cv::threshold(fullMask, fullMask, 127, 255, cv::THRESH_BINARY);
|
||||
|
||||
// Crop mask to bounding box
|
||||
cv::Mat roiMask = fullMask(obj.box).clone();
|
||||
obj.mask = roiMask;
|
||||
|
||||
// Create polygon from contour in the ROI region
|
||||
std::vector<std::vector<cv::Point>> contours;
|
||||
cv::findContours(roiMask.clone(), contours, cv::RETR_EXTERNAL,
|
||||
cv::CHAIN_APPROX_SIMPLE);
|
||||
|
||||
if (!contours.empty()) {
|
||||
// Use the largest contour
|
||||
int largestIdx = 0;
|
||||
double largestArea = 0;
|
||||
for (int c = 0; c < static_cast<int>(contours.size()); ++c) {
|
||||
double area = cv::contourArea(contours[c]);
|
||||
if (area > largestArea) {
|
||||
largestArea = area;
|
||||
largestIdx = c;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<cv::Point> approx;
|
||||
double epsilon = 0.01 * cv::arcLength(contours[largestIdx], true);
|
||||
cv::approxPolyDP(contours[largestIdx], approx, epsilon, true);
|
||||
|
||||
// Convert to absolute coordinates (ROI is relative to box)
|
||||
for (const auto& pt : approx) {
|
||||
obj.polygon.emplace_back(
|
||||
static_cast<float>(pt.x + obj.box.x),
|
||||
static_cast<float>(pt.y + obj.box.y));
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: box corners as polygon
|
||||
if (obj.polygon.size() < 3) {
|
||||
obj.polygon = {
|
||||
cv::Point2f(static_cast<float>(obj.box.x),
|
||||
static_cast<float>(obj.box.y)),
|
||||
cv::Point2f(static_cast<float>(obj.box.x + obj.box.width),
|
||||
static_cast<float>(obj.box.y)),
|
||||
cv::Point2f(static_cast<float>(obj.box.x + obj.box.width),
|
||||
static_cast<float>(obj.box.y + obj.box.height)),
|
||||
cv::Point2f(static_cast<float>(obj.box.x),
|
||||
static_cast<float>(obj.box.y + obj.box.height))
|
||||
};
|
||||
}
|
||||
|
||||
results.push_back(std::move(obj));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace ANSCENTER
|
||||
141
engines/ONNXEngine/ONNXSAM3.h
Normal file
141
engines/ONNXEngine/ONNXSAM3.h
Normal file
@@ -0,0 +1,141 @@
|
||||
#ifndef ONNXSAM3_H
|
||||
#define ONNXSAM3_H
|
||||
#pragma once
|
||||
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include "opencv2/opencv.hpp"
|
||||
#include "EPLoader.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
#ifndef ONNXENGINE_API
|
||||
#ifdef ENGINE_EXPORTS
|
||||
#define ONNXENGINE_API __declspec(dllexport)
|
||||
#else
|
||||
#define ONNXENGINE_API __declspec(dllimport)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace ANSCENTER
|
||||
{
|
||||
/// Result from SAM3 segmentation inference.
|
||||
struct SAM3Result
|
||||
{
|
||||
cv::Rect box; // bounding box
|
||||
float confidence = 0.0f; // detection confidence score
|
||||
cv::Mat mask; // binary mask (within bounding box)
|
||||
std::vector<cv::Point2f> polygon; // simplified contour polygon
|
||||
};
|
||||
|
||||
/// SAM3 engine using 3 separate ONNX Runtime sessions:
|
||||
/// 1) Image encoder — produces backbone features + position encodings
|
||||
/// 2) Language encoder — produces text attention mask + text features
|
||||
/// 3) Decoder — combines image features + language features → boxes, scores, masks
|
||||
///
|
||||
/// This architecture avoids the CUDA EP crash that occurs with the
|
||||
/// monolithic 3.3 GB model, since each sub-model is under 2 GB.
|
||||
class ONNXENGINE_API ONNXSAM3
|
||||
{
|
||||
public:
|
||||
/// Construct from a model folder containing:
|
||||
/// anssam3_image_encoder.onnx (+.onnx_data)
|
||||
/// anssam3_language_encoder.onnx (+.onnx_data)
|
||||
/// anssam3_decoder.onnx (+.onnx_data)
|
||||
explicit ONNXSAM3(const std::string& modelFolder,
|
||||
EngineType engineType,
|
||||
unsigned int num_threads = 1);
|
||||
|
||||
~ONNXSAM3();
|
||||
|
||||
// Non-copyable
|
||||
ONNXSAM3(const ONNXSAM3&) = delete;
|
||||
ONNXSAM3& operator=(const ONNXSAM3&) = delete;
|
||||
|
||||
/// Set text prompt (runs language encoder, caches results).
|
||||
void setPrompt(const std::vector<int64_t>& inputIds,
|
||||
const std::vector<int64_t>& attentionMask);
|
||||
|
||||
/// Run inference: image encoder + decoder with cached language features.
|
||||
/// @param mat Input image (BGR).
|
||||
/// @param segThreshold Score threshold for filtering detections.
|
||||
/// @return SAM3Result objects with boxes/masks/polygons.
|
||||
std::vector<SAM3Result> detect(const cv::Mat& mat,
|
||||
float segThreshold = 0.5f);
|
||||
|
||||
int getInputSize() const { return m_inputSize; }
|
||||
int getTokenLength() const { return m_tokenLength; }
|
||||
int getMaskH() const { return m_maskH; }
|
||||
int getMaskW() const { return m_maskW; }
|
||||
bool isPromptSet() const { return m_promptSet; }
|
||||
|
||||
private:
|
||||
/// Bundle holding one ORT session and its I/O names.
|
||||
struct SessionBundle
|
||||
{
|
||||
Ort::Session* session = nullptr;
|
||||
|
||||
std::vector<std::string> inputNames_; // owns strings
|
||||
std::vector<const char*> inputNames; // c_str pointers
|
||||
|
||||
std::vector<std::string> outputNames_; // owns strings
|
||||
std::vector<const char*> outputNames; // c_str pointers
|
||||
|
||||
~SessionBundle();
|
||||
};
|
||||
|
||||
/// Create one session bundle (EP attach, external data, GPU→CPU fallback).
|
||||
/// @param forceCPU When true, skip GPU EP and always use CPU.
|
||||
/// @param optLevel Graph optimization level for this session.
|
||||
void createSessionBundle(SessionBundle& bundle,
|
||||
const std::string& onnxPath,
|
||||
const std::string& label,
|
||||
bool forceCPU = false,
|
||||
GraphOptimizationLevel optLevel = GraphOptimizationLevel::ORT_ENABLE_ALL);
|
||||
|
||||
/// Image preprocessing: BGR → RGB, resize to 1008, HWC→CHW, uint8.
|
||||
void preprocessImage(const cv::Mat& mat, std::vector<uint8_t>& buffer);
|
||||
|
||||
/// Convert decoder outputs (boxes, scores, masks) → SAM3Result objects.
|
||||
std::vector<SAM3Result> postprocessResults(
|
||||
const float* boxesData, int numBoxes,
|
||||
const float* scoresData,
|
||||
const bool* masksData, int maskH, int maskW,
|
||||
int origWidth, int origHeight,
|
||||
float scoreThreshold);
|
||||
|
||||
// EP helpers (replicated from BasicOrtHandler)
|
||||
bool TryAppendCUDA(Ort::SessionOptions& opts);
|
||||
bool TryAppendDirectML(Ort::SessionOptions& opts);
|
||||
bool TryAppendOpenVINO(Ort::SessionOptions& opts);
|
||||
// ORT environment (shared across all 3 sessions)
|
||||
Ort::Env* m_env = nullptr;
|
||||
Ort::MemoryInfo* m_memInfo = nullptr;
|
||||
|
||||
// Three session bundles
|
||||
SessionBundle m_imageEncoder;
|
||||
SessionBundle m_langEncoder;
|
||||
SessionBundle m_decoder;
|
||||
|
||||
// Engine configuration
|
||||
EngineType m_engineType;
|
||||
unsigned int m_numThreads;
|
||||
std::string m_modelFolder;
|
||||
|
||||
// Model dimensions
|
||||
int m_inputSize = 1008; // image spatial size (H=W)
|
||||
int m_tokenLength = 32; // text token sequence length
|
||||
int m_maskH = 0; // output mask height (from decoder output)
|
||||
int m_maskW = 0; // output mask width (from decoder output)
|
||||
|
||||
// Cached language encoder outputs (set by setPrompt)
|
||||
std::vector<char> m_cachedLangMask; // bool tensor data
|
||||
std::vector<int64_t> m_cachedLangMaskShape;
|
||||
std::vector<float> m_cachedLangFeatures; // float32 tensor data
|
||||
std::vector<int64_t> m_cachedLangFeaturesShape;
|
||||
bool m_promptSet = false;
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user