Files
ANSCORE/modules/ANSODEngine/Movienet.cpp

369 lines
15 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "Movienet.h"
#include <filesystem>
namespace ANSCENTER {
// Extract input size from _A<digit> suffix just before the file extension.
// _A0 / _A1 → 172 _A2 → 224 _A3 → 256 _A4 → 290 _A5 → 320
// Examples: Violence0305_A5.onnx → 320
// MyModel0908_A4.onnx → 290
// Returns 0 if the suffix is not recognised.
static int GetMovinetSizeFromFilename(const std::string& filename)
{
static constexpr int variantSize[] = {
172, // A0
172, // A1
224, // A2
256, // A3
290, // A4
320 // A5
};
// Strip extension — find last '.'
const size_t dotPos = filename.rfind('.');
if (dotPos == std::string::npos || dotPos < 3) return 0;
// Expect 3-char suffix just before the dot: _A<digit>
const char chUnderscore = filename[dotPos - 3];
const char chA = filename[dotPos - 2];
const char chDigit = filename[dotPos - 1];
if (chUnderscore == '_' &&
(chA == 'A' || chA == 'a') &&
chDigit >= '0' && chDigit <= '5')
{
return variantSize[chDigit - '0'];
}
return 0;
}
// Scan folder for any .onnx whose stem ends with _A0.._A5 (highest first).
// Returns the full path + matching input dimensions.
// If modelName is provided (e.g. "Violence0305_A2"), uses that exact model
// and resolves dimensions from its _A<digit> suffix.
// Falls back to "movinet.onnx" with 172×172 if nothing matches.
static std::string ResolveMovinetModel(const std::string& folder,
int& outWidth, int& outHeight,
const std::string& modelName = "")
{
// If a specific model name was given, use it directly
if (!modelName.empty()) {
std::string fname = modelName;
// Append .onnx if not already present
if (fname.size() < 5 || fname.substr(fname.size() - 5) != ".onnx") {
fname += ".onnx";
}
std::string fullPath = CreateFilePath(folder, fname);
int sz = GetMovinetSizeFromFilename(fname);
if (sz > 0) {
outWidth = sz;
outHeight = sz;
std::cout << "ANSMOVIENET: Using specified model '" << fname << "' with input size " << sz << "x" << sz << "\n";
}
else {
// Model name given but no recognized _A<digit> suffix — use default
outWidth = 172;
outHeight = 172;
std::cout << "ANSMOVIENET: Using specified model '" << fname << "' but failed to detect input size from filename. Defaulting to 172x172.\n";
}
return fullPath;
}
// No model name specified — scan folder, highest variant first
static const std::string suffixes[] = { "_A5", "_A4", "_A3", "_A2", "_A1", "_A0" };
try {
namespace fs = std::filesystem;
if (fs::is_directory(folder)) {
for (const auto& suffix : suffixes) {
for (const auto& entry : fs::directory_iterator(folder)) {
if (!entry.is_regular_file()) continue;
const std::string fname = entry.path().filename().string();
// Must be .onnx
if (fname.size() < 5) continue;
std::string ext = fname.substr(fname.size() - 5);
// Case-insensitive .onnx check
for (auto& c : ext) c = static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
if (ext != ".onnx") continue;
// Check if stem ends with the current suffix
const std::string stem = fname.substr(0, fname.size() - 5);
if (stem.size() >= suffix.size() &&
stem.compare(stem.size() - suffix.size(), suffix.size(), suffix) == 0)
{
int sz = GetMovinetSizeFromFilename(fname);
if (sz > 0) {
outWidth = sz;
outHeight = sz;
return entry.path().string();
}
}
}
}
}
}
catch (...) {
// Filesystem error — fall through to legacy
std::cout << "ANSMOVIENET: Error scanning model folder '. Falling back to default model.\n";
}
// Legacy fallback
outWidth = 172;
outHeight = 172;
return CreateFilePath(folder, "movinet.onnx");
}
bool ANSMOVIENET::Initialize(std::string licenseKey, ModelConfig modelConfig, const std::string& modelZipFilePath, const std::string& modelZipPassword, std::string& labelMap) {
bool result = ANSODBase::Initialize(licenseKey, modelConfig, modelZipFilePath, modelZipPassword, labelMap);
labelMap = "Face";
_licenseValid = true;
std::vector<std::string> labels{ labelMap };
if (!_licenseValid) return false;
try {
_modelConfig = modelConfig;
_modelConfig.modelType = ModelType::MOVIENET;
_modelConfig.detectionType = DetectionType::CLASSIFICATION;
// Auto-detect model variant and matching input size
int detectedW = 0, detectedH = 0;
std::string onnxModel = ResolveMovinetModel(_modelFolder, detectedW, detectedH);
_modelConfig.inpHeight = detectedH;
_modelConfig.inpWidth = detectedW;
if (_modelConfig.modelMNSThreshold < 0.2)
_modelConfig.modelMNSThreshold = 0.5;
if (_modelConfig.modelConfThreshold < 0.2)
_modelConfig.modelConfThreshold = 0.5;
if (_isInitialized) {
_movienet_detector.reset();
_isInitialized = false;
}
unsigned int numThreads = 1;
this->_movienet_detector = std::make_unique<MOVINET>(
onnxModel, TEMPORAL_LENGTH, detectedW, detectedH, 3, numThreads);
_isInitialized = true;
return true;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSMOVIENET::Initialize", e.what(), __FILE__, __LINE__);
return false;
}
}
bool ANSMOVIENET::LoadModel(const std::string& modelZipFilePath, const std::string& modelZipPassword) {
try {
bool result = ANSODBase::LoadModel(modelZipFilePath, modelZipPassword);
if (!result) return false;
int detectedW = 0, detectedH = 0;
std::string onnxModel = ResolveMovinetModel(_modelFolder, detectedW, detectedH);
_modelConfig.inpWidth = detectedW;
_modelConfig.inpHeight = detectedH;
unsigned int numThreads = 1;
_movienet_detector = std::make_unique<MOVINET>(
onnxModel, TEMPORAL_LENGTH, detectedW, detectedH, 3, numThreads);
_isInitialized = true;
return _isInitialized;
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSMOVIENET::LoadModel", e.what(), __FILE__, __LINE__);
return false;
}
}
bool ANSMOVIENET::LoadModelFromFolder(std::string licenseKey, ModelConfig modelConfig, std::string modelName, std::string className, const std::string& modelFolder, std::string& labelMap) {
try {
bool result = ANSODBase::LoadModelFromFolder(licenseKey, modelConfig, modelName, className, modelFolder, labelMap);
if (!result) return false;
_modelConfig = modelConfig;
_modelConfig.modelType = ModelType::MOVIENET;
_modelConfig.detectionType = DetectionType::CLASSIFICATION;
// Resolve model path and input dimensions.
// If modelName is given (e.g. "Violence0305_A2"), uses that exact model;
// otherwise scans the folder for the best _A<digit> variant.
int detectedW = 0, detectedH = 0;
std::string onnxModel = ResolveMovinetModel(modelFolder, detectedW, detectedH, modelName);
_modelConfig.inpWidth = detectedW;
_modelConfig.inpHeight = detectedH;
if (_modelConfig.modelMNSThreshold < 0.2)
_modelConfig.modelMNSThreshold = 0.5;
if (_modelConfig.modelConfThreshold < 0.2)
_modelConfig.modelConfThreshold = 0.5;
if (_isInitialized) {
_movienet_detector.reset();
_isInitialized = false;
}
unsigned int numThreads = 1;
this->_movienet_detector = std::make_unique<MOVINET>(
onnxModel, TEMPORAL_LENGTH, detectedW, detectedH, 3, numThreads);
_isInitialized = true;
return _isInitialized;
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSMOVIENET::LoadModel", e.what(), __FILE__, __LINE__);
return false;
}
}
bool ANSMOVIENET::OptimizeModel(bool fp16, std::string& optimizedModelFolder) {
if (FileExist(_modelFilePath)) {
optimizedModelFolder = GetParentFolder(_modelFilePath);
this->_logger.LogDebug("ANSMOVIENET::OptimizeModel", "This model is optimized. No need other optimization.", __FILE__, __LINE__);
return true;
}
else {
optimizedModelFolder = "";
this->_logger.LogFatal("ANSMOVIENET::OptimizeModel", "This model is not exist. Please check the model path again.", __FILE__, __LINE__);
return false;
}
}
ANSMOVIENET::~ANSMOVIENET() {
try {
Destroy();
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSMOVIENET::Destroy", e.what(), __FILE__, __LINE__);
}
}
bool ANSMOVIENET::Destroy() {
try {
std::lock_guard<std::recursive_mutex> lock(_mutex);
_cameraQueues.clear();
_globalFrameCounter = 0;
_movienet_detector.reset();
_isInitialized = false;
return true;
}
catch (std::exception& e) {
this->_logger.LogFatal("ANSMOVIENET::Destroy", e.what(), __FILE__, __LINE__);
return false;
}
}
// Inference functions
std::vector<Object> ANSMOVIENET::RunInference(const cv::Mat& input, const std::string& camera_id) {
std::vector<Object> result = Inference(input, camera_id);
if (_trackerEnabled) {
result = ApplyTracking(result, camera_id);
if (_stabilizationEnabled) result = StabilizeDetections(result, camera_id);
}
return result;
}
std::vector<Object> ANSMOVIENET::RunInference(const cv::Mat& input) {
return Inference(input, "MovienetCam");
}
void ANSMOVIENET::CleanupStaleQueues() {
// Called internally <20> already under lock from Inference()
if (_cameraQueues.empty()) return;
// 1. Remove queues not accessed for STALE_THRESHOLD frames
for (auto it = _cameraQueues.begin(); it != _cameraQueues.end(); ) {
int age = _globalFrameCounter - it->second.lastAccessFrame;
if (age > STALE_THRESHOLD) {
it = _cameraQueues.erase(it);
}
else {
++it;
}
}
// 2. If still over hard cap, remove oldest queues first
if (static_cast<int>(_cameraQueues.size()) > MAX_QUEUES) {
// Collect and sort by last access time
std::vector<std::pair<int, std::string>> accessTimes;
accessTimes.reserve(_cameraQueues.size());
for (const auto& [key, state] : _cameraQueues) {
accessTimes.emplace_back(state.lastAccessFrame, key);
}
std::sort(accessTimes.begin(), accessTimes.end());
// Remove oldest until under cap
int toRemove = static_cast<int>(_cameraQueues.size()) - MAX_QUEUES;
for (int i = 0; i < toRemove && i < static_cast<int>(accessTimes.size()); ++i) {
_cameraQueues.erase(accessTimes[i].second);
}
}
}
std::vector<Object> ANSMOVIENET::Inference(const cv::Mat& input, const std::string& camera_id) {
std::lock_guard<std::recursive_mutex> lock(_mutex);
std::vector<Object> detectedObjects;
if (!_isInitialized || !_licenseValid || !_movienet_detector) {
this->_logger.LogError("ANSMOVIENET::Inference",
"Model is not initialized or license is not valid", __FILE__, __LINE__);
return detectedObjects;
}
if (input.empty()) {
this->_logger.LogError("ANSMOVIENET::Inference",
"Input frame is empty", __FILE__, __LINE__);
return detectedObjects;
}
try {
_globalFrameCounter++;
std::string cameraKey = camera_id.empty() ? "" : camera_id;
auto& state = _cameraQueues[cameraKey];
// Update access timestamp
state.lastAccessFrame = _globalFrameCounter;
state.frameCount++;
// Add frame to queue
state.frames.push_back(input.clone());
// Maintain queue at TEMPORAL_LENGTH
while (state.frames.size() > static_cast<size_t>(TEMPORAL_LENGTH)) {
state.frames.pop_front();
}
// Run inference when:
// 1. Full window available
// 2. Stride condition met
bool hasFullWindow = (state.frames.size() == static_cast<size_t>(TEMPORAL_LENGTH));
bool strideReady = (state.frameCount == TEMPORAL_LENGTH)
|| (state.frameCount > TEMPORAL_LENGTH
&& (state.frameCount - TEMPORAL_LENGTH) % _inferenceStride == 0);
if (hasFullWindow && strideReady) {
std::pair<int, float> result;
_movienet_detector->inference(state.frames, result);
if (result.first >= 0) {
Object obj;
obj.classId = result.first;
obj.className = (result.first < static_cast<int>(_classes.size())) ?
_classes[result.first] : "Unknown";
obj.confidence = result.second;
obj.box = cv::Rect(0, 0, input.cols, input.rows);
detectedObjects.push_back(obj);
}
}
// Store in base class camera data
CameraData& cameraData = GetCameraData(cameraKey);
cameraData._detectionQueue.push_back(detectedObjects);
if (cameraData._detectionQueue.size() > QUEUE_SIZE) {
cameraData._detectionQueue.pop_front();
}
// ----- Periodic self-cleanup -----
if (_globalFrameCounter % CLEANUP_INTERVAL == 0) {
CleanupStaleQueues();
}
}
catch (const std::exception& e) {
this->_logger.LogError("ANSMOVIENET::Inference",
std::string("Exception during inference: ") + e.what(), __FILE__, __LINE__);
detectedObjects.clear();
}
catch (...) {
this->_logger.LogError("ANSMOVIENET::Inference",
"Unknown exception during inference", __FILE__, __LINE__);
detectedObjects.clear();
}
return detectedObjects;
}
}