Files
ANSCORE/modules/ANSOCR/ANSONNXOCR/ONNXOCRClassifier.cpp

148 lines
5.5 KiB
C++
Raw Normal View History

2026-03-28 16:54:11 +11:00
#include "ONNXOCRClassifier.h"
#include <opencv2/imgproc.hpp>
#include <iostream>
#include <algorithm>
#include <cmath>
2026-04-14 20:30:21 +10:00
#include <chrono>
2026-03-28 16:54:11 +11:00
namespace ANSCENTER {
namespace onnxocr {
ONNXOCRClassifier::ONNXOCRClassifier(const std::string& onnx_path, unsigned int num_threads)
: BasicOrtHandler(onnx_path, num_threads) {
}
2026-04-14 20:30:21 +10:00
ONNXOCRClassifier::ONNXOCRClassifier(const std::string& onnx_path,
const OrtHandlerOptions& options,
unsigned int num_threads)
: BasicOrtHandler(onnx_path, options, num_threads) {
}
2026-03-28 16:54:11 +11:00
Ort::Value ONNXOCRClassifier::transform(const cv::Mat& mat) {
cv::Mat resized;
// Direct resize to 80x160 (PP-LCNet_x1_0_textline_ori)
// No aspect ratio preservation — matches PaddleOCR official ResizeImage
cv::resize(mat, resized, cv::Size(kClsImageW, kClsImageH));
resized.convertTo(resized, CV_32FC3);
// PP-LCNet uses ImageNet normalization (same as detection)
auto data = NormalizeAndPermute(resized);
input_values_handler.assign(data.begin(), data.end());
return Ort::Value::CreateTensor<float>(
*memory_info_handler, input_values_handler.data(), input_values_handler.size(),
input_node_dims.data(), input_node_dims.size());
}
Ort::Value ONNXOCRClassifier::transformBatch(const std::vector<cv::Mat>& images) {
// Not used - classifier processes single images in Classify() loop
if (!images.empty()) {
return transform(images[0]);
}
return Ort::Value(nullptr);
}
void ONNXOCRClassifier::Classify(std::vector<cv::Mat>& img_list,
std::vector<int>& cls_labels,
std::vector<float>& cls_scores,
float cls_thresh) {
std::lock_guard<std::mutex> lock(_mutex);
cls_labels.clear();
cls_scores.clear();
if (!ort_session || img_list.empty()) return;
cls_labels.resize(img_list.size(), 0);
cls_scores.resize(img_list.size(), 0.0f);
// Process one image at a time (dynamic shapes)
for (size_t i = 0; i < img_list.size(); i++) {
if (img_list[i].empty()) continue;
try {
// Preprocess: direct resize to 80x160 (PP-LCNet_x1_0_textline_ori)
// No aspect ratio preservation — matches PaddleOCR official ResizeImage
cv::Mat resized;
cv::resize(img_list[i], resized, cv::Size(kClsImageW, kClsImageH));
resized.convertTo(resized, CV_32FC3);
// PP-LCNet uses ImageNet normalization (same as detection)
auto inputData = NormalizeAndPermute(resized);
std::array<int64_t, 4> inputShape = { 1, 3, kClsImageH, kClsImageW };
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
*memory_info_handler, inputData.data(), inputData.size(),
inputShape.data(), inputShape.size());
auto outputTensors = ort_session->Run(
Ort::RunOptions{ nullptr },
input_node_names.data(), &inputTensor, 1,
output_node_names.data(), num_outputs);
float* outData = outputTensors[0].GetTensorMutableData<float>();
auto outShape = outputTensors[0].GetTensorTypeAndShapeInfo().GetShape();
int numClasses = (outShape.size() > 1) ? static_cast<int>(outShape[1]) : 2;
// Find argmax and use raw output value as score
// PaddleOCR v5 models include softmax, so output values are probabilities
// Matches PaddleOCR official: score = preds[i, argmax_idx]
int maxIdx = 0;
float maxVal = outData[0];
for (int c = 1; c < numClasses; c++) {
if (outData[c] > maxVal) {
maxVal = outData[c];
maxIdx = c;
}
}
cls_labels[i] = maxIdx;
cls_scores[i] = maxVal;
}
catch (const Ort::Exception& e) {
std::cerr << "[ONNXOCRClassifier] Inference failed for image " << i
<< ": " << e.what() << std::endl;
cls_labels[i] = 0;
cls_scores[i] = 0.0f;
}
}
}
2026-04-14 20:30:21 +10:00
void ONNXOCRClassifier::Warmup() {
std::lock_guard<std::mutex> lock(_mutex);
if (_warmedUp || !ort_session) return;
try {
cv::Mat dummy(kClsImageH * 2, kClsImageW * 2, CV_8UC3, cv::Scalar(128, 128, 128));
cv::Mat resized;
cv::resize(dummy, resized, cv::Size(kClsImageW, kClsImageH));
resized.convertTo(resized, CV_32FC3);
auto inputData = NormalizeAndPermute(resized);
std::array<int64_t, 4> inputShape = { 1, 3, kClsImageH, kClsImageW };
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
*memory_info_handler, inputData.data(), inputData.size(),
inputShape.data(), inputShape.size());
auto t0 = std::chrono::high_resolution_clock::now();
(void)ort_session->Run(
Ort::RunOptions{ nullptr },
input_node_names.data(), &inputTensor, 1,
output_node_names.data(), num_outputs);
auto t1 = std::chrono::high_resolution_clock::now();
double ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
std::cout << "[ONNXOCRClassifier] Warmup [1,3,"
<< kClsImageH << "," << kClsImageW << "] "
<< ms << " ms" << std::endl;
}
catch (const Ort::Exception& e) {
std::cerr << "[ONNXOCRClassifier] Warmup failed: " << e.what() << std::endl;
}
_warmedUp = true;
}
2026-03-28 16:54:11 +11:00
} // namespace onnxocr
} // namespace ANSCENTER