Fix NV12 crash issue when recreate camera object
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <semaphore>
|
||||
#include "TRTCompat.h"
|
||||
|
||||
// Per-device mutex for CUDA graph capture.
|
||||
@@ -15,6 +16,95 @@ static std::mutex& graphCaptureMutex() {
|
||||
return m;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GPU INFERENCE THROTTLE
|
||||
// ============================================================================
|
||||
// Global counting semaphore that limits how many Engine instances can execute
|
||||
// CUDA inference simultaneously. Without this, N separate Engine instances
|
||||
// (one per camera) all submit GPU work at once, causing:
|
||||
// 1. SM 100% saturation → each inference takes 5-10x longer
|
||||
// 2. GPU thermal throttling at 85°C → further slowdown
|
||||
// 3. cudaStreamSynchronize blocking indefinitely → system freeze
|
||||
//
|
||||
// Auto-computed from GPU VRAM:
|
||||
// ≤ 4 GB → 2 concurrent 8 GB → 4 concurrent
|
||||
// 6 GB → 3 concurrent 12+ GB → 6 concurrent
|
||||
// Multi-GPU: sum across all GPUs
|
||||
//
|
||||
// Excess threads wait on CPU (nearly zero cost) while the bounded set
|
||||
// runs efficiently on the GPU without thermal throttling.
|
||||
static std::counting_semaphore<64>& gpuInferenceSemaphore() {
|
||||
static int maxConcurrent = []() {
|
||||
int totalSlots = 0;
|
||||
int gpuCount = 0;
|
||||
cudaGetDeviceCount(&gpuCount);
|
||||
if (gpuCount <= 0) return 4; // fallback
|
||||
|
||||
for (int i = 0; i < gpuCount; ++i) {
|
||||
size_t freeMem = 0, totalMem = 0;
|
||||
cudaSetDevice(i);
|
||||
cudaMemGetInfo(&freeMem, &totalMem);
|
||||
int gbTotal = static_cast<int>(totalMem / (1024ULL * 1024ULL * 1024ULL));
|
||||
|
||||
// Scale concurrency with VRAM: ~1 slot per 2 GB, min 2, max 6 per GPU
|
||||
int slotsThisGpu = std::clamp(gbTotal / 2, 2, 6);
|
||||
totalSlots += slotsThisGpu;
|
||||
}
|
||||
|
||||
totalSlots = std::clamp(totalSlots, 2, 64);
|
||||
std::cout << "Info [GPU Throttle]: max concurrent inferences = "
|
||||
<< totalSlots << " (across " << gpuCount << " GPU(s))" << std::endl;
|
||||
return totalSlots;
|
||||
}();
|
||||
static std::counting_semaphore<64> sem(maxConcurrent);
|
||||
return sem;
|
||||
}
|
||||
|
||||
// RAII guard for the GPU inference semaphore
|
||||
struct GpuInferenceGuard {
|
||||
GpuInferenceGuard() { gpuInferenceSemaphore().acquire(); }
|
||||
~GpuInferenceGuard() { gpuInferenceSemaphore().release(); }
|
||||
GpuInferenceGuard(const GpuInferenceGuard&) = delete;
|
||||
GpuInferenceGuard& operator=(const GpuInferenceGuard&) = delete;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// WDDM-SAFE STREAM SYNCHRONIZATION
|
||||
// ============================================================================
|
||||
// Under Windows WDDM, cudaStreamSynchronize calls cuStreamQuery in a tight
|
||||
// loop with SwitchToThread, holding nvcuda64's internal SRW lock the entire
|
||||
// time. When the GPU is busy with inference, this spin blocks ALL other CUDA
|
||||
// operations — including HW video decode (nvcuvid), cuMemAlloc, cuArrayDestroy.
|
||||
// If a camera Reconnect or decode buffer allocation needs an exclusive SRW lock
|
||||
// while inference is spinning, the entire system deadlocks.
|
||||
//
|
||||
// This function replaces cudaStreamSynchronize with a polling loop that
|
||||
// explicitly releases the SRW lock between queries by sleeping briefly.
|
||||
// This allows other CUDA operations to interleave with the sync wait.
|
||||
static inline cudaError_t cudaStreamSynchronize_Safe(cudaStream_t stream) {
|
||||
// Fast path: check if already done (no sleep overhead for quick kernels)
|
||||
cudaError_t err = cudaStreamQuery(stream);
|
||||
if (err != cudaErrorNotReady) return err;
|
||||
|
||||
// Short Sleep(0) fast path (~10 iterations) catches sub-ms kernel completions.
|
||||
// Then switch to Sleep(1) to give cleanup operations (cuArrayDestroy, cuMemFree)
|
||||
// a window to acquire the exclusive nvcuda64 SRW lock.
|
||||
// Previously used 1000 Sleep(0) iterations which hogged the SRW lock and
|
||||
// caused ~20-second stalls when concurrent cleanup needed exclusive access.
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
Sleep(0);
|
||||
err = cudaStreamQuery(stream);
|
||||
if (err != cudaErrorNotReady) return err;
|
||||
}
|
||||
|
||||
// 1ms sleeps — adds negligible latency at 30 FPS but prevents SRW lock starvation.
|
||||
while (true) {
|
||||
Sleep(1);
|
||||
err = cudaStreamQuery(stream);
|
||||
if (err != cudaErrorNotReady) return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Engine<T>::warmUp(int iterations) {
|
||||
if (m_verbose) {
|
||||
@@ -163,6 +253,16 @@ bool Engine<T>::runInference(const std::vector<std::vector<cv::cuda::GpuMat>>& i
|
||||
return runInferenceFromPool(inputs, featureVectors);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GPU INFERENCE THROTTLE
|
||||
// ============================================================================
|
||||
// Limit how many Engine instances can run CUDA inference simultaneously.
|
||||
// Without this, 12 cameras each with their own Engine all submit GPU work
|
||||
// at once → SM 100% → thermal throttle → cudaStreamSynchronize hangs.
|
||||
// The semaphore lets excess threads wait on CPU (nearly zero cost) while
|
||||
// a bounded number use the GPU efficiently.
|
||||
GpuInferenceGuard gpuThrottle;
|
||||
|
||||
// ============================================================================
|
||||
// SINGLE-ENGINE SERIALISATION
|
||||
// ============================================================================
|
||||
@@ -376,7 +476,7 @@ bool Engine<T>::runInference(const std::vector<std::vector<cv::cuda::GpuMat>>& i
|
||||
std::cout << "Error: Failed to set optimization profile 0" << std::endl;
|
||||
return false;
|
||||
}
|
||||
cudaError_t syncErr = cudaStreamSynchronize(m_inferenceStream);
|
||||
cudaError_t syncErr = cudaStreamSynchronize_Safe(m_inferenceStream);
|
||||
if (syncErr != cudaSuccess) {
|
||||
std::cout << "Error: Failed to sync after profile change: "
|
||||
<< cudaGetErrorString(syncErr) << std::endl;
|
||||
@@ -642,7 +742,7 @@ bool Engine<T>::runInference(const std::vector<std::vector<cv::cuda::GpuMat>>& i
|
||||
if (graphExec) {
|
||||
// Launch the pre-captured graph (single API call replaces many).
|
||||
cudaGraphLaunch(graphExec, m_inferenceStream);
|
||||
cudaStreamSynchronize(m_inferenceStream);
|
||||
cudaStreamSynchronize_Safe(m_inferenceStream);
|
||||
|
||||
// CPU memcpy: pinned buffers -> featureVectors (interleaved by batch).
|
||||
for (int batch = 0; batch < batchSize; ++batch) {
|
||||
@@ -705,7 +805,7 @@ bool Engine<T>::runInference(const std::vector<std::vector<cv::cuda::GpuMat>>& i
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t syncErr = cudaStreamSynchronize(m_inferenceStream);
|
||||
cudaError_t syncErr = cudaStreamSynchronize_Safe(m_inferenceStream);
|
||||
if (syncErr != cudaSuccess) {
|
||||
std::string errMsg = "[Engine] runInference FAIL: cudaStreamSynchronize: "
|
||||
+ std::string(cudaGetErrorString(syncErr));
|
||||
|
||||
Reference in New Issue
Block a user