#pragma once #include #include #include #include #include #include namespace Yolov10 { struct YoloDetection { short class_id; float confidence; cv::Rect box; }; class Inference { public: Inference() {} Inference(const std::string& model_path, const float& model_confidence_threshold); Inference(const std::string& model_path, const cv::Size model_input_shape, const float& model_confidence_threshold); std::vector RunInference(const cv::Mat& frame); private: void InitialModel(const std::string& model_path); void Preprocessing(const cv::Mat& frame); void PostProcessing(); cv::Rect GetBoundingBox(const cv::Rect& src) const; std::vector detections_; float model_confidence_threshold_; cv::Mat resized_frame_; cv::Point2f factor_; cv::Size2f model_input_shape_; cv::Size model_output_shape_; std::string _modelFilePath; ov::Tensor input_tensor_; ov::InferRequest inference_request_; ov::CompiledModel compiled_model_; }; void DrawDetectedObject(cv::Mat& frame, const std::vector& detections, const std::vector& class_names); std::vector GetClassNameFromMetadata(const std::string& metadata_path); Inference::Inference(const std::string& model_path, const float& model_confidence_threshold) { model_input_shape_ = cv::Size(640, 640); // Set the default size for models with dynamic shapes to prevent errors. model_confidence_threshold_ = model_confidence_threshold; _modelFilePath = model_path; InitialModel(model_path); } // If the model has dynamic shapes, we need to set the input shape. Inference::Inference(const std::string& model_path, const cv::Size model_input_shape, const float& model_confidence_threshold) { model_input_shape_ = model_input_shape; model_confidence_threshold_ = model_confidence_threshold; _modelFilePath = model_path; InitialModel(model_path); } void Inference::InitialModel(const std::string& model_path) { ov::Core core; std::shared_ptr model = core.read_model(model_path); if (model->is_dynamic()) { model->reshape({ 1, 3, static_cast(model_input_shape_.height), static_cast(model_input_shape_.width) }); } ov::preprocess::PrePostProcessor ppp = ov::preprocess::PrePostProcessor(model); ppp.input().tensor().set_element_type(ov::element::u8).set_layout("NHWC").set_color_format(ov::preprocess::ColorFormat::BGR); ppp.input().preprocess().convert_element_type(ov::element::f32).convert_color(ov::preprocess::ColorFormat::RGB).scale({ 255, 255, 255 }); ppp.input().model().set_layout("NCHW"); ppp.output().tensor().set_element_type(ov::element::f32); model = ppp.build(); //compiled_model_ = core.compile_model(model, "AUTO:GPU,CPU"); //core.set_property("AUTO", ov::device::priorities("GPU,CPU")); compiled_model_ = core.compile_model(model, "C"); std::vector available_devices = core.get_available_devices(); auto num_requests = compiled_model_.get_property(ov::optimal_number_of_infer_requests); inference_request_ = compiled_model_.create_infer_request(); const std::vector> inputs = model->inputs(); const ov::Shape input_shape = inputs[0].get_shape(); short height = input_shape[1]; short width = input_shape[2]; model_input_shape_ = cv::Size2f(width, height); const std::vector> outputs = model->outputs(); const ov::Shape output_shape = outputs[0].get_shape(); height = output_shape[1]; width = output_shape[2]; model_output_shape_ = cv::Size(width, height); } std::vector Inference::RunInference(const cv::Mat& frame) { Preprocessing(frame); inference_request_.infer(); PostProcessing(); return detections_; } void Inference::Preprocessing(const cv::Mat& frame) { cv::resize(frame, resized_frame_, model_input_shape_, 0, 0, cv::INTER_AREA); factor_.x = static_cast(frame.cols / model_input_shape_.width); factor_.y = static_cast(frame.rows / model_input_shape_.height); float* input_data = (float*)resized_frame_.data; input_tensor_ = ov::Tensor(compiled_model_.input().get_element_type(), compiled_model_.input().get_shape(), input_data); inference_request_.set_input_tensor(input_tensor_); } void Inference::PostProcessing() { const float* detections = inference_request_.get_output_tensor().data(); detections_.clear(); /* * 0 1 2 3 4 5 * x, y, w. h, confidence, class_id */ for (unsigned int i = 0; i < model_output_shape_.height; ++i) { const unsigned int index = i * model_output_shape_.width; const float& confidence = detections[index + 4]; if (confidence > model_confidence_threshold_) { const float& x = detections[index + 0]; const float& y = detections[index + 1]; const float& w = detections[index + 2]; const float& h = detections[index + 3]; YoloDetection result; result.class_id = static_cast(detections[index + 5]); if (result.class_id > 9) result.class_id = 9; result.confidence = confidence; result.box = GetBoundingBox(cv::Rect(x, y, w, h)); detections_.push_back(result); } } } void DrawDetectedObject(cv::Mat& frame, const std::vector& detections, const std::vector& class_names) { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution dis(120, 255); for (const auto& detection : detections) { const cv::Rect& box = detection.box; const float& confidence = detection.confidence; const int& class_id = detection.class_id; const cv::Scalar color = cv::Scalar(dis(gen), dis(gen), dis(gen)); cv::rectangle(frame, box, color, 3); std::string class_string; if (class_names.empty()) class_string = "id[" + std::to_string(class_id) + "] " + std::to_string(confidence).substr(0, 4); else class_string = class_names[class_id] + " " + std::to_string(confidence).substr(0, 4); const cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_SIMPLEX, 0.6, 2, 0); const cv::Rect text_box(box.x - 2, box.y - 27, text_size.width + 10, text_size.height + 15); cv::rectangle(frame, text_box, color, cv::FILLED); cv::putText(frame, class_string, cv::Point(box.x + 5, box.y - 5), cv::FONT_HERSHEY_SIMPLEX, 0.6, cv::Scalar(0, 0, 0), 2, 0); } } std::vector GetClassNameFromMetadata(const std::string& metadata_path) { std::vector class_names; class_names.push_back("person"); class_names.push_back("bicycle"); class_names.push_back("car"); class_names.push_back("motorcycle"); class_names.push_back("airplane"); class_names.push_back("bus"); class_names.push_back("train"); class_names.push_back("truck"); class_names.push_back("boat"); class_names.push_back("traffic light"); class_names.push_back("fire hydrant"); return class_names; } }