246 lines
8.9 KiB
C++
246 lines
8.9 KiB
C++
#include "decode.h"
|
|
#include "stdio.h"
|
|
#include "device_launch_parameters.h"
|
|
|
|
|
|
namespace nvinfer1
|
|
{
|
|
DecodePlugin::DecodePlugin()
|
|
{
|
|
}
|
|
|
|
DecodePlugin::~DecodePlugin()
|
|
{
|
|
}
|
|
|
|
// create the plugin at runtime from a byte stream
|
|
DecodePlugin::DecodePlugin(const void* data, size_t length)
|
|
{
|
|
}
|
|
|
|
void DecodePlugin::serialize(void* buffer) const TRT_NOEXCEPT
|
|
{
|
|
}
|
|
|
|
size_t DecodePlugin::getSerializationSize() const TRT_NOEXCEPT
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
int DecodePlugin::initialize() TRT_NOEXCEPT
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
Dims DecodePlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT
|
|
{
|
|
//output the result to channel
|
|
int totalCount = 0;
|
|
totalCount += decodeplugin::INPUT_H / 8 * decodeplugin::INPUT_W / 8 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
|
|
totalCount += decodeplugin::INPUT_H / 16 * decodeplugin::INPUT_W / 16 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
|
|
totalCount += decodeplugin::INPUT_H / 32 * decodeplugin::INPUT_W / 32 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
|
|
|
|
return Dims3(totalCount + 1, 1, 1);
|
|
}
|
|
|
|
// Set plugin namespace
|
|
void DecodePlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT
|
|
{
|
|
mPluginNamespace = pluginNamespace;
|
|
}
|
|
|
|
const char* DecodePlugin::getPluginNamespace() const TRT_NOEXCEPT
|
|
{
|
|
return mPluginNamespace;
|
|
}
|
|
|
|
// Return the DataType of the plugin output at the requested index
|
|
DataType DecodePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT
|
|
{
|
|
return DataType::kFLOAT;
|
|
}
|
|
|
|
// Return true if output tensor is broadcast across a batch.
|
|
bool DecodePlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT
|
|
{
|
|
return false;
|
|
}
|
|
|
|
// Return true if plugin can use input that is broadcast across batch without replication.
|
|
bool DecodePlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT
|
|
{
|
|
return false;
|
|
}
|
|
|
|
void DecodePlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT
|
|
{
|
|
}
|
|
|
|
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
|
|
void DecodePlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT
|
|
{
|
|
}
|
|
|
|
// Detach the plugin object from its execution context.
|
|
void DecodePlugin::detachFromContext() TRT_NOEXCEPT {}
|
|
|
|
const char* DecodePlugin::getPluginType() const TRT_NOEXCEPT
|
|
{
|
|
return "Decode_TRT";
|
|
}
|
|
|
|
const char* DecodePlugin::getPluginVersion() const TRT_NOEXCEPT
|
|
{
|
|
return "1";
|
|
}
|
|
|
|
void DecodePlugin::destroy() TRT_NOEXCEPT
|
|
{
|
|
delete this;
|
|
}
|
|
|
|
// Clone the plugin
|
|
IPluginV2IOExt* DecodePlugin::clone() const TRT_NOEXCEPT
|
|
{
|
|
DecodePlugin* p = new DecodePlugin();
|
|
p->setPluginNamespace(mPluginNamespace);
|
|
return p;
|
|
}
|
|
|
|
__device__ float Logist(float data) { return 1. / (1. + expf(-data)); };
|
|
|
|
__global__ void CalDetection(const float* input, float* output, int num_elem, int step, int anchor, int output_elem) {
|
|
|
|
int idx = threadIdx.x + blockDim.x * blockIdx.x;
|
|
if (idx >= num_elem) return;
|
|
|
|
int h = decodeplugin::INPUT_H / step;
|
|
int w = decodeplugin::INPUT_W / step;
|
|
int total_grid = h * w;
|
|
int bn_idx = idx / total_grid;
|
|
idx = idx - bn_idx * total_grid;
|
|
int y = idx / w;
|
|
int x = idx % w;
|
|
const float* cur_input = input + bn_idx * (4 + 2 + 10) * 2 * total_grid;
|
|
const float* bbox_reg = &cur_input[0];
|
|
const float* cls_reg = &cur_input[2 * 4 * total_grid];
|
|
const float* lmk_reg = &cur_input[2 * 4 * total_grid + 2 * 2 * total_grid];
|
|
|
|
for (int k = 0; k < 2; ++k) {
|
|
float conf1 = cls_reg[idx + k * total_grid * 2];
|
|
float conf2 = cls_reg[idx + k * total_grid * 2 + total_grid];
|
|
conf2 = expf(conf2) / (expf(conf1) + expf(conf2));
|
|
if (conf2 <= 0.02) continue;
|
|
|
|
float* res_count = output + bn_idx * output_elem;
|
|
int count = (int)atomicAdd(res_count, 1);
|
|
char* data = (char*)res_count + sizeof(float) + count * sizeof(decodeplugin::Detection);
|
|
decodeplugin::Detection* det = (decodeplugin::Detection*)(data);
|
|
|
|
float prior[4];
|
|
prior[0] = ((float)x + 0.5) / w;
|
|
prior[1] = ((float)y + 0.5) / h;
|
|
prior[2] = (float)anchor * (k + 1) / decodeplugin::INPUT_W;
|
|
prior[3] = (float)anchor * (k + 1) / decodeplugin::INPUT_H;
|
|
|
|
//Location
|
|
det->bbox[0] = prior[0] + bbox_reg[idx + k * total_grid * 4] * 0.1 * prior[2];
|
|
det->bbox[1] = prior[1] + bbox_reg[idx + k * total_grid * 4 + total_grid] * 0.1 * prior[3];
|
|
det->bbox[2] = prior[2] * expf(bbox_reg[idx + k * total_grid * 4 + total_grid * 2] * 0.2);
|
|
det->bbox[3] = prior[3] * expf(bbox_reg[idx + k * total_grid * 4 + total_grid * 3] * 0.2);
|
|
det->bbox[0] -= det->bbox[2] / 2;
|
|
det->bbox[1] -= det->bbox[3] / 2;
|
|
det->bbox[2] += det->bbox[0];
|
|
det->bbox[3] += det->bbox[1];
|
|
det->bbox[0] *= decodeplugin::INPUT_W;
|
|
det->bbox[1] *= decodeplugin::INPUT_H;
|
|
det->bbox[2] *= decodeplugin::INPUT_W;
|
|
det->bbox[3] *= decodeplugin::INPUT_H;
|
|
det->class_confidence = conf2;
|
|
for (int i = 0; i < 10; i += 2) {
|
|
det->landmark[i] = prior[0] + lmk_reg[idx + k * total_grid * 10 + total_grid * i] * 0.1 * prior[2];
|
|
det->landmark[i + 1] = prior[1] + lmk_reg[idx + k * total_grid * 10 + total_grid * (i + 1)] * 0.1 * prior[3];
|
|
det->landmark[i] *= decodeplugin::INPUT_W;
|
|
det->landmark[i + 1] *= decodeplugin::INPUT_H;
|
|
}
|
|
}
|
|
}
|
|
|
|
void DecodePlugin::forwardGpu(const float* const* inputs, float* output, cudaStream_t stream, int batchSize)
|
|
{
|
|
int num_elem = 0;
|
|
int base_step = 8;
|
|
int base_anchor = 16;
|
|
int thread_count;
|
|
|
|
int totalCount = 1;
|
|
totalCount += decodeplugin::INPUT_H / 8 * decodeplugin::INPUT_W / 8 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
|
|
totalCount += decodeplugin::INPUT_H / 16 * decodeplugin::INPUT_W / 16 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
|
|
totalCount += decodeplugin::INPUT_H / 32 * decodeplugin::INPUT_W / 32 * 2 * sizeof(decodeplugin::Detection) / sizeof(float);
|
|
for (int idx = 0; idx < batchSize; ++idx) {
|
|
cudaMemsetAsync(output + idx * totalCount, 0, sizeof(float), stream);
|
|
}
|
|
|
|
for (unsigned int i = 0; i < 3; ++i)
|
|
{
|
|
num_elem = batchSize * decodeplugin::INPUT_H / base_step * decodeplugin::INPUT_W / base_step;
|
|
thread_count = (num_elem < thread_count_) ? num_elem : thread_count_;
|
|
CalDetection << < (num_elem + thread_count - 1) / thread_count, thread_count, 0, stream >> >
|
|
(inputs[i], output, num_elem, base_step, base_anchor, totalCount);
|
|
|
|
base_step *= 2;
|
|
base_anchor *= 4;
|
|
}
|
|
}
|
|
|
|
int DecodePlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT
|
|
{
|
|
//GPU
|
|
//CUDA_CHECK(cudaStreamSynchronize(stream));
|
|
forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize);
|
|
return 0;
|
|
};
|
|
|
|
PluginFieldCollection DecodePluginCreator::mFC{};
|
|
std::vector<PluginField> DecodePluginCreator::mPluginAttributes;
|
|
|
|
DecodePluginCreator::DecodePluginCreator()
|
|
{
|
|
mPluginAttributes.clear();
|
|
|
|
mFC.nbFields = mPluginAttributes.size();
|
|
mFC.fields = mPluginAttributes.data();
|
|
}
|
|
|
|
const char* DecodePluginCreator::getPluginName() const TRT_NOEXCEPT
|
|
{
|
|
return "Decode_TRT";
|
|
}
|
|
|
|
const char* DecodePluginCreator::getPluginVersion() const TRT_NOEXCEPT
|
|
{
|
|
return "1";
|
|
}
|
|
|
|
const PluginFieldCollection* DecodePluginCreator::getFieldNames() TRT_NOEXCEPT
|
|
{
|
|
return &mFC;
|
|
}
|
|
|
|
IPluginV2IOExt* DecodePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT
|
|
{
|
|
DecodePlugin* obj = new DecodePlugin();
|
|
obj->setPluginNamespace(mNamespace.c_str());
|
|
return obj;
|
|
}
|
|
|
|
IPluginV2IOExt* DecodePluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT
|
|
{
|
|
// This object will be deleted when the network is destroyed, which will
|
|
// call PReluPlugin::destroy()
|
|
DecodePlugin* obj = new DecodePlugin(serialData, serialLength);
|
|
obj->setPluginNamespace(mNamespace.c_str());
|
|
return obj;
|
|
}
|
|
|
|
} |