Files
ANSCORE/modules/ANSOCR/ANSRTOCR/PaddleOCRV5RTEngine.cpp

158 lines
5.8 KiB
C++
Raw Normal View History

2026-03-28 16:54:11 +11:00
#include "PaddleOCRV5RTEngine.h"
#include <opencv2/imgproc.hpp>
#include <iostream>
namespace ANSCENTER {
namespace rtocr {
bool PaddleOCRV5RTEngine::Initialize(const std::string& detModelPath,
const std::string& clsModelPath,
const std::string& recModelPath,
const std::string& dictPath,
int gpuId,
const std::string& engineCacheDir) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
gpuId_ = gpuId;
if (!engineCacheDir.empty()) {
engineCacheDir_ = engineCacheDir;
}
try {
// 1. Initialize detector
detector_ = std::make_unique<RTOCRDetector>();
if (!detector_->Initialize(detModelPath, gpuId_, engineCacheDir_, detMaxSideLen_)) {
std::cerr << "[PaddleOCRV5RTEngine] Failed to initialize detector" << std::endl;
return false;
}
// 2. Initialize classifier (optional - only if path provided)
if (!clsModelPath.empty()) {
classifier_ = std::make_unique<RTOCRClassifier>();
if (!classifier_->Initialize(clsModelPath, gpuId_, engineCacheDir_)) {
std::cerr << "[PaddleOCRV5RTEngine] Warning: Failed to initialize classifier, skipping"
<< std::endl;
classifier_.reset();
}
}
// 3. Initialize recognizer
recognizer_ = std::make_unique<RTOCRRecognizer>();
recognizer_->SetRecImageHeight(recImgH_);
recognizer_->SetRecImageMaxWidth(recImgMaxW_);
if (!recognizer_->Initialize(recModelPath, dictPath, gpuId_, engineCacheDir_)) {
std::cerr << "[PaddleOCRV5RTEngine] Failed to initialize recognizer" << std::endl;
return false;
}
std::cout << "[PaddleOCRV5RTEngine] Initialized successfully"
<< " (detector: yes, classifier: " << (classifier_ ? "yes" : "no")
<< ", recognizer: yes)" << std::endl;
return true;
}
catch (const std::exception& e) {
std::cerr << "[PaddleOCRV5RTEngine] Initialize failed: " << e.what() << std::endl;
return false;
}
}
std::vector<OCRPredictResult> PaddleOCRV5RTEngine::ocr(const cv::Mat& image) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
std::vector<OCRPredictResult> results;
if (!detector_ || !recognizer_ || image.empty()) return results;
try {
// 1. Detection: find text boxes
std::vector<TextBox> textBoxes = detector_->Detect(
image, detMaxSideLen_, detDbThresh_, detBoxThresh_,
detUnclipRatio_, useDilation_);
if (textBoxes.empty()) return results;
// 2. Crop text regions
std::vector<cv::Mat> croppedImages;
croppedImages.reserve(textBoxes.size());
for (size_t i = 0; i < textBoxes.size(); i++) {
cv::Mat cropped = GetRotateCropImage(image, textBoxes[i]);
if (cropped.empty()) continue;
croppedImages.push_back(cropped);
}
if (croppedImages.size() != textBoxes.size()) {
// Some crops failed, rebuild aligned arrays
std::vector<TextBox> validBoxes;
std::vector<cv::Mat> validCrops;
for (size_t i = 0; i < textBoxes.size(); i++) {
cv::Mat cropped = GetRotateCropImage(image, textBoxes[i]);
if (!cropped.empty()) {
validBoxes.push_back(textBoxes[i]);
validCrops.push_back(cropped);
}
}
textBoxes = validBoxes;
croppedImages = validCrops;
}
// 3. Classification (optional): check orientation and rotate if needed
std::vector<int> clsLabels(croppedImages.size(), 0);
std::vector<float> clsScores(croppedImages.size(), 0.0f);
if (classifier_) {
auto clsResults = classifier_->Classify(croppedImages, clsThresh_);
for (size_t i = 0; i < clsResults.size() && i < croppedImages.size(); i++) {
clsLabels[i] = clsResults[i].first;
clsScores[i] = clsResults[i].second;
// Rotate 180 degrees if label is odd and confidence is high enough
if (clsLabels[i] % 2 == 1 && clsScores[i] > clsThresh_) {
cv::rotate(croppedImages[i], croppedImages[i], cv::ROTATE_180);
}
}
}
// 4. Recognition: extract text from cropped images
std::vector<TextLine> textLines = recognizer_->RecognizeBatch(croppedImages);
// 5. Combine results
results.reserve(textBoxes.size());
for (size_t i = 0; i < textBoxes.size(); i++) {
OCRPredictResult res;
// Convert box to [[x,y], ...] format
for (int j = 0; j < 4; j++) {
res.box.push_back({
static_cast<int>(textBoxes[i].points[j].x),
static_cast<int>(textBoxes[i].points[j].y)
});
}
if (i < textLines.size()) {
res.text = textLines[i].text;
res.score = textLines[i].score;
}
res.cls_label = clsLabels[i];
res.cls_score = clsScores[i];
results.push_back(res);
}
return results;
}
catch (const std::exception& e) {
std::cerr << "[PaddleOCRV5RTEngine] OCR failed: " << e.what() << std::endl;
return results;
}
}
TextLine PaddleOCRV5RTEngine::recognizeOnly(const cv::Mat& croppedImage) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (!recognizer_ || croppedImage.empty()) return { "", 0.0f };
return recognizer_->Recognize(croppedImage);
}
2026-03-28 16:54:11 +11:00
} // namespace rtocr
} // namespace ANSCENTER