Refactor project structure
This commit is contained in:
165
modules/ANSOCR/ANSONNXOCR/ONNXOCRRecognizer.cpp
Normal file
165
modules/ANSOCR/ANSONNXOCR/ONNXOCRRecognizer.cpp
Normal file
@@ -0,0 +1,165 @@
|
||||
#include "ONNXOCRRecognizer.h"
|
||||
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include <cfloat>
|
||||
#include <cstring>
|
||||
|
||||
namespace ANSCENTER {
|
||||
namespace onnxocr {
|
||||
|
||||
ONNXOCRRecognizer::ONNXOCRRecognizer(const std::string& onnx_path, unsigned int num_threads)
|
||||
: BasicOrtHandler(onnx_path, num_threads) {
|
||||
}
|
||||
|
||||
bool ONNXOCRRecognizer::LoadDictionary(const std::string& dictPath) {
|
||||
keys_ = LoadDict(dictPath);
|
||||
if (keys_.size() < 2) {
|
||||
std::cerr << "[ONNXOCRRecognizer] Failed to load dictionary: " << dictPath << std::endl;
|
||||
return false;
|
||||
}
|
||||
std::cout << "[ONNXOCRRecognizer] Loaded dictionary with " << keys_.size()
|
||||
<< " characters from: " << dictPath << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
Ort::Value ONNXOCRRecognizer::transform(const cv::Mat& mat) {
|
||||
// Not used directly - recognition uses custom preprocess with dynamic width
|
||||
cv::Mat resized = ResizeRecImage(mat, imgH_, imgMaxW_);
|
||||
resized.convertTo(resized, CV_32FC3);
|
||||
auto data = NormalizeAndPermuteCls(resized);
|
||||
|
||||
input_values_handler.assign(data.begin(), data.end());
|
||||
return Ort::Value::CreateTensor<float>(
|
||||
*memory_info_handler, input_values_handler.data(), input_values_handler.size(),
|
||||
input_node_dims.data(), input_node_dims.size());
|
||||
}
|
||||
|
||||
Ort::Value ONNXOCRRecognizer::transformBatch(const std::vector<cv::Mat>& images) {
|
||||
// Not used - recognizer processes single images with dynamic widths
|
||||
if (!images.empty()) {
|
||||
return transform(images[0]);
|
||||
}
|
||||
return Ort::Value(nullptr);
|
||||
}
|
||||
|
||||
TextLine ONNXOCRRecognizer::Recognize(const cv::Mat& croppedImage) {
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
if (!ort_session || croppedImage.empty() || keys_.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
// Preprocess: resize to fixed height, proportional width
|
||||
cv::Mat resized = ResizeRecImage(croppedImage, imgH_, imgMaxW_);
|
||||
int resizedW = resized.cols;
|
||||
|
||||
resized.convertTo(resized, CV_32FC3);
|
||||
// Recognition uses (pixel/255 - 0.5) / 0.5 normalization (same as classifier)
|
||||
auto normalizedData = NormalizeAndPermuteCls(resized);
|
||||
|
||||
// Pad to at least kRecImgW width (matching official PaddleOCR behavior)
|
||||
// Official PaddleOCR: padding_im = np.zeros((C, H, W)), then copies normalized
|
||||
// image into left portion. Padding value = 0.0 in normalized space.
|
||||
int imgW = std::max(resizedW, kRecImgW);
|
||||
|
||||
std::vector<float> inputData;
|
||||
if (imgW > resizedW) {
|
||||
// Zero-pad on the right (CHW layout)
|
||||
inputData.resize(3 * imgH_ * imgW, 0.0f);
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int y = 0; y < imgH_; y++) {
|
||||
std::memcpy(
|
||||
&inputData[c * imgH_ * imgW + y * imgW],
|
||||
&normalizedData[c * imgH_ * resizedW + y * resizedW],
|
||||
resizedW * sizeof(float));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
inputData = std::move(normalizedData);
|
||||
}
|
||||
|
||||
// Create input tensor with (possibly padded) width
|
||||
std::array<int64_t, 4> inputShape = { 1, 3, imgH_, imgW };
|
||||
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
|
||||
*memory_info_handler, inputData.data(), inputData.size(),
|
||||
inputShape.data(), inputShape.size());
|
||||
|
||||
// Run inference
|
||||
auto outputTensors = ort_session->Run(
|
||||
Ort::RunOptions{ nullptr },
|
||||
input_node_names.data(), &inputTensor, 1,
|
||||
output_node_names.data(), num_outputs);
|
||||
|
||||
// Get output
|
||||
float* outputData = outputTensors[0].GetTensorMutableData<float>();
|
||||
auto outputShape = outputTensors[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
int seqLen = static_cast<int>(outputShape[1]);
|
||||
int numClasses = static_cast<int>(outputShape[2]);
|
||||
|
||||
return CTCDecode(outputData, seqLen, numClasses);
|
||||
}
|
||||
catch (const Ort::Exception& e) {
|
||||
std::cerr << "[ONNXOCRRecognizer] Inference failed: " << e.what() << std::endl;
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TextLine> ONNXOCRRecognizer::RecognizeBatch(const std::vector<cv::Mat>& croppedImages) {
|
||||
std::vector<TextLine> results;
|
||||
results.reserve(croppedImages.size());
|
||||
|
||||
// Process one at a time (dynamic width per image)
|
||||
for (size_t i = 0; i < croppedImages.size(); i++) {
|
||||
results.push_back(Recognize(croppedImages[i]));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
TextLine ONNXOCRRecognizer::CTCDecode(const float* outputData, int seqLen, int numClasses) {
|
||||
TextLine result;
|
||||
std::string text;
|
||||
std::vector<float> scores;
|
||||
|
||||
int lastIndex = 0; // CTC blank is index 0
|
||||
|
||||
for (int t = 0; t < seqLen; t++) {
|
||||
// Find argmax for this timestep
|
||||
int maxIndex = 0;
|
||||
float maxValue = -FLT_MAX;
|
||||
|
||||
const float* timeStep = outputData + t * numClasses;
|
||||
for (int c = 0; c < numClasses; c++) {
|
||||
if (timeStep[c] > maxValue) {
|
||||
maxValue = timeStep[c];
|
||||
maxIndex = c;
|
||||
}
|
||||
}
|
||||
|
||||
// CTC decode: skip blanks (index 0) and repeated characters
|
||||
if (maxIndex != 0 && maxIndex != lastIndex) {
|
||||
if (maxIndex > 0 && maxIndex < static_cast<int>(keys_.size())) {
|
||||
text += keys_[maxIndex]; // keys_[0]="#"(blank), keys_[1]=first_char, etc.
|
||||
// Use raw model output value as confidence (PaddleOCR v5 models include softmax)
|
||||
scores.push_back(maxValue);
|
||||
}
|
||||
}
|
||||
lastIndex = maxIndex;
|
||||
}
|
||||
|
||||
result.text = text;
|
||||
if (!scores.empty()) {
|
||||
result.score = std::accumulate(scores.begin(), scores.end(), 0.0f) /
|
||||
static_cast<float>(scores.size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace onnxocr
|
||||
} // namespace ANSCENTER
|
||||
Reference in New Issue
Block a user