Files
ANSCORE/modules/ANSOCR/ANSRtOCR.cpp

409 lines
16 KiB
C++

#include "ANSRtOCR.h"
#include "Utility.h"
#include <opencv2/highgui.hpp>
namespace ANSCENTER {
bool ANSRTOCR::Initialize(const std::string& licenseKey, OCRModelConfig modelConfig,
const std::string& modelZipFilePath, const std::string& modelZipPassword, int engineMode) {
try {
bool result = ANSOCRBase::Initialize(licenseKey, modelConfig, modelZipFilePath, modelZipPassword, engineMode);
if (!result) return false;
// Validate detection model
if (!FileExist(_modelConfig.detectionModelFile)) {
this->_logger.LogFatal("ANSRTOCR::Initialize", "Invalid detector model file: " + _modelConfig.detectionModelFile, __FILE__, __LINE__);
_licenseValid = false;
return false;
}
// Validate recognizer model
if (!FileExist(_modelConfig.recognizerModelFile)) {
this->_logger.LogFatal("ANSRTOCR::Initialize", "Invalid recognizer model file: " + _modelConfig.recognizerModelFile, __FILE__, __LINE__);
_licenseValid = false;
return false;
}
// Classifier is optional - controlled by useCLS flag and file existence
std::string clsModelPath;
if (_modelConfig.useCLS) {
clsModelPath = _modelConfig.clsModelFile;
if (!clsModelPath.empty() && !FileExist(clsModelPath)) {
this->_logger.LogWarn("ANSRTOCR::Initialize", "Classifier model not found, skipping: " + clsModelPath, __FILE__, __LINE__);
clsModelPath = ""; // Clear to skip classifier
}
}
else {
this->_logger.LogDebug("ANSRTOCR::Initialize", "Classifier disabled (useCLS=false)", __FILE__, __LINE__);
}
try {
// Configure engine parameters from modelConfig
_engine->SetDetMaxSideLen(_modelConfig.limitSideLen);
_engine->SetDetDbThresh(static_cast<float>(_modelConfig.detectionDBThreshold));
_engine->SetDetBoxThresh(static_cast<float>(_modelConfig.detectionBoxThreshold));
_engine->SetDetUnclipRatio(static_cast<float>(_modelConfig.detectionDBUnclipRatio));
_engine->SetClsThresh(static_cast<float>(_modelConfig.clsThreshold));
_engine->SetUseDilation(_modelConfig.useDilation);
_engine->SetGpuId(_modelConfig.gpuId);
// Determine engine cache directory (same folder as detection model)
std::string engineCacheDir;
auto pos = _modelConfig.detectionModelFile.find_last_of("/\\");
if (pos != std::string::npos) {
engineCacheDir = _modelConfig.detectionModelFile.substr(0, pos);
}
_isInitialized = _engine->Initialize(
_modelConfig.detectionModelFile,
clsModelPath,
_modelConfig.recognizerModelFile,
_modelConfig.recogizerCharDictionaryPath,
_modelConfig.gpuId,
engineCacheDir);
return _isInitialized;
}
catch (const std::exception& e) {
_licenseValid = false;
this->_logger.LogFatal("ANSRTOCR::Initialize", e.what(), __FILE__, __LINE__);
return false;
}
catch (...) {
_licenseValid = false;
this->_logger.LogFatal("ANSRTOCR::Initialize", "Failed to create TensorRT OCR engine", __FILE__, __LINE__);
return false;
}
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSRTOCR::Initialize", e.what(), __FILE__, __LINE__);
_licenseValid = false;
return false;
}
}
std::vector<ANSCENTER::OCRObject> ANSRTOCR::RunInference(const cv::Mat& input) {
std::vector<ANSCENTER::OCRObject> output;
if (input.empty()) return output;
if ((input.cols < 10) || (input.rows < 10)) return output;
return RunInference(input, "OCRRTCAM");
}
std::vector<ANSCENTER::OCRObject> ANSRTOCR::RunInference(const cv::Mat& input, const std::string& cameraId) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
std::vector<ANSCENTER::OCRObject> OCRObjects;
if (!_licenseValid) {
this->_logger.LogError("ANSRTOCR::RunInference", "Invalid License", __FILE__, __LINE__);
return OCRObjects;
}
if (!_isInitialized) {
this->_logger.LogError("ANSRTOCR::RunInference", "Model is not initialized", __FILE__, __LINE__);
return OCRObjects;
}
if (input.empty() || input.cols < 10 || input.rows < 10) {
this->_logger.LogError("ANSRTOCR::RunInference", "Input image is invalid or too small", __FILE__, __LINE__);
return OCRObjects;
}
try {
// Convert grayscale to BGR if necessary
cv::Mat im;
if (input.channels() == 1) {
cv::cvtColor(input, im, cv::COLOR_GRAY2BGR);
}
else {
im = input.clone();
}
if (!_engine) {
this->_logger.LogFatal("ANSRTOCR::RunInference", "Engine instance is null", __FILE__, __LINE__);
return OCRObjects;
}
// The engine handles large images correctly in two stages:
// 1. Detection: internally scales to limitSideLen → bounded GPU memory
// 2. Recognition: crops each text box from the ORIGINAL full-res image
// This preserves text detail without tiling (which fragments text at boundaries).
std::vector<rtocr::OCRPredictResult> res_ocr = _engine->ocr(im);
for (size_t n = 0; n < res_ocr.size(); ++n) {
if (res_ocr[n].box.size() != 4) {
this->_logger.LogError("ANSRTOCR::RunInference", "Invalid OCR box size", __FILE__, __LINE__);
continue;
}
cv::Point rook_points[4];
for (size_t m = 0; m < 4; ++m) {
rook_points[m] = cv::Point(
static_cast<int>(res_ocr[n].box[m][0]),
static_cast<int>(res_ocr[n].box[m][1])
);
}
int x = std::max(0, rook_points[0].x);
int y = std::max(0, rook_points[0].y);
int width = rook_points[1].x - rook_points[0].x;
int height = rook_points[2].y - rook_points[1].y;
width = std::max(1, std::min(im.cols - x, width));
height = std::max(1, std::min(im.rows - y, height));
if (width <= 1 || height <= 1) {
continue;
}
ANSCENTER::OCRObject ocrObject;
ocrObject.box = cv::Rect(x, y, width, height);
ocrObject.classId = res_ocr[n].cls_label;
ocrObject.confidence = res_ocr[n].score;
ocrObject.className = res_ocr[n].text;
ocrObject.extraInfo = "cls label: " + std::to_string(res_ocr[n].cls_label)
+ "; cls score: " + std::to_string(res_ocr[n].cls_score);
ocrObject.cameraId = cameraId;
OCRObjects.push_back(ocrObject);
}
im.release();
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSRTOCR::RunInference", e.what(), __FILE__, __LINE__);
}
catch (...) {
this->_logger.LogFatal("ANSRTOCR::RunInference", "Unknown exception occurred", __FILE__, __LINE__);
}
return OCRObjects;
}
std::vector<ANSCENTER::OCRObject> ANSRTOCR::RunInference(const cv::Mat& input, const std::vector<cv::Rect>& Bbox) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
std::vector<ANSCENTER::OCRObject> OCRObjects;
if (!_licenseValid) {
this->_logger.LogError("ANSRTOCR::RunInference", "Invalid License", __FILE__, __LINE__);
return OCRObjects;
}
if (!_isInitialized) {
this->_logger.LogError("ANSRTOCR::RunInference", "Model is not initialized", __FILE__, __LINE__);
return OCRObjects;
}
try {
if (input.empty()) {
this->_logger.LogError("ANSRTOCR::RunInference", "Input image is empty", __FILE__, __LINE__);
return OCRObjects;
}
if ((input.cols < 10) || (input.rows < 10)) return OCRObjects;
if (Bbox.size() > 0) {
cv::Mat frame;
if (input.channels() == 1) {
cv::cvtColor(input, frame, cv::COLOR_GRAY2BGR);
}
else {
frame = input.clone();
}
int fWidth = frame.cols;
int fHeight = frame.rows;
for (auto it = Bbox.begin(); it != Bbox.end(); it++) {
int x1 = std::max(0, it->x);
int y1 = std::max(0, it->y);
int width = std::min(fWidth - x1, it->width);
int height = std::min(fHeight - y1, it->height);
if (x1 >= 0 && y1 >= 0 && width >= 5 && height >= 5) {
cv::Rect objectPos(x1, y1, width, height);
cv::Mat croppedObject = frame(objectPos);
std::vector<ANSCENTER::OCRObject> tempObjects = RunInference(croppedObject);
for (size_t i = 0; i < tempObjects.size(); i++) {
ANSCENTER::OCRObject detObj = tempObjects[i];
detObj.box.x = tempObjects[i].box.x + x1;
detObj.box.y = tempObjects[i].box.y + y1;
detObj.box.x = std::max(0, detObj.box.x);
detObj.box.y = std::max(0, detObj.box.y);
detObj.box.width = std::min(fWidth - detObj.box.x, detObj.box.width);
detObj.box.height = std::min(fHeight - detObj.box.y, detObj.box.height);
OCRObjects.push_back(detObj);
}
}
}
}
else {
cv::Mat frame;
if (input.channels() == 1) {
cv::cvtColor(input, frame, cv::COLOR_GRAY2BGR);
}
else {
frame = input.clone();
}
std::vector<rtocr::OCRPredictResult> res_ocr = _engine->ocr(frame);
for (size_t n = 0; n < res_ocr.size(); n++) {
if (res_ocr[n].box.size() != 4) continue;
cv::Point rook_points[4];
for (size_t m = 0; m < res_ocr[n].box.size(); m++) {
rook_points[m] = cv::Point(
static_cast<int>(res_ocr[n].box[m][0]),
static_cast<int>(res_ocr[n].box[m][1]));
}
ANSCENTER::OCRObject ocrObject;
ocrObject.box.x = rook_points[0].x;
ocrObject.box.y = rook_points[0].y;
ocrObject.box.width = rook_points[1].x - rook_points[0].x;
ocrObject.box.height = rook_points[2].y - rook_points[1].y;
ocrObject.box.x = std::max(0, ocrObject.box.x);
ocrObject.box.y = std::max(0, ocrObject.box.y);
ocrObject.box.width = std::min(frame.cols - ocrObject.box.x, ocrObject.box.width);
ocrObject.box.height = std::min(frame.rows - ocrObject.box.y, ocrObject.box.height);
ocrObject.classId = res_ocr[n].cls_label;
ocrObject.confidence = res_ocr[n].score;
ocrObject.className = res_ocr[n].text;
ocrObject.extraInfo = "cls label:" + std::to_string(res_ocr[n].cls_label)
+ ";cls score:" + std::to_string(res_ocr[n].cls_score);
OCRObjects.push_back(ocrObject);
}
frame.release();
}
return OCRObjects;
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSRTOCR::RunInference", e.what(), __FILE__, __LINE__);
return OCRObjects;
}
}
std::vector<ANSCENTER::OCRObject> ANSRTOCR::RunInference(const cv::Mat& input, const std::vector<cv::Rect>& Bbox, const std::string& cameraId) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
std::vector<ANSCENTER::OCRObject> OCRObjects;
if (!_licenseValid) {
this->_logger.LogError("ANSRTOCR::RunInference", "Invalid License", __FILE__, __LINE__);
return OCRObjects;
}
if (!_isInitialized) {
this->_logger.LogError("ANSRTOCR::RunInference", "Model is not initialized", __FILE__, __LINE__);
return OCRObjects;
}
try {
if (input.empty()) {
this->_logger.LogError("ANSRTOCR::RunInference", "Input image is empty", __FILE__, __LINE__);
return OCRObjects;
}
if ((input.cols < 10) || (input.rows < 10)) return OCRObjects;
if (Bbox.size() > 0) {
cv::Mat frame;
if (input.channels() == 1) {
cv::cvtColor(input, frame, cv::COLOR_GRAY2BGR);
}
else {
frame = input.clone();
}
int fWidth = frame.cols;
int fHeight = frame.rows;
for (auto it = Bbox.begin(); it != Bbox.end(); it++) {
int x1 = std::max(0, it->x);
int y1 = std::max(0, it->y);
int width = std::min(fWidth - x1, it->width);
int height = std::min(fHeight - y1, it->height);
if (x1 >= 0 && y1 >= 0 && width >= 5 && height >= 5) {
cv::Rect objectPos(x1, y1, width, height);
cv::Mat croppedObject = frame(objectPos);
std::vector<ANSCENTER::OCRObject> tempObjects = RunInference(croppedObject);
for (size_t i = 0; i < tempObjects.size(); i++) {
ANSCENTER::OCRObject detObj = tempObjects[i];
detObj.box.x = tempObjects[i].box.x + x1;
detObj.box.y = tempObjects[i].box.y + y1;
detObj.box.x = std::max(0, detObj.box.x);
detObj.box.y = std::max(0, detObj.box.y);
detObj.box.width = std::min(fWidth - detObj.box.x, detObj.box.width);
detObj.box.height = std::min(fHeight - detObj.box.y, detObj.box.height);
detObj.cameraId = cameraId;
OCRObjects.push_back(detObj);
}
}
}
}
else {
cv::Mat im = input.clone();
std::vector<rtocr::OCRPredictResult> res_ocr = _engine->ocr(im);
for (size_t n = 0; n < res_ocr.size(); n++) {
if (res_ocr[n].box.size() != 4) continue;
cv::Point rook_points[4];
for (size_t m = 0; m < res_ocr[n].box.size(); m++) {
rook_points[m] = cv::Point(
static_cast<int>(res_ocr[n].box[m][0]),
static_cast<int>(res_ocr[n].box[m][1]));
}
ANSCENTER::OCRObject ocrObject;
ocrObject.box.x = rook_points[0].x;
ocrObject.box.y = rook_points[0].y;
ocrObject.box.width = rook_points[1].x - rook_points[0].x;
ocrObject.box.height = rook_points[2].y - rook_points[1].y;
ocrObject.box.x = std::max(0, ocrObject.box.x);
ocrObject.box.y = std::max(0, ocrObject.box.y);
ocrObject.box.width = std::min(im.cols - ocrObject.box.x, ocrObject.box.width);
ocrObject.box.height = std::min(im.rows - ocrObject.box.y, ocrObject.box.height);
ocrObject.classId = res_ocr[n].cls_label;
ocrObject.confidence = res_ocr[n].score;
ocrObject.className = res_ocr[n].text;
ocrObject.extraInfo = "cls label:" + std::to_string(res_ocr[n].cls_label)
+ ";cls score:" + std::to_string(res_ocr[n].cls_score);
ocrObject.cameraId = cameraId;
OCRObjects.push_back(ocrObject);
}
im.release();
}
return OCRObjects;
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSRTOCR::RunInference", e.what(), __FILE__, __LINE__);
return OCRObjects;
}
}
std::pair<std::string, float> ANSRTOCR::RecognizeText(const cv::Mat& croppedImage) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (!_isInitialized || !_engine || croppedImage.empty()) return {"", 0.0f};
auto result = _engine->recognizeOnly(croppedImage);
return {result.text, result.score};
}
ANSRTOCR::~ANSRTOCR() {
try {
Destroy();
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSRTOCR::~ANSRTOCR()", e.what(), __FILE__, __LINE__);
}
}
bool ANSRTOCR::Destroy() {
try {
if (_engine) _engine.reset();
return true;
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSRTOCR::Destroy", e.what(), __FILE__, __LINE__);
return false;
}
}
} // namespace ANSCENTER