Files
ANSCORE/engines/TensorRTAPI/include/engine/TRTCompat.h

51 lines
2.3 KiB
C

#pragma once
// ============================================================================
// TRTCompat.h -- TensorRT version compatibility macros
//
// Centralises all TRT-version-dependent API differences so that the rest of
// the codebase can be compiled against TRT 8.x or TRT 10.x without scattering
// #if blocks everywhere.
//
// Build 1: CUDA 11.8 + cuDNN 8 + TensorRT 8.6 + OpenCV 4.10 (SM 35-86)
// Build 2: CUDA 13.1 + cuDNN 9 + TensorRT 10 + OpenCV 4.13 (SM 75-121)
// ============================================================================
#include <NvInferVersion.h>
// ---------------------------------------------------------------------------
// Network creation
// ---------------------------------------------------------------------------
// TRT 10+: kEXPLICIT_BATCH was removed (it is the only mode).
// TRT 8.x: The flag must be passed explicitly.
#if NV_TENSORRT_MAJOR >= 10
#define TRT_CREATE_NETWORK(builder) \
(builder)->createNetworkV2(0)
#else
#define TRT_CREATE_NETWORK(builder) \
(builder)->createNetworkV2( \
1U << static_cast<uint32_t>( \
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))
#endif
// ---------------------------------------------------------------------------
// Inference execution
// ---------------------------------------------------------------------------
// TRT 10+: enqueueV3(stream) — uses tensor addresses pre-bound via
// setTensorAddress().
// TRT 8.x: enqueueV2(bindings, stream, nullptr) — uses a void** array
// indexed by binding position.
#if NV_TENSORRT_MAJOR >= 10
#define TRT_ENQUEUE(context, stream, buffers) \
(context)->enqueueV3(stream)
#else
#define TRT_ENQUEUE(context, stream, buffers) \
(context)->enqueueV2( \
reinterpret_cast<void**>((buffers).data()), (stream), nullptr)
#endif
// ---------------------------------------------------------------------------
// Feature-detection helpers
// ---------------------------------------------------------------------------
#define TRT_HAS_ENQUEUE_V3 (NV_TENSORRT_MAJOR >= 10)
#define TRT_HAS_EXPLICIT_BATCH_FLAG (NV_TENSORRT_MAJOR < 10)