51 lines
2.3 KiB
C
51 lines
2.3 KiB
C
#pragma once
|
|
// ============================================================================
|
|
// TRTCompat.h -- TensorRT version compatibility macros
|
|
//
|
|
// Centralises all TRT-version-dependent API differences so that the rest of
|
|
// the codebase can be compiled against TRT 8.x or TRT 10.x without scattering
|
|
// #if blocks everywhere.
|
|
//
|
|
// Build 1: CUDA 11.8 + cuDNN 8 + TensorRT 8.6 + OpenCV 4.10 (SM 35-86)
|
|
// Build 2: CUDA 13.1 + cuDNN 9 + TensorRT 10 + OpenCV 4.13 (SM 75-121)
|
|
// ============================================================================
|
|
|
|
#include <NvInferVersion.h>
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Network creation
|
|
// ---------------------------------------------------------------------------
|
|
// TRT 10+: kEXPLICIT_BATCH was removed (it is the only mode).
|
|
// TRT 8.x: The flag must be passed explicitly.
|
|
#if NV_TENSORRT_MAJOR >= 10
|
|
#define TRT_CREATE_NETWORK(builder) \
|
|
(builder)->createNetworkV2(0)
|
|
#else
|
|
#define TRT_CREATE_NETWORK(builder) \
|
|
(builder)->createNetworkV2( \
|
|
1U << static_cast<uint32_t>( \
|
|
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))
|
|
#endif
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Inference execution
|
|
// ---------------------------------------------------------------------------
|
|
// TRT 10+: enqueueV3(stream) — uses tensor addresses pre-bound via
|
|
// setTensorAddress().
|
|
// TRT 8.x: enqueueV2(bindings, stream, nullptr) — uses a void** array
|
|
// indexed by binding position.
|
|
#if NV_TENSORRT_MAJOR >= 10
|
|
#define TRT_ENQUEUE(context, stream, buffers) \
|
|
(context)->enqueueV3(stream)
|
|
#else
|
|
#define TRT_ENQUEUE(context, stream, buffers) \
|
|
(context)->enqueueV2( \
|
|
reinterpret_cast<void**>((buffers).data()), (stream), nullptr)
|
|
#endif
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Feature-detection helpers
|
|
// ---------------------------------------------------------------------------
|
|
#define TRT_HAS_ENQUEUE_V3 (NV_TENSORRT_MAJOR >= 10)
|
|
#define TRT_HAS_EXPLICIT_BATCH_FLAG (NV_TENSORRT_MAJOR < 10)
|