124 lines
4.2 KiB
C++
124 lines
4.2 KiB
C++
#include "include/ocr_det.h"
|
|
|
|
namespace PaddleOCR {
|
|
|
|
Detector::Detector(std::string model_path)
|
|
{
|
|
ov::Core core;
|
|
this->model_path = model_path;
|
|
this->model = core.read_model(this->model_path);
|
|
this->model->reshape({ 1, 3, ov::Dimension(32, this->limit_side_len_), ov::Dimension(1, this->limit_side_len_) });
|
|
//core.set_property("CPU", ov::hint::performance_mode(ov::hint::PerformanceMode::THROUGHPUT));
|
|
this->compiled_model = core.compile_model(this->model, "CPU");
|
|
//this->compiled_model = core.compile_model(this->model, "CPU");
|
|
this->infer_request = this->compiled_model.create_infer_request();
|
|
}
|
|
void Detector::SetParameters(std::string limit_type,
|
|
std::string det_db_score_mode,
|
|
bool is_scale,
|
|
double det_db_thresh,
|
|
double det_db_box_thresh,
|
|
double det_db_unclip_ratio,
|
|
bool use_dilation)
|
|
{
|
|
std::lock_guard<std::recursive_mutex> lock(_mutex);
|
|
this->limit_type_ = limit_type;
|
|
this->det_db_score_mode_ = det_db_score_mode;
|
|
this->is_scale_ = is_scale;
|
|
this->det_db_thresh_ = det_db_thresh;
|
|
this->det_db_box_thresh_ = det_db_box_thresh;
|
|
this->det_db_unclip_ratio_ = det_db_unclip_ratio;
|
|
this->use_dilation_ = use_dilation;
|
|
}
|
|
void Detector::GetParameters(std::string& limit_type,
|
|
std::string& det_db_score_mode,
|
|
bool& is_scale,
|
|
double& det_db_thresh,
|
|
double& det_db_box_thresh,
|
|
double& det_db_unclip_ratio,
|
|
bool& use_dilation)
|
|
{
|
|
std::lock_guard<std::recursive_mutex> lock(_mutex);
|
|
limit_type = this->limit_type_;
|
|
det_db_score_mode = this->det_db_score_mode_;
|
|
is_scale = this->is_scale_;
|
|
det_db_thresh = this->det_db_thresh_;
|
|
det_db_box_thresh = this->det_db_box_thresh_;
|
|
det_db_unclip_ratio = this->det_db_unclip_ratio_;
|
|
use_dilation = this->use_dilation_;
|
|
}
|
|
void Detector::Run(const cv::Mat& src_img, std::vector<OCRPredictResult>& ocr_results)
|
|
{
|
|
std::lock_guard<std::recursive_mutex> lock(_mutex);
|
|
try {
|
|
this->src_img = src_img;
|
|
this->resize_op_.Run(this->src_img, this->resize_img, this->limit_type_,
|
|
this->limit_side_len_, this->ratio_h, this->ratio_w);
|
|
|
|
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
|
this->is_scale_);
|
|
|
|
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
|
ov::Shape intput_shape = { 1, 3, (size_t)resize_img.rows, (size_t)resize_img.cols };
|
|
this->permute_op_.Run(&resize_img, input.data());
|
|
|
|
std::vector<std::vector<std::vector<int>>> boxes;
|
|
auto input_port = this->compiled_model.input();
|
|
|
|
// -------- set input --------
|
|
ov::Tensor input_tensor(input_port.get_element_type(), intput_shape, input.data());
|
|
this->infer_request.set_input_tensor(input_tensor);
|
|
// -------- start inference --------
|
|
|
|
/* this->infer_request.start_async();
|
|
this->infer_request.wait();*/
|
|
|
|
this->infer_request.infer();
|
|
|
|
auto output = this->infer_request.get_output_tensor(0);
|
|
const float* out_data = output.data<const float>();
|
|
|
|
ov::Shape output_shape = output.get_shape();
|
|
const size_t n2 = output_shape[2];
|
|
const size_t n3 = output_shape[3];
|
|
const int n = n2 * n3;
|
|
|
|
std::vector<float> pred(n, 0.0);
|
|
std::vector<unsigned char> cbuf(n, ' ');
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
pred[i] = float(out_data[i]);
|
|
cbuf[i] = (unsigned char)((out_data[i]) * 255);
|
|
}
|
|
|
|
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
|
|
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
|
|
|
|
const double threshold = this->det_db_thresh_ * 255;
|
|
const double maxvalue = 255;
|
|
cv::Mat bit_map;
|
|
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
|
|
if (this->use_dilation_) {
|
|
cv::Mat dila_ele =
|
|
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
|
|
cv::dilate(bit_map, bit_map, dila_ele);
|
|
}
|
|
|
|
boxes = post_processor_.BoxesFromBitmap(
|
|
pred_map, bit_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_,
|
|
this->det_db_score_mode_);
|
|
|
|
boxes = post_processor_.FilterTagDetRes(boxes, this->ratio_h, this->ratio_w, this->src_img);
|
|
for (int i = 0; i < boxes.size(); i++) {
|
|
OCRPredictResult res;
|
|
res.box = boxes[i];
|
|
ocr_results.push_back(res);
|
|
}
|
|
// sort boex from top to bottom, from left to right
|
|
Utility::sorted_boxes(ocr_results);
|
|
}
|
|
catch (const std::exception& e) {
|
|
std::cerr << e.what() << std::endl;
|
|
}
|
|
}
|
|
} |