Files
ANSCORE/modules/ANSLPR/include/ONNX_detector.h

68 lines
2.5 KiB
C++

/*
// Line.h: interface for the C_Line class.
*/
#if !defined(ONNX_RUNTIME_DETECTOR_H)
#define ONNX_RUNTIME_DETECTOR_H
# pragma once
#include <iostream>
#include <vector>
#include <numeric>
#include <string>
#include <functional>
#include <onnxruntime_c_api.h>
#include <onnxruntime_cxx_api.h>
#ifdef ANSLPR_USE_CUDA
#include <cpu_provider_factory.h>
#endif //ANSLPR_USE_CUDA
#include "utils_alpr_detect.h"
class OnnxDetector {
public:
/***
* @brief constructor
* @param model_path - path of the TorchScript weight file
*/
OnnxDetector(Ort::Env& env, const void* model_data, size_t model_data_length, const Ort::SessionOptions& options);
OnnxDetector(Ort::Env& env, const ORTCHAR_T* model_path, const Ort::SessionOptions& options);
void dump() const;
/***
* @brief inference module
* @param img - input image
* @param conf_threshold - confidence threshold
* @param iou_threshold - IoU threshold for nms
* @return detection result - bounding box, score, class index
*/
std::vector<std::vector<Detection>>
Run(const cv::Mat& img, float conf_threshold, float iou_threshold, bool preserve_aspect_ratio);
/***
* @brief inference module
* @param img - input image
* @param conf_threshold - confidence threshold
* @param iou_threshold - IoU threshold for nms
* @return detection result - bounding box, score, class index
*/
std::list<std::vector<std::vector<Detection>>>
Run(const cv::Mat& img, float iou_threshold);
/***
* @brief
* @return the maximum size of input image (ie width or height of dnn input layer)
*/
int64_t max_image_size() const;
bool is_valid() const {
return (session.GetInputCount() > 0 && session.GetOutputCount() > 0);
}
protected:
//session options are created outside the class. The classifier access to its options through a constant reference
const Ort::SessionOptions & sessionOptions;
Ort::Session session;
//ONNX environment are created outside the class. The classifier access to its envirponment through a constant reference
const Ort::Env& env;
};
//non max suppession algorithm to select boxes
void nms(const std::vector<cv::Rect>& srcRects, std::vector<cv::Rect>& resRects, std::vector<int>& resIndexs, float thresh);
//standard scalar product
template <typename T>
T vectorProduct(const std::vector<T>& v)
{
return std::accumulate(v.begin(), v.end(), 1, std::multiplies<T>());
}
#endif // !defined(ONNX_RUNTIME_DETECTOR_H)