68 lines
2.5 KiB
C++
68 lines
2.5 KiB
C++
/*
|
|
// Line.h: interface for the C_Line class.
|
|
*/
|
|
#if !defined(ONNX_RUNTIME_DETECTOR_H)
|
|
#define ONNX_RUNTIME_DETECTOR_H
|
|
# pragma once
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include <numeric>
|
|
#include <string>
|
|
#include <functional>
|
|
#include <onnxruntime_c_api.h>
|
|
#include <onnxruntime_cxx_api.h>
|
|
#ifdef ANSLPR_USE_CUDA
|
|
#include <cpu_provider_factory.h>
|
|
#endif //ANSLPR_USE_CUDA
|
|
#include "utils_alpr_detect.h"
|
|
class OnnxDetector {
|
|
public:
|
|
/***
|
|
* @brief constructor
|
|
* @param model_path - path of the TorchScript weight file
|
|
*/
|
|
OnnxDetector(Ort::Env& env, const void* model_data, size_t model_data_length, const Ort::SessionOptions& options);
|
|
OnnxDetector(Ort::Env& env, const ORTCHAR_T* model_path, const Ort::SessionOptions& options);
|
|
void dump() const;
|
|
/***
|
|
* @brief inference module
|
|
* @param img - input image
|
|
* @param conf_threshold - confidence threshold
|
|
* @param iou_threshold - IoU threshold for nms
|
|
* @return detection result - bounding box, score, class index
|
|
*/
|
|
std::vector<std::vector<Detection>>
|
|
Run(const cv::Mat& img, float conf_threshold, float iou_threshold, bool preserve_aspect_ratio);
|
|
/***
|
|
* @brief inference module
|
|
* @param img - input image
|
|
* @param conf_threshold - confidence threshold
|
|
* @param iou_threshold - IoU threshold for nms
|
|
* @return detection result - bounding box, score, class index
|
|
*/
|
|
std::list<std::vector<std::vector<Detection>>>
|
|
Run(const cv::Mat& img, float iou_threshold);
|
|
/***
|
|
* @brief
|
|
* @return the maximum size of input image (ie width or height of dnn input layer)
|
|
*/
|
|
int64_t max_image_size() const;
|
|
bool is_valid() const {
|
|
return (session.GetInputCount() > 0 && session.GetOutputCount() > 0);
|
|
}
|
|
protected:
|
|
//session options are created outside the class. The classifier access to its options through a constant reference
|
|
const Ort::SessionOptions & sessionOptions;
|
|
Ort::Session session;
|
|
//ONNX environment are created outside the class. The classifier access to its envirponment through a constant reference
|
|
const Ort::Env& env;
|
|
};
|
|
//non max suppession algorithm to select boxes
|
|
void nms(const std::vector<cv::Rect>& srcRects, std::vector<cv::Rect>& resRects, std::vector<int>& resIndexs, float thresh);
|
|
//standard scalar product
|
|
template <typename T>
|
|
T vectorProduct(const std::vector<T>& v)
|
|
{
|
|
return std::accumulate(v.begin(), v.end(), 1, std::multiplies<T>());
|
|
}
|
|
#endif // !defined(ONNX_RUNTIME_DETECTOR_H)
|