Add unit tests
This commit is contained in:
187
tests/ANSODEngine-UnitTest/yolov10.h
Normal file
187
tests/ANSODEngine-UnitTest/yolov10.h
Normal file
@@ -0,0 +1,187 @@
|
||||
#pragma once
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/core/node.hpp>
|
||||
#include <opencv2/dnn.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <models/detection_model_ssd.h>
|
||||
#include <random>
|
||||
namespace Yolov10 {
|
||||
struct YoloDetection {
|
||||
short class_id;
|
||||
float confidence;
|
||||
cv::Rect box;
|
||||
};
|
||||
class Inference {
|
||||
public:
|
||||
Inference() {}
|
||||
Inference(const std::string& model_path, const float& model_confidence_threshold);
|
||||
Inference(const std::string& model_path, const cv::Size model_input_shape, const float& model_confidence_threshold);
|
||||
|
||||
std::vector<YoloDetection> RunInference(const cv::Mat& frame);
|
||||
|
||||
private:
|
||||
void InitialModel(const std::string& model_path);
|
||||
void Preprocessing(const cv::Mat& frame);
|
||||
void PostProcessing();
|
||||
cv::Rect GetBoundingBox(const cv::Rect& src) const;
|
||||
std::vector<YoloDetection> detections_;
|
||||
float model_confidence_threshold_;
|
||||
cv::Mat resized_frame_;
|
||||
cv::Point2f factor_;
|
||||
cv::Size2f model_input_shape_;
|
||||
cv::Size model_output_shape_;
|
||||
std::string _modelFilePath;
|
||||
ov::Tensor input_tensor_;
|
||||
ov::InferRequest inference_request_;
|
||||
ov::CompiledModel compiled_model_;
|
||||
};
|
||||
|
||||
void DrawDetectedObject(cv::Mat& frame, const std::vector<YoloDetection>& detections, const std::vector<std::string>& class_names);
|
||||
std::vector<std::string> GetClassNameFromMetadata(const std::string& metadata_path);
|
||||
|
||||
|
||||
Inference::Inference(const std::string& model_path, const float& model_confidence_threshold) {
|
||||
model_input_shape_ = cv::Size(640, 640); // Set the default size for models with dynamic shapes to prevent errors.
|
||||
model_confidence_threshold_ = model_confidence_threshold;
|
||||
_modelFilePath = model_path;
|
||||
InitialModel(model_path);
|
||||
}
|
||||
|
||||
// If the model has dynamic shapes, we need to set the input shape.
|
||||
Inference::Inference(const std::string& model_path, const cv::Size model_input_shape, const float& model_confidence_threshold) {
|
||||
model_input_shape_ = model_input_shape;
|
||||
model_confidence_threshold_ = model_confidence_threshold;
|
||||
_modelFilePath = model_path;
|
||||
|
||||
InitialModel(model_path);
|
||||
}
|
||||
|
||||
|
||||
void Inference::InitialModel(const std::string& model_path) {
|
||||
ov::Core core;
|
||||
std::shared_ptr<ov::Model> model = core.read_model(model_path);
|
||||
if (model->is_dynamic()) {
|
||||
model->reshape({ 1, 3, static_cast<long int>(model_input_shape_.height), static_cast<long int>(model_input_shape_.width) });
|
||||
}
|
||||
ov::preprocess::PrePostProcessor ppp = ov::preprocess::PrePostProcessor(model);
|
||||
|
||||
ppp.input().tensor().set_element_type(ov::element::u8).set_layout("NHWC").set_color_format(ov::preprocess::ColorFormat::BGR);
|
||||
ppp.input().preprocess().convert_element_type(ov::element::f32).convert_color(ov::preprocess::ColorFormat::RGB).scale({ 255, 255, 255 });
|
||||
ppp.input().model().set_layout("NCHW");
|
||||
ppp.output().tensor().set_element_type(ov::element::f32);
|
||||
|
||||
model = ppp.build();
|
||||
//compiled_model_ = core.compile_model(model, "AUTO:GPU,CPU");
|
||||
//core.set_property("AUTO", ov::device::priorities("GPU,CPU"));
|
||||
compiled_model_ = core.compile_model(model, "C");
|
||||
std::vector<std::string> available_devices = core.get_available_devices();
|
||||
auto num_requests = compiled_model_.get_property(ov::optimal_number_of_infer_requests);
|
||||
|
||||
inference_request_ = compiled_model_.create_infer_request();
|
||||
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
|
||||
const ov::Shape input_shape = inputs[0].get_shape();
|
||||
|
||||
short height = input_shape[1];
|
||||
short width = input_shape[2];
|
||||
model_input_shape_ = cv::Size2f(width, height);
|
||||
|
||||
const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
|
||||
const ov::Shape output_shape = outputs[0].get_shape();
|
||||
|
||||
height = output_shape[1];
|
||||
width = output_shape[2];
|
||||
model_output_shape_ = cv::Size(width, height);
|
||||
}
|
||||
|
||||
std::vector<YoloDetection> Inference::RunInference(const cv::Mat& frame) {
|
||||
Preprocessing(frame);
|
||||
inference_request_.infer();
|
||||
PostProcessing();
|
||||
|
||||
return detections_;
|
||||
}
|
||||
|
||||
void Inference::Preprocessing(const cv::Mat& frame) {
|
||||
cv::resize(frame, resized_frame_, model_input_shape_, 0, 0, cv::INTER_AREA);
|
||||
factor_.x = static_cast<float>(frame.cols / model_input_shape_.width);
|
||||
factor_.y = static_cast<float>(frame.rows / model_input_shape_.height);
|
||||
float* input_data = (float*)resized_frame_.data;
|
||||
input_tensor_ = ov::Tensor(compiled_model_.input().get_element_type(), compiled_model_.input().get_shape(), input_data);
|
||||
inference_request_.set_input_tensor(input_tensor_);
|
||||
}
|
||||
|
||||
void Inference::PostProcessing() {
|
||||
const float* detections = inference_request_.get_output_tensor().data<const float>();
|
||||
detections_.clear();
|
||||
/*
|
||||
* 0 1 2 3 4 5
|
||||
* x, y, w. h, confidence, class_id
|
||||
*/
|
||||
|
||||
for (unsigned int i = 0; i < model_output_shape_.height; ++i) {
|
||||
const unsigned int index = i * model_output_shape_.width;
|
||||
const float& confidence = detections[index + 4];
|
||||
if (confidence > model_confidence_threshold_) {
|
||||
const float& x = detections[index + 0];
|
||||
const float& y = detections[index + 1];
|
||||
const float& w = detections[index + 2];
|
||||
const float& h = detections[index + 3];
|
||||
YoloDetection result;
|
||||
result.class_id = static_cast<const short>(detections[index + 5]);
|
||||
if (result.class_id > 9) result.class_id = 9;
|
||||
result.confidence = confidence;
|
||||
result.box = GetBoundingBox(cv::Rect(x, y, w, h));
|
||||
detections_.push_back(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void DrawDetectedObject(cv::Mat& frame, const std::vector<YoloDetection>& detections, const std::vector<std::string>& class_names) {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<int> dis(120, 255);
|
||||
|
||||
for (const auto& detection : detections) {
|
||||
const cv::Rect& box = detection.box;
|
||||
const float& confidence = detection.confidence;
|
||||
const int& class_id = detection.class_id;
|
||||
|
||||
const cv::Scalar color = cv::Scalar(dis(gen), dis(gen), dis(gen));
|
||||
cv::rectangle(frame, box, color, 3);
|
||||
|
||||
std::string class_string;
|
||||
|
||||
if (class_names.empty())
|
||||
class_string = "id[" + std::to_string(class_id) + "] " + std::to_string(confidence).substr(0, 4);
|
||||
else
|
||||
class_string = class_names[class_id] + " " + std::to_string(confidence).substr(0, 4);
|
||||
|
||||
const cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_SIMPLEX, 0.6, 2, 0);
|
||||
const cv::Rect text_box(box.x - 2, box.y - 27, text_size.width + 10, text_size.height + 15);
|
||||
|
||||
cv::rectangle(frame, text_box, color, cv::FILLED);
|
||||
cv::putText(frame, class_string, cv::Point(box.x + 5, box.y - 5), cv::FONT_HERSHEY_SIMPLEX, 0.6, cv::Scalar(0, 0, 0), 2, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> GetClassNameFromMetadata(const std::string& metadata_path) {
|
||||
std::vector<std::string> class_names;
|
||||
|
||||
class_names.push_back("person");
|
||||
class_names.push_back("bicycle");
|
||||
class_names.push_back("car");
|
||||
class_names.push_back("motorcycle");
|
||||
class_names.push_back("airplane");
|
||||
class_names.push_back("bus");
|
||||
class_names.push_back("train");
|
||||
class_names.push_back("truck");
|
||||
class_names.push_back("boat");
|
||||
class_names.push_back("traffic light");
|
||||
class_names.push_back("fire hydrant");
|
||||
|
||||
return class_names;
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user