Files
ANSCORE/engines/TensorRTAPI/include/engine/TRTEngineCache.h

178 lines
7.1 KiB
C++

#pragma once
// TRTEngineCache.h — Process-wide cache for shared TensorRT ICudaEngine instances.
//
// When multiple AI tasks load the same model (same .engine file + GPU), this cache
// ensures only ONE copy of the model weights lives in VRAM. Each task creates its
// own IExecutionContext from the shared ICudaEngine (TRT-supported pattern).
//
// Usage in loadNetwork():
// auto& cache = TRTEngineCache::instance();
// auto hit = cache.tryGet(enginePath, gpuIdx);
// if (hit.engine) {
// m_engine = hit.engine; m_runtime = hit.runtime; // cache hit
// } else {
// // ... deserialize as usual ...
// m_engine = cache.putIfAbsent(enginePath, gpuIdx, runtime, engine);
// }
//
// In ~Engine():
// cache.release(enginePath, gpuIdx);
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <iostream>
#include <NvInfer.h>
/// Process-wide flag: set to true during DLL_PROCESS_DETACH when ExitProcess
/// is in progress (lpReserved != NULL). Worker threads are already dead in
/// this state, so thread::join() would deadlock and CUDA/TRT calls are unsafe.
/// Checked by Engine::~Engine to skip cleanup that requires live threads or GPUs.
inline std::atomic<bool>& g_processExiting() {
static std::atomic<bool> s_flag{false};
return s_flag;
}
class TRTEngineCache {
public:
struct CacheHit {
std::shared_ptr<nvinfer1::ICudaEngine> engine;
std::shared_ptr<nvinfer1::IRuntime> runtime;
};
static TRTEngineCache& instance() {
static TRTEngineCache s_instance;
return s_instance;
}
/// Global bypass — when true, tryGet() always returns miss, putIfAbsent()
/// is a no-op, and buildLoadNetwork/loadNetwork force single-GPU path.
/// Used by OptimizeModelStr to prevent inner engines (created by
/// custom DLLs via ANSLIB.dll) from creating pools/caching.
/// Stored as a member of the singleton to guarantee a single instance
/// across all translation units (avoids MSVC inline static duplication).
static std::atomic<bool>& globalBypass() {
return instance().m_globalBypass;
}
std::atomic<bool> m_globalBypass{false};
/// Try to get a cached engine. Returns {nullptr, nullptr} on miss.
/// On hit, increments refcount.
CacheHit tryGet(const std::string& engineFilePath, int gpuIndex) {
if (globalBypass().load(std::memory_order_relaxed)) return {nullptr, nullptr};
std::lock_guard<std::mutex> lock(m_mutex);
auto it = m_cache.find({engineFilePath, gpuIndex});
if (it != m_cache.end()) {
it->second.refcount++;
std::cout << "[TRTEngineCache] HIT: " << engineFilePath
<< " GPU[" << gpuIndex << "] refs=" << it->second.refcount << std::endl;
return {it->second.engine, it->second.runtime};
}
return {nullptr, nullptr};
}
/// Store a newly deserialized engine. If another thread already stored the
/// same key (race), returns the existing one and the caller's copy is discarded.
/// Increments refcount for the returned engine.
std::shared_ptr<nvinfer1::ICudaEngine> putIfAbsent(
const std::string& engineFilePath, int gpuIndex,
std::shared_ptr<nvinfer1::IRuntime> runtime,
std::shared_ptr<nvinfer1::ICudaEngine> engine) {
if (globalBypass().load(std::memory_order_relaxed)) return engine; // don't cache
std::lock_guard<std::mutex> lock(m_mutex);
CacheKey key{engineFilePath, gpuIndex};
auto it = m_cache.find(key);
if (it != m_cache.end()) {
// Another thread beat us — use theirs, discard ours
it->second.refcount++;
std::cout << "[TRTEngineCache] RACE: using existing for " << engineFilePath
<< " GPU[" << gpuIndex << "] refs=" << it->second.refcount << std::endl;
return it->second.engine;
}
// First to store — insert
CachedEntry entry;
entry.engine = std::move(engine);
entry.runtime = std::move(runtime);
entry.refcount = 1;
auto inserted = m_cache.emplace(std::move(key), std::move(entry));
std::cout << "[TRTEngineCache] STORED: " << engineFilePath
<< " GPU[" << gpuIndex << "] refs=1" << std::endl;
return inserted.first->second.engine;
}
/// Decrement refcount. When refcount reaches 0, the engine is evicted immediately
/// to release VRAM and file handles (allows ModelOptimizer to rebuild .engine files
/// while LabVIEW is running).
void release(const std::string& engineFilePath, int gpuIndex) {
std::lock_guard<std::mutex> lock(m_mutex);
auto it = m_cache.find({engineFilePath, gpuIndex});
if (it != m_cache.end() && it->second.refcount > 0) {
it->second.refcount--;
std::cout << "[TRTEngineCache] RELEASE: " << engineFilePath
<< " GPU[" << gpuIndex << "] refs=" << it->second.refcount << std::endl;
if (it->second.refcount <= 0) {
std::cout << "[TRTEngineCache] EVICT (refcount=0): " << engineFilePath
<< " GPU[" << gpuIndex << "]" << std::endl;
m_cache.erase(it);
}
}
}
/// Remove all entries with refcount == 0 (call at shutdown or when VRAM tight).
void evictUnused() {
std::lock_guard<std::mutex> lock(m_mutex);
for (auto it = m_cache.begin(); it != m_cache.end(); ) {
if (it->second.refcount <= 0) {
std::cout << "[TRTEngineCache] EVICT: " << it->first.path
<< " GPU[" << it->first.gpuIndex << "]" << std::endl;
it = m_cache.erase(it);
} else {
++it;
}
}
}
/// Clear all cached engines immediately (call during DLL_PROCESS_DETACH
/// BEFORE destroying engine handles, to avoid calling into unloaded TRT DLLs).
void clearAll() {
std::lock_guard<std::mutex> lock(m_mutex);
std::cout << "[TRTEngineCache] CLEAR ALL (" << m_cache.size() << " entries)" << std::endl;
m_cache.clear(); // shared_ptrs released — engines destroyed while TRT is still loaded
}
/// Number of cached engines (for diagnostics).
size_t size() const {
std::lock_guard<std::mutex> lock(m_mutex);
return m_cache.size();
}
private:
TRTEngineCache() = default;
TRTEngineCache(const TRTEngineCache&) = delete;
TRTEngineCache& operator=(const TRTEngineCache&) = delete;
struct CacheKey {
std::string path;
int gpuIndex = 0;
bool operator==(const CacheKey& o) const {
return path == o.path && gpuIndex == o.gpuIndex;
}
};
struct CacheKeyHash {
size_t operator()(const CacheKey& k) const {
return std::hash<std::string>{}(k.path) ^
(std::hash<int>{}(k.gpuIndex) << 16);
}
};
struct CachedEntry {
std::shared_ptr<nvinfer1::ICudaEngine> engine;
std::shared_ptr<nvinfer1::IRuntime> runtime;
int refcount = 0;
};
std::unordered_map<CacheKey, CachedEntry, CacheKeyHash> m_cache;
mutable std::mutex m_mutex;
};