Files
ANSCustomModels/tests/TestCommon.h

117 lines
4.0 KiB
C
Raw Normal View History

2026-04-05 14:30:43 +10:00
#pragma once
#include <gtest/gtest.h>
#include <opencv2/opencv.hpp>
#include <string>
#include <vector>
#include <sstream>
#include <chrono>
#include <filesystem>
#include "ANSLIB.h"
// ---------------------------------------------------------------------------
// Model directory paths — update these to match your local environment
// ---------------------------------------------------------------------------
namespace TestConfig {
inline const std::string FIRE_SMOKE_MODEL_DIR =
"C:\\Programs\\DemoAssets\\ModelsForANSVIS\\ANS_FireSmoke_v2.0";
inline const std::string HELMET_MODEL_DIR =
"C:\\Programs\\DemoAssets\\ModelsForANSVIS\\ANS_Helmet(GPU)_v1.0";
inline const std::string WEAPON_MODEL_DIR =
"C:\\Programs\\DemoAssets\\ModelsForANSVIS\\ANS_WeaponDetection(GPU)_1.0";
inline const std::string FIRE_SMOKE_VIDEO =
"C:\\Programs\\DemoAssets\\Videos\\FireNSmoke\\ANSFireFull.mp4";
inline const std::string HELMET_VIDEO =
"C:\\Programs\\DemoAssets\\Videos\\Helmet\\HM2.mp4";
inline const std::string WEAPON_VIDEO =
"C:\\Programs\\DemoAssets\\Videos\\Weapon\\AK47 Glock.mp4";
// Check if model directory exists
inline bool ModelExists(const std::string& path) {
return std::filesystem::exists(path) && std::filesystem::is_directory(path);
}
// Check if video file exists
inline bool VideoExists(const std::string& path) {
return std::filesystem::exists(path);
}
} // namespace TestConfig
// ---------------------------------------------------------------------------
// Helper utilities
// ---------------------------------------------------------------------------
namespace TestUtils {
// Parse comma-separated label map into vector of class names
inline std::vector<std::string> ParseLabelMap(const std::string& labelMap) {
std::vector<std::string> classes;
std::stringstream ss(labelMap);
std::string item;
while (std::getline(ss, item, ',')) {
classes.push_back(item);
}
return classes;
}
// Create a solid-color test frame (no model required)
inline cv::Mat CreateTestFrame(int width, int height, cv::Scalar color = cv::Scalar(128, 128, 128)) {
return cv::Mat(height, width, CV_8UC3, color);
}
// Create a frame with a bright red/orange region to simulate fire-like colors
inline cv::Mat CreateFireLikeFrame(int width, int height) {
cv::Mat frame(height, width, CV_8UC3, cv::Scalar(50, 50, 50));
cv::Rect fireRegion(width / 4, height / 4, width / 2, height / 2);
frame(fireRegion) = cv::Scalar(0, 80, 255); // BGR: orange-red
return frame;
}
// Create a frame with a gray haze to simulate smoke-like colors
inline cv::Mat CreateSmokeLikeFrame(int width, int height) {
cv::Mat frame(height, width, CV_8UC3, cv::Scalar(30, 30, 30));
cv::Rect smokeRegion(width / 4, height / 4, width / 2, height / 2);
frame(smokeRegion) = cv::Scalar(180, 180, 190); // BGR: light gray
return frame;
}
// Measure inference time in milliseconds
template <typename Func>
double MeasureMs(Func&& func) {
auto start = std::chrono::high_resolution_clock::now();
func();
auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double, std::milli>(end - start).count();
}
// Run N frames of video through a detector, return (totalDetections, avgMs)
template <typename Detector>
std::pair<int, double> RunVideoFrames(Detector& detector, const std::string& videoPath, int maxFrames) {
cv::VideoCapture cap(videoPath);
if (!cap.isOpened()) return { -1, 0.0 };
int totalDetections = 0;
double totalMs = 0.0;
int frameCount = 0;
while (frameCount < maxFrames) {
cv::Mat frame;
if (!cap.read(frame)) break;
double ms = MeasureMs([&]() {
auto results = detector.RunInference(frame);
totalDetections += static_cast<int>(results.size());
});
totalMs += ms;
frameCount++;
}
cap.release();
double avgMs = (frameCount > 0) ? totalMs / frameCount : 0.0;
return { totalDetections, avgMs };
}
} // namespace TestUtils