#include "ANSCustomCodeHelmetDetection.h" static std::string toUpperCase(const std::string& input) { std::string result = input; std::transform(result.begin(), result.end(), result.begin(), [](unsigned char c) { return std::toupper(c); }); return result; } ANSCustomHMD::ANSCustomHMD() { _isInitialized = false; _readROIs = false; } ANSCustomHMD::~ANSCustomHMD() { Destroy(); } bool ANSCustomHMD::Destroy() { try { _detector.reset(); _classifier.reset(); _isInitialized = false; return true; } catch (...) { return false; } } bool ANSCustomHMD::OptimizeModel(bool fp16) { try { if (!_detector || !_classifier) return false; int detectorResult = _detector->Optimize(fp16); int classifierResult = _classifier->Optimize(fp16); if ((detectorResult != 1) || (classifierResult != 1)) return false; else return true; } catch (...) { return false; } } bool ANSCustomHMD::Initialize(const std::string& modelDirectory, float detectionScoreThreshold, std::string& labelMap) { try { _modelDirectory = modelDirectory; _detectionScoreThreshold = detectionScoreThreshold; _isInitialized = false; // Create model instances using factory pattern (ABI-safe) _detector = ANSLIBPtr(ANSCENTER::ANSLIB::Create(), &ANSCENTER::ANSLIB::Destroy); _classifier = ANSLIBPtr(ANSCENTER::ANSLIB::Create(), &ANSCENTER::ANSLIB::Destroy); // NVIDIA GPU: Use TensorRT _detectorModelType = 31; // TENSORRT _detectorDetectionType = 1; // DETECTION _classifierModelType = 31; // TENSORRT _classifierDetectionType = 0; // CLASSIFICATION //Check the hardware type engineType = _detector->GetEngineType(); if (engineType == 1) { // NVIDIA GPU: Use TensorRT _detectorModelType = 31; // TENSORRT _detectorDetectionType = 1; // DETECTION _classifierModelType = 31; // TENSORRT _classifierDetectionType = 0; // CLASSIFICATION std::cout << "NVIDIA GPU detected. Using TensorRT" << std::endl; } else { // CPU/Other: Use YOLO _detectorModelType = 3; // YOLOV8/YOLOV11 _detectorDetectionType = 1; // DETECTION _classifierModelType = 20; // ANSONNXCL _classifierDetectionType = 0; // CLASSIFICATION std::cout << "CPU detected. Using YOLO/ANSONNXCL" << std::endl; } if (_detectionScoreThreshold < 0.25f) _detectionScoreThreshold = 0.25f; // classId: 0=license plate, 1=motorcyclist, 2=helmet, 3=no_helmet labelMap = "license plate,motorcyclist,helmet,no_helmet"; #ifdef FNS_DEBUG this->_loadEngineOnCreate = true; #endif int loadEngineOnCreation = _loadEngineOnCreate ? 1 : 0; int autoEngineDetection = 1; std::string licenseKey = ""; // Load detector model float detScoreThreshold = _detectionScoreThreshold; float detConfThreshold = 0.5f; float detNMSThreshold = 0.5f; std::string detLabelMap; int detResult = _detector->LoadModelFromFolder( licenseKey.c_str(), "detector", "detector.names", detScoreThreshold, detConfThreshold, detNMSThreshold, autoEngineDetection, _detectorModelType, _detectorDetectionType, loadEngineOnCreation, modelDirectory.c_str(), detLabelMap); if (detResult != 1) { std::cerr << "ANSCustomHMD::Initialize: Failed to load detector model." << std::endl; return false; } // Load classifier model float clsScoreThreshold = 0.25f; float clsConfThreshold = 0.5f; float clsNMSThreshold = 0.5f; int clsResult = _classifier->LoadModelFromFolder( licenseKey.c_str(), "classifier", "classifier.names", clsScoreThreshold, clsConfThreshold, clsNMSThreshold, autoEngineDetection, _classifierModelType, _classifierDetectionType, loadEngineOnCreation, modelDirectory.c_str(), _classifierLabels); if (clsResult != 1) { std::cerr << "ANSCustomHMD::Initialize: Failed to load classifier model." << std::endl; return false; } _isInitialized = true; return true; } catch (const std::exception& e) { std::cerr << "ANSCustomHMD::Initialize: Exception: " << e.what() << std::endl; return false; } catch (...) { std::cerr << "ANSCustomHMD::Initialize: Unknown exception." << std::endl; return false; } } bool ANSCustomHMD::ConfigureParameters(CustomParams& param) { param.ROI_Config.clear(); param.ROI_Options.clear(); param.Parameters.clear(); param.ROI_Values.clear(); return true; } std::vector ANSCustomHMD::RunInference(const cv::Mat& input) { return RunInference(input, "CustomCam"); } std::vector ANSCustomHMD::RunInference(const cv::Mat& input, const std::string& camera_id) { std::lock_guard lock(_mutex); if (!_isInitialized || !_detector) { return {}; } if (input.empty() || input.cols < 10 || input.rows < 10) { return {}; } try { // One-time parameter reading if (!_readROIs) { for (const auto& param : _params.Parameters) { if (param.Name.find("ALPR") != std::string::npos) { std::string paramValue = toUpperCase(param.Value); // ALPR feature currently disabled } } _readROIs = true; } // Run object detection std::vector detectionResults; _detector->RunInference(input, camera_id.c_str(), detectionResults); if (detectionResults.empty()) { return {}; } std::vector results; results.reserve(detectionResults.size()); const cv::Rect frameBounds(0, 0, input.cols, input.rows); const float scoreThreshold = _detectionScoreThreshold; for (const auto& obj : detectionResults) { if (obj.confidence < scoreThreshold) { continue; } CustomObject customObject; customObject.confidence = obj.confidence; customObject.box = obj.box; customObject.cameraId = camera_id; switch (obj.classId) { case 0: { // License plate customObject.classId = 0; customObject.className = "license plate"; customObject.extraInfo = "license plate"; results.push_back(std::move(customObject)); break; } case 1: { // Motorcycle customObject.classId = 1; customObject.className = "motorcyclist"; results.push_back(std::move(customObject)); break; } case 2: // Helmet case 3: { // No helmet - classify to confirm // Validate bounding box if ((obj.box & frameBounds) != obj.box) { continue; } if (!_classifier) { continue; } cv::Mat croppedImage = input(obj.box); std::vector classifierResults; _classifier->RunInference(croppedImage, camera_id.c_str(), classifierResults); if (classifierResults.empty()) { continue; } const bool isNoHelmet = (classifierResults[0].classId == 1); customObject.classId = isNoHelmet ? 3 : 2; customObject.className = isNoHelmet ? "no_helmet" : "helmet"; results.push_back(std::move(customObject)); break; } default: break; } } return results; } catch (const std::exception& e) { std::cerr << "ANSCustomHMD::RunInference: Exception: " << e.what() << std::endl; } catch (...) { std::cerr << "ANSCustomHMD::RunInference: Unknown exception." << std::endl; } return {}; }