diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 03457ee..e7aec9a 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -9,7 +9,10 @@ "Bash(cp \"C:/Projects/ANLS/ANSLIB/ANSCustomFireNSmokeDetection/dllmain.cpp\" \"C:/Projects/CLionProjects/ANSCustomModels/ANSCustomFireNSmokeDetection/dllmain.cpp\")", "Bash(cp \"C:/Projects/ANLS/ANSLIB/ANSCustomFireNSmokeDetection/framework.h\" \"C:/Projects/CLionProjects/ANSCustomModels/ANSCustomFireNSmokeDetection/framework.h\")", "Bash(cp \"C:/Projects/ANLS/ANSLIB/ANSCustomFireNSmokeDetection/pch.h\" \"C:/Projects/CLionProjects/ANSCustomModels/ANSCustomFireNSmokeDetection/pch.h\")", - "Bash(cp \"C:/Projects/ANLS/ANSLIB/ANSCustomFireNSmokeDetection/pch.cpp\" \"C:/Projects/CLionProjects/ANSCustomModels/ANSCustomFireNSmokeDetection/pch.cpp\")" + "Bash(cp \"C:/Projects/ANLS/ANSLIB/ANSCustomFireNSmokeDetection/pch.cpp\" \"C:/Projects/CLionProjects/ANSCustomModels/ANSCustomFireNSmokeDetection/pch.cpp\")", + "Read(//c/ANSLibs/opencv/x64/vc17/**)", + "Read(//c/ANSLibs/opencv/**)", + "Bash(find C:/ANSLibs -name *.dll -type f)" ] } } diff --git a/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.cpp b/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.cpp index 23c811e..72a68d9 100644 --- a/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.cpp +++ b/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.cpp @@ -56,7 +56,109 @@ void ANSCustomFS::ResetDetectionState() { ResetDetectedArea(); _retainDetectedArea = 0; _isFireNSmokeDetected = false; + _trackHistory.clear(); } +// --- Tracker-based voting functions --- + +void ANSCustomFS::UpdateTrackHistory(int trackId, int classId, const cv::Rect& bbox) { + auto it = _trackHistory.find(trackId); + if (it == _trackHistory.end()) { + TrackRecord record; + record.trackId = trackId; + record.classId = classId; + record.bboxHistory.push_back(bbox); + record.detectedCount = 1; + record.totalFrames = 1; + record.confirmed = false; + _trackHistory[trackId] = std::move(record); + } + else { + auto& record = it->second; + record.bboxHistory.push_back(bbox); + record.detectedCount++; + record.totalFrames++; + + // Slide the window: keep only VOTE_WINDOW entries in bbox history + while (static_cast(record.bboxHistory.size()) > VOTE_WINDOW) { + record.bboxHistory.pop_front(); + record.detectedCount = std::max(0, record.detectedCount - 1); + } + + // Update confirmed status + if (record.detectedCount >= VOTE_THRESHOLD) { + record.confirmed = true; + } + } +} + +bool ANSCustomFS::IsTrackConfirmed(int trackId) const { + auto it = _trackHistory.find(trackId); + if (it == _trackHistory.end()) return false; + return it->second.detectedCount >= VOTE_THRESHOLD; +} + +bool ANSCustomFS::HasBboxMovement(int trackId) const { + auto it = _trackHistory.find(trackId); + if (it == _trackHistory.end()) return false; + + const auto& history = it->second.bboxHistory; + if (history.size() < 2) return false; + + const cv::Rect& earliest = history.front(); + const cv::Rect& latest = history.back(); + + // Calculate center positions + float cx1 = earliest.x + earliest.width / 2.0f; + float cy1 = earliest.y + earliest.height / 2.0f; + float cx2 = latest.x + latest.width / 2.0f; + float cy2 = latest.y + latest.height / 2.0f; + + // Average size for normalization + float avgWidth = (earliest.width + latest.width) / 2.0f; + float avgHeight = (earliest.height + latest.height) / 2.0f; + if (avgWidth < 1.0f || avgHeight < 1.0f) return false; + + // Position change relative to average size + float posChange = std::sqrt( + std::pow((cx2 - cx1) / avgWidth, 2) + + std::pow((cy2 - cy1) / avgHeight, 2) + ); + + // Size change relative to average area + float area1 = static_cast(earliest.area()); + float area2 = static_cast(latest.area()); + float avgArea = (area1 + area2) / 2.0f; + float sizeChange = (avgArea > 0) ? std::abs(area2 - area1) / avgArea : 0.0f; + + return (posChange > BBOX_CHANGE_THRESHOLD) || (sizeChange > BBOX_CHANGE_THRESHOLD); +} + +void ANSCustomFS::AgeTracks(const std::unordered_set& detectedTrackIds) { + auto it = _trackHistory.begin(); + while (it != _trackHistory.end()) { + if (detectedTrackIds.find(it->first) == detectedTrackIds.end()) { + // Track was NOT detected this frame + it->second.totalFrames++; + + // Remove stale tracks that haven't been seen recently + if (it->second.totalFrames > VOTE_WINDOW && + it->second.detectedCount == 0) { + it = _trackHistory.erase(it); + continue; + } + + // Age out the sliding window (add a "miss" frame) + if (static_cast(it->second.bboxHistory.size()) >= VOTE_WINDOW) { + it->second.bboxHistory.pop_front(); + it->second.detectedCount = std::max(0, it->second.detectedCount - 1); + } + } + ++it; + } +} + +// --- End voting functions --- + void ANSCustomFS::UpdateNoDetectionCondition() { _isRealFireFrame = false; @@ -120,152 +222,6 @@ void ANSCustomFS::GetModelParameters() { _readROIs = true; } } -std::vector ANSCustomFS::ProcessExistingDetectedArea( - const cv::Mat& frame, - const std::string& camera_id, - const std::vector& fireNSmokeRects, - cv::Mat& draw) -{ -#ifdef FNS_DEBUG - cv::rectangle(draw, _detectedArea, cv::Scalar(255, 255, 0), 2); // Cyan -#endif - - // Run detection on ROI (no clone - just a view into frame) - cv::Mat activeROI = frame(_detectedArea); - std::vector detectedObjects; - _detector->RunInference(activeROI, camera_id.c_str(), detectedObjects); - - if (detectedObjects.empty()) { - UpdateNoDetectionCondition(); - return {}; - } - - std::vector output; - output.reserve(detectedObjects.size()); - - for (auto& detectedObj : detectedObjects) { - ProcessDetectedObject(frame, detectedObj, camera_id, fireNSmokeRects, output, draw); - } - - if (output.empty()) { - UpdateNoDetectionCondition(); - } - - return output; -} -bool ANSCustomFS::ProcessDetectedObject( - const cv::Mat& frame, - ANSCENTER::Object& detectedObj, - const std::string& camera_id, - const std::vector& fireNSmokeRects, - std::vector& output, cv::Mat& draw) -{ - // Adjust coordinates to frame space - detectedObj.box.x += _detectedArea.x; - detectedObj.box.y += _detectedArea.y; - detectedObj.cameraId = camera_id; - - // Check exclusive ROI overlap - if (IsROIOverlapping(detectedObj.box, _exclusiveROIs, INCLUSIVE_IOU_THRESHOLD)) { - return false; - } - - // Check confidence threshold - if (detectedObj.confidence <= _detectionScoreThreshold) { - UpdateNoDetectionCondition(); - return false; - } - - // Check if fire or smoke - if (!IsFireOrSmoke(detectedObj.classId, detectedObj.confidence)) { - UpdateNoDetectionCondition(); - return false; - } - - // Check for reflection - cv::Mat objectMask = frame(detectedObj.box); - if (detectReflection(objectMask)) { - UpdateNoDetectionCondition(); - return false; - } - - // Check area overlap - float areaOverlap = calculateIoU(_detectedArea, detectedObj.box); - if (areaOverlap >= MAX_AREA_OVERLAP) { - UpdateNoDetectionCondition(); - return false; - } - -#ifdef FNS_DEBUG - cv::Scalar color = (detectedObj.classId == 0) ? - cv::Scalar(0, 255, 255) : cv::Scalar(255, 0, 255); // Yellow/Purple - cv::rectangle(draw, detectedObj.box, color, 2); -#endif - - // Check motion correlation - if (!ValidateMotionCorrelation(fireNSmokeRects)) { - UpdateNoDetectionCondition(); - return false; - } - - if (!IsOverlapping(detectedObj, fireNSmokeRects, 0)) { - UpdateNoDetectionCondition(); - return false; - } - - // Filter validation - if (!ValidateWithFilter(frame, detectedObj, camera_id, output, draw)) - { - return false; - } - - return true; -} -bool ANSCustomFS::ValidateWithFilter( - const cv::Mat& frame, - const ANSCENTER::Object& detectedObj, - const std::string& camera_id, - std::vector& output, cv::Mat& draw) -{ - // Skip filter check after sufficient confirmation frames - if (_realFireCheck > FILTER_VERIFICATION_FRAMES) { - output.push_back(detectedObj); - UpdatePositiveDetection(); - return true; - } - - // Run filter inference - std::vector filteredObjects; - _filter->RunInference(frame, camera_id.c_str(), filteredObjects); - std::vector excludedObjects; - - for (const auto& filteredObj : filteredObjects) { - if (EXCLUDED_FILTER_CLASSES.find(filteredObj.classId) == EXCLUDED_FILTER_CLASSES.end()) { - excludedObjects.push_back(filteredObj); - -#ifdef FNS_DEBUG - cv::rectangle(draw, filteredObj.box, cv::Scalar(0, 255, 0), 2); - cv::putText(draw, filteredObj.className, - cv::Point(filteredObj.box.x, filteredObj.box.y - 10), - cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 255, 0), 2); -#endif - } - } - - // Check if detection overlaps with excluded objects - if (excludedObjects.empty() || !IsOverlapping(detectedObj, excludedObjects, 0)) { - output.push_back(detectedObj); - UpdatePositiveDetection(); - _realFireCheck++; - return true; - } - else { - // Decrement but don't go negative - _realFireCheck = std::max(0, _realFireCheck - 1); - _isRealFireFrame = (_realFireCheck > 0); - return false; - } -} std::vector ANSCustomFS::FindNewDetectedArea( const cv::Mat& frame, @@ -576,6 +532,12 @@ bool ANSCustomFS::Initialize(const std::string& modelDirectory, float detectionS return false; } + // Enable ByteTrack tracker on the detector for persistent track IDs + int trackerResult = _detector->SetTracker(0 /*BYTETRACK*/, 1 /*enable*/); + if (trackerResult != 1) { + std::cerr << "ANSCustomFS::Initialize: Warning - Failed to enable ByteTrack tracker." << std::endl; + } + // Load filter model (COCO general object detector for false positive filtering) float filterScoreThreshold = 0.25f; float filterConfThreshold = 0.5f; @@ -1031,7 +993,7 @@ std::vector ANSCustomFS::RunInference(const cv::Mat& input, const } } -// New helper function to process detected area +// Stage A: Process existing detected area with tracker-based voting std::vector ANSCustomFS::ProcessExistingDetectedArea( const cv::Mat& frame, const std::string& camera_id, @@ -1041,18 +1003,19 @@ std::vector ANSCustomFS::ProcessExistingDetectedArea( cv::Mat activeROI = frame(_detectedArea); - // Detect movement and objects - std::vector movementObjects = FindMovementObjects(frame, camera_id, draw); + // Run detector on ROI — tracker assigns persistent trackIds std::vector detectedObjects; _detector->RunInference(activeROI, camera_id.c_str(), detectedObjects); if (detectedObjects.empty()) { + // Age all existing tracks (missed frame for all) + AgeTracks({}); UpdateNoDetectionCondition(); return output; } - const bool skipMotionCheck = (_motionSpecificity < 0.0f) || (_motionSpecificity >= 1.0f); - const bool validMovement = !movementObjects.empty() && movementObjects.size() < MAX_MOTION_TRACKING; + // Collect detected track IDs this frame for aging + std::unordered_set detectedTrackIds; for (auto& detectedObj : detectedObjects) { // Adjust coordinates to full frame @@ -1060,84 +1023,75 @@ std::vector ANSCustomFS::ProcessExistingDetectedArea( detectedObj.box.y += _detectedArea.y; detectedObj.cameraId = camera_id; - // Skip if overlapping with exclusive ROIs + // 1. Exclusive ROI check — skip if overlapping exclusion zones if (IsROIOverlapping(detectedObj.box, _exclusiveROIs, INCLUSIVE_IOU_THRESHOLD)) { continue; } - // Check confidence thresholds + // 2. Confidence check — fire >= threshold, smoke >= smoke threshold const bool isValidFire = (detectedObj.classId == 0) && (detectedObj.confidence >= _detectionScoreThreshold); const bool isValidSmoke = (detectedObj.classId == 2) && (detectedObj.confidence >= _smokeDetetectionThreshold); if (!isValidFire && !isValidSmoke) { - UpdateNoDetectionCondition(); continue; } - // Check area overlap - const float area_threshold = calculateIoU(_detectedArea, detectedObj.box); - if (area_threshold >= MAX_AREA_OVERLAP) { - UpdateNoDetectionCondition(); - continue; - } + // 3. Update track history with this detection + int trackId = detectedObj.trackId; + detectedTrackIds.insert(trackId); + UpdateTrackHistory(trackId, detectedObj.classId, detectedObj.box); #ifdef FNS_DEBUG + // Draw detection with track info cv::Scalar color = (detectedObj.classId == 0) ? cv::Scalar(0, 255, 255) : cv::Scalar(255, 0, 255); cv::rectangle(draw, detectedObj.box, color, 2); + auto trackIt = _trackHistory.find(trackId); + if (trackIt != _trackHistory.end()) { + std::string label = "T" + std::to_string(trackId) + " " + + std::to_string(trackIt->second.detectedCount) + "/" + + std::to_string(VOTE_THRESHOLD); + cv::putText(draw, label, + cv::Point(detectedObj.box.x, detectedObj.box.y - 10), + cv::FONT_HERSHEY_SIMPLEX, 0.5, color, 2); + } #endif - // Check motion - if (!skipMotionCheck && !validMovement) { - UpdateNoDetectionCondition(); + // 4. Voting check — require consistent detection across frames + if (!IsTrackConfirmed(trackId)) { continue; } - if (!skipMotionCheck && !IsOverlapping(detectedObj, movementObjects, 0)) { - UpdateNoDetectionCondition(); + // 5. Movement check — verify bounding box is changing (not static false positive) + if (!HasBboxMovement(trackId)) { continue; } - // Process valid detection - if (!ProcessValidDetection(frame, camera_id, draw, detectedObj, output)) { + // 6. COCO filter — exclude detections that overlap with known non-fire objects + std::vector excludedObjects = FindExcludedObjects(frame, camera_id, draw); + if (!excludedObjects.empty() && IsOverlapping(detectedObj, excludedObjects, 0)) { + // Detection overlaps with a known object — not fire/smoke + _realFireCheck = std::max(0, _realFireCheck - 1); + if (_realFireCheck <= 0) { + _isRealFireFrame = false; + } continue; } + + // All checks passed — confirmed detection + AddConfirmedDetection(detectedObj, output); + _realFireCheck++; + } + + // Age out tracks not seen this frame + AgeTracks(detectedTrackIds); + + if (output.empty()) { + UpdateNoDetectionCondition(); } return output; } -bool ANSCustomFS::ProcessValidDetection( - const cv::Mat& frame, - const std::string& camera_id, - cv::Mat& draw, - ANSCENTER::Object& detectedObj, - std::vector& output) -{ - if (_realFireCheck > FILTERFRAMES) { - AddConfirmedDetection(detectedObj, output); - return true; - } - - std::vector excludedObjects = FindExcludedObjects(frame, camera_id, draw); - - if (excludedObjects.empty()) { - AddConfirmedDetection(detectedObj, output); - return true; - } - - if (!IsOverlapping(detectedObj, excludedObjects, 0)) { - AddConfirmedDetection(detectedObj, output); - _realFireCheck++; - return true; - } - - _realFireCheck = std::max(0, _realFireCheck - 1); - if (_realFireCheck <= 0) { - _isRealFireFrame = false; - } - return false; -} - void ANSCustomFS::AddConfirmedDetection(ANSCENTER::Object& detectedObj, std::vector& output) { output.push_back(std::move(detectedObj)); _isFireNSmokeDetected = true; diff --git a/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.h b/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.h index 0463f46..5f98e50 100644 --- a/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.h +++ b/ANSCustomFireNSmokeDetection/ANSCustomFireNSmoke.h @@ -1,6 +1,7 @@ #include "ANSLIB.h" #include #include +#include #define RETAINFRAMES 80 #define FILTERFRAMES 10 @@ -13,6 +14,22 @@ class CUSTOM_API ANSCustomFS : public IANSCustomClass int priority; ImageSection(const cv::Rect& r) : region(r), priority(0) {} }; + + // Track record for voting-based detection confirmation + struct TrackRecord { + int trackId{ 0 }; + int classId{ 0 }; // fire=0, smoke=2 + std::deque bboxHistory; // bounding box history within window + int detectedCount{ 0 }; // frames detected in sliding window + int totalFrames{ 0 }; // total frames since track appeared + bool confirmed{ false }; // passed voting threshold + }; + + // Voting mechanism constants + static constexpr int VOTE_WINDOW = 15; + static constexpr int VOTE_THRESHOLD = 8; + static constexpr float BBOX_CHANGE_THRESHOLD = 0.05f; + private: using ANSLIBPtr = std::unique_ptr; @@ -74,6 +91,9 @@ private: float _smokeDetetectionThreshold{ 0 }; float _motionSpecificity{ 0 }; + // Tracker-based voting state + std::unordered_map _trackHistory; + cv::Rect GenerateMinimumSquareBoundingBox(const std::vector& detectedObjects, int minSize = 640); void UpdateNoDetectionCondition(); bool detectStaticFire(std::deque& frameQueue); @@ -101,7 +121,6 @@ private: void ResetDetectionState(); void GetModelParameters(); std::vector ProcessExistingDetectedArea(const cv::Mat& frame, const std::string& camera_id, cv::Mat& draw); - bool ProcessValidDetection(const cv::Mat& frame, const std::string& camera_id, cv::Mat& draw, ANSCENTER::Object& detectedObj, std::vector& output); void AddConfirmedDetection(ANSCENTER::Object& detectedObj, std::vector& output); #ifdef FNS_DEBUG void DisplayDebugFrame(cv::Mat& draw) { @@ -117,21 +136,11 @@ private: int getLowestPriorityRegion(); cv::Rect getRegionByPriority(int priority); - std::vector ProcessExistingDetectedArea( - const cv::Mat& frame, - const std::string& camera_id, - const std::vector& fireNSmokeRects, cv::Mat& draw); - bool ProcessDetectedObject( - const cv::Mat& frame, - ANSCENTER::Object& detectedObj, - const std::string& camera_id, - const std::vector& fireNSmokeRects, - std::vector& output, cv::Mat& draw); - bool ValidateWithFilter( - const cv::Mat& frame, - const ANSCENTER::Object& detectedObj, - const std::string& camera_id, - std::vector& output, cv::Mat& draw); + // Tracker-based voting methods + void UpdateTrackHistory(int trackId, int classId, const cv::Rect& bbox); + bool IsTrackConfirmed(int trackId) const; + bool HasBboxMovement(int trackId) const; + void AgeTracks(const std::unordered_set& detectedTrackIds); std::vector FindNewDetectedArea( const cv::Mat& frame, const std::string& camera_id, cv::Mat& draw); diff --git a/CMakeLists.txt b/CMakeLists.txt index e1b2b4c..4defbc5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,3 +7,9 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) add_subdirectory(ANSCustomHelmetDetection) add_subdirectory(ANSCustomFireNSmokeDetection) add_subdirectory(ANSCustomWeaponDetection) + +# Unit & integration tests (Google Test) +option(BUILD_TESTS "Build unit tests" ON) +if(BUILD_TESTS) + add_subdirectory(tests) +endif() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..71e2ec6 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,31 @@ +project(ANSCustomModels_Tests LANGUAGES CXX) + +# ---------- Google Test (fetched once, shared by all sub-projects) ---------- +include(FetchContent) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +enable_testing() + +# ---------- Common paths (propagated to sub-projects via variables) ---------- +set(ANSLIB_INCLUDE_DIR "C:/Projects/ANLS/ANSLIB/ANSLIB" CACHE PATH "") +set(OPENCV_INCLUDE_DIR "C:/ANSLibs/opencv/include" CACHE PATH "") +set(ANSLIB_LIB_DIR "C:/ProgramData/ANSCENTER/Shared" CACHE PATH "") +set(OPENCV_LIB_DIR "C:/ANSLibs/opencv/x64/vc17/lib" CACHE PATH "") +set(OPENCV_BIN_DIR "C:/ProgramData/ANSCENTER/Shared" CACHE PATH "") +set(TEST_COMMON_DIR "${CMAKE_CURRENT_SOURCE_DIR}" CACHE PATH "") + +# ---------- Place all test .exe files alongside the DLLs they need ---------- +# This ensures custom model DLLs (built by sibling projects) land in the same +# directory as the test executables so Windows can find them at runtime. +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" CACHE PATH "" FORCE) + +# ---------- Sub-project test executables ---------- +add_subdirectory(FireNSmokeDetection) +add_subdirectory(HelmetDetection) +add_subdirectory(WeaponDetection) diff --git a/tests/FireNSmokeDetection/CMakeLists.txt b/tests/FireNSmokeDetection/CMakeLists.txt new file mode 100644 index 0000000..dc66467 --- /dev/null +++ b/tests/FireNSmokeDetection/CMakeLists.txt @@ -0,0 +1,54 @@ +project(FireNSmokeDetection_Tests LANGUAGES CXX) + +add_executable(${PROJECT_NAME} + FireNSmokeDetectionTest.cpp +) + +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) + +target_compile_definitions(${PROJECT_NAME} PRIVATE + WIN32_LEAN_AND_MEAN + NOMINMAX + $<$:_DEBUG> + $<$:NDEBUG> +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${TEST_COMMON_DIR} + ${ANSLIB_INCLUDE_DIR} + ${OPENCV_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ANSCustomFireNSmokeDetection +) + +target_link_directories(${PROJECT_NAME} PRIVATE + ${ANSLIB_LIB_DIR} + ${OPENCV_LIB_DIR} +) + +target_link_libraries(${PROJECT_NAME} PRIVATE + gtest + gtest_main + ANSLIB + opencv_world4130 + ANSCustomFireNSmokeDetection +) + +if(MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /W3 /sdl /permissive-) +endif() + +# Copy required DLLs next to the test executable so Windows can find them +add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + # ANSLIB.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${ANSLIB_LIB_DIR}/ANSLIB.dll" + "$" + # OpenCV DLL + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${OPENCV_BIN_DIR}/opencv_world4130.dll" + "$" + COMMENT "Copying runtime DLLs for ${PROJECT_NAME}" +) + +include(GoogleTest) +gtest_discover_tests(${PROJECT_NAME} DISCOVERY_MODE PRE_TEST) diff --git a/tests/FireNSmokeDetection/FireNSmokeDetectionTest.cpp b/tests/FireNSmokeDetection/FireNSmokeDetectionTest.cpp new file mode 100644 index 0000000..a472e90 --- /dev/null +++ b/tests/FireNSmokeDetection/FireNSmokeDetectionTest.cpp @@ -0,0 +1,192 @@ +#include "TestCommon.h" +#include "ANSCustomFireNSmoke.h" + +// =========================================================================== +// Unit Tests — no model files required +// =========================================================================== + +class FireNSmokeUnitTest : public ::testing::Test { +protected: + ANSCustomFS detector; +}; + +TEST_F(FireNSmokeUnitTest, EmptyFrameReturnsNoDetections) { + cv::Mat empty; + auto results = detector.RunInference(empty); + EXPECT_TRUE(results.empty()); +} + +TEST_F(FireNSmokeUnitTest, TinyFrameReturnsNoDetections) { + cv::Mat tiny = TestUtils::CreateTestFrame(5, 5); + auto results = detector.RunInference(tiny); + EXPECT_TRUE(results.empty()); +} + +TEST_F(FireNSmokeUnitTest, UninitializedDetectorReturnsNoDetections) { + cv::Mat frame = TestUtils::CreateTestFrame(640, 480); + auto results = detector.RunInference(frame); + EXPECT_TRUE(results.empty()); +} + +TEST_F(FireNSmokeUnitTest, RunInferenceWithCameraId) { + cv::Mat frame = TestUtils::CreateTestFrame(640, 480); + auto results = detector.RunInference(frame, "test_cam_01"); + EXPECT_TRUE(results.empty()); +} + +TEST_F(FireNSmokeUnitTest, ConfigureParametersReturnsValidConfig) { + CustomParams params; + bool result = detector.ConfigureParameters(params); + EXPECT_TRUE(result); + + // Should have ExclusiveROIs ROI config + ASSERT_FALSE(params.ROI_Config.empty()); + EXPECT_EQ(params.ROI_Config[0].Name, "ExclusiveROIs"); + EXPECT_TRUE(params.ROI_Config[0].Rectangle); + EXPECT_FALSE(params.ROI_Config[0].Polygon); + EXPECT_FALSE(params.ROI_Config[0].Line); + EXPECT_EQ(params.ROI_Config[0].MinItems, 0); + EXPECT_EQ(params.ROI_Config[0].MaxItems, 20); + + // Should have SmokeScore and Sensitivity parameters + ASSERT_GE(params.Parameters.size(), 2u); + + bool hasSmokeScore = false; + bool hasSensitivity = false; + for (const auto& p : params.Parameters) { + if (p.Name == "SmokeScore") { + hasSmokeScore = true; + EXPECT_EQ(p.DataType, "float"); + EXPECT_EQ(p.MaxValue, 1); + EXPECT_EQ(p.MinValue, 0); + } + if (p.Name == "Sensitivity") { + hasSensitivity = true; + EXPECT_EQ(p.DataType, "float"); + } + } + EXPECT_TRUE(hasSmokeScore) << "Missing SmokeScore parameter"; + EXPECT_TRUE(hasSensitivity) << "Missing Sensitivity parameter"; +} + +TEST_F(FireNSmokeUnitTest, DestroySucceeds) { + EXPECT_TRUE(detector.Destroy()); +} + +TEST_F(FireNSmokeUnitTest, DestroyCanBeCalledMultipleTimes) { + EXPECT_TRUE(detector.Destroy()); + EXPECT_TRUE(detector.Destroy()); +} + +TEST_F(FireNSmokeUnitTest, InitializeWithInvalidDirectoryFails) { + std::string labelMap; + bool result = detector.Initialize("C:\\NonExistent\\Path\\Model", 0.5f, labelMap); + EXPECT_FALSE(result); +} + +TEST_F(FireNSmokeUnitTest, OptimizeBeforeInitializeReturnsFalse) { + EXPECT_FALSE(detector.OptimizeModel(true)); +} + +// =========================================================================== +// Integration Tests — require model files on disk +// =========================================================================== + +class FireNSmokeIntegrationTest : public ::testing::Test { +protected: + ANSCustomFS detector; + std::string labelMap; + std::vector classes; + + void SetUp() override { + if (!TestConfig::ModelExists(TestConfig::FIRE_SMOKE_MODEL_DIR)) { + GTEST_SKIP() << "Fire/Smoke model not found at: " << TestConfig::FIRE_SMOKE_MODEL_DIR; + } + bool ok = detector.Initialize(TestConfig::FIRE_SMOKE_MODEL_DIR, 0.5f, labelMap); + ASSERT_TRUE(ok) << "Failed to initialize Fire/Smoke detector"; + classes = TestUtils::ParseLabelMap(labelMap); + } + + void TearDown() override { + detector.Destroy(); + } +}; + +TEST_F(FireNSmokeIntegrationTest, InitializeProducesLabelMap) { + EXPECT_FALSE(labelMap.empty()); + EXPECT_FALSE(classes.empty()); +} + +TEST_F(FireNSmokeIntegrationTest, InferenceOnSolidFrameReturnsNoDetections) { + cv::Mat frame = TestUtils::CreateTestFrame(1920, 1080); + auto results = detector.RunInference(frame, "test_cam"); + EXPECT_TRUE(results.empty()) << "Solid gray frame should not trigger fire/smoke"; +} + +TEST_F(FireNSmokeIntegrationTest, InferenceOnSmallFrame) { + cv::Mat frame = TestUtils::CreateTestFrame(320, 240); + auto results = detector.RunInference(frame, "test_cam"); + SUCCEED(); +} + +TEST_F(FireNSmokeIntegrationTest, InferenceOnLargeFrame) { + cv::Mat frame = TestUtils::CreateTestFrame(3840, 2160); + auto results = detector.RunInference(frame, "test_cam"); + SUCCEED(); +} + +TEST_F(FireNSmokeIntegrationTest, DetectionResultFieldsAreValid) { + if (!TestConfig::VideoExists(TestConfig::FIRE_SMOKE_VIDEO)) { + GTEST_SKIP() << "Fire/Smoke test video not found"; + } + + cv::VideoCapture cap(TestConfig::FIRE_SMOKE_VIDEO); + ASSERT_TRUE(cap.isOpened()); + + bool detectionFound = false; + for (int i = 0; i < 300 && !detectionFound; i++) { + cv::Mat frame; + if (!cap.read(frame)) break; + + auto results = detector.RunInference(frame, "test_cam"); + for (const auto& obj : results) { + detectionFound = true; + EXPECT_GE(obj.confidence, 0.0f); + EXPECT_LE(obj.confidence, 1.0f); + EXPECT_GE(obj.box.width, 0); + EXPECT_GE(obj.box.height, 0); + EXPECT_TRUE(obj.classId == 0 || obj.classId == 2) + << "Expected fire (0) or smoke (2), got classId=" << obj.classId; + } + } + cap.release(); +} + +TEST_F(FireNSmokeIntegrationTest, PerformanceBenchmark) { + if (!TestConfig::VideoExists(TestConfig::FIRE_SMOKE_VIDEO)) { + GTEST_SKIP() << "Fire/Smoke test video not found"; + } + + auto [totalDetections, avgMs] = TestUtils::RunVideoFrames(detector, TestConfig::FIRE_SMOKE_VIDEO, 100); + ASSERT_GE(totalDetections, 0) << "Video could not be opened"; + + std::cout << "[FireNSmoke] 100 frames: avg=" << avgMs << "ms/frame, " + << "detections=" << totalDetections << std::endl; + + EXPECT_LT(avgMs, 200.0) << "Average inference time exceeds 200ms"; +} + +TEST_F(FireNSmokeIntegrationTest, ThreadSafetyConcurrentInference) { + cv::Mat frame1 = TestUtils::CreateTestFrame(640, 480, cv::Scalar(100, 100, 100)); + cv::Mat frame2 = TestUtils::CreateTestFrame(640, 480, cv::Scalar(200, 200, 200)); + + std::vector results1, results2; + + std::thread t1([&]() { results1 = detector.RunInference(frame1, "cam_1"); }); + std::thread t2([&]() { results2 = detector.RunInference(frame2, "cam_2"); }); + + t1.join(); + t2.join(); + + SUCCEED(); +} diff --git a/tests/HelmetDetection/CMakeLists.txt b/tests/HelmetDetection/CMakeLists.txt new file mode 100644 index 0000000..8f1c496 --- /dev/null +++ b/tests/HelmetDetection/CMakeLists.txt @@ -0,0 +1,54 @@ +project(HelmetDetection_Tests LANGUAGES CXX) + +add_executable(${PROJECT_NAME} + HelmetDetectionTest.cpp +) + +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) + +target_compile_definitions(${PROJECT_NAME} PRIVATE + WIN32_LEAN_AND_MEAN + NOMINMAX + $<$:_DEBUG> + $<$:NDEBUG> +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${TEST_COMMON_DIR} + ${ANSLIB_INCLUDE_DIR} + ${OPENCV_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ANSCustomHelmetDetection +) + +target_link_directories(${PROJECT_NAME} PRIVATE + ${ANSLIB_LIB_DIR} + ${OPENCV_LIB_DIR} +) + +target_link_libraries(${PROJECT_NAME} PRIVATE + gtest + gtest_main + ANSLIB + opencv_world4130 + ANSCustomHelmetDetection +) + +if(MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /W3 /sdl /permissive-) +endif() + +# Copy required DLLs next to the test executable so Windows can find them +add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + # ANSLIB.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${ANSLIB_LIB_DIR}/ANSLIB.dll" + "$" + # OpenCV DLL + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${OPENCV_BIN_DIR}/opencv_world4130.dll" + "$" + COMMENT "Copying runtime DLLs for ${PROJECT_NAME}" +) + +include(GoogleTest) +gtest_discover_tests(${PROJECT_NAME} DISCOVERY_MODE PRE_TEST) diff --git a/tests/HelmetDetection/HelmetDetectionTest.cpp b/tests/HelmetDetection/HelmetDetectionTest.cpp new file mode 100644 index 0000000..87d687f --- /dev/null +++ b/tests/HelmetDetection/HelmetDetectionTest.cpp @@ -0,0 +1,162 @@ +#include "TestCommon.h" +#include "ANSCustomCodeHelmetDetection.h" + +// =========================================================================== +// Unit Tests — no model files required +// =========================================================================== + +class HelmetUnitTest : public ::testing::Test { +protected: + ANSCustomHMD detector; +}; + +TEST_F(HelmetUnitTest, EmptyFrameReturnsNoDetections) { + cv::Mat empty; + auto results = detector.RunInference(empty); + EXPECT_TRUE(results.empty()); +} + +TEST_F(HelmetUnitTest, TinyFrameReturnsNoDetections) { + cv::Mat tiny = TestUtils::CreateTestFrame(5, 5); + auto results = detector.RunInference(tiny); + EXPECT_TRUE(results.empty()); +} + +TEST_F(HelmetUnitTest, UninitializedDetectorReturnsNoDetections) { + cv::Mat frame = TestUtils::CreateTestFrame(640, 480); + auto results = detector.RunInference(frame); + EXPECT_TRUE(results.empty()); +} + +TEST_F(HelmetUnitTest, RunInferenceWithCameraId) { + cv::Mat frame = TestUtils::CreateTestFrame(640, 480); + auto results = detector.RunInference(frame, "test_cam_01"); + EXPECT_TRUE(results.empty()); +} + +TEST_F(HelmetUnitTest, ConfigureParametersReturnsValidConfig) { + CustomParams params; + bool result = detector.ConfigureParameters(params); + EXPECT_TRUE(result); +} + +TEST_F(HelmetUnitTest, DestroySucceeds) { + EXPECT_TRUE(detector.Destroy()); +} + +TEST_F(HelmetUnitTest, DestroyCanBeCalledMultipleTimes) { + EXPECT_TRUE(detector.Destroy()); + EXPECT_TRUE(detector.Destroy()); +} + +TEST_F(HelmetUnitTest, InitializeWithInvalidDirectoryFails) { + std::string labelMap; + bool result = detector.Initialize("C:\\NonExistent\\Path\\Model", 0.5f, labelMap); + EXPECT_FALSE(result); +} + +TEST_F(HelmetUnitTest, OptimizeBeforeInitializeReturnsFalse) { + EXPECT_FALSE(detector.OptimizeModel(true)); +} + +// =========================================================================== +// Integration Tests — require model files on disk +// =========================================================================== + +class HelmetIntegrationTest : public ::testing::Test { +protected: + ANSCustomHMD detector; + std::string labelMap; + std::vector classes; + + void SetUp() override { + if (!TestConfig::ModelExists(TestConfig::HELMET_MODEL_DIR)) { + GTEST_SKIP() << "Helmet model not found at: " << TestConfig::HELMET_MODEL_DIR; + } + bool ok = detector.Initialize(TestConfig::HELMET_MODEL_DIR, 0.6f, labelMap); + ASSERT_TRUE(ok) << "Failed to initialize Helmet detector"; + classes = TestUtils::ParseLabelMap(labelMap); + } + + void TearDown() override { + detector.Destroy(); + } +}; + +TEST_F(HelmetIntegrationTest, InitializeProducesLabelMap) { + EXPECT_FALSE(labelMap.empty()); + EXPECT_FALSE(classes.empty()); +} + +TEST_F(HelmetIntegrationTest, InferenceOnSolidFrameReturnsNoDetections) { + cv::Mat frame = TestUtils::CreateTestFrame(1920, 1080); + auto results = detector.RunInference(frame, "test_cam"); + EXPECT_TRUE(results.empty()) << "Solid gray frame should not trigger helmet detection"; +} + +TEST_F(HelmetIntegrationTest, InferenceOnSmallFrame) { + cv::Mat frame = TestUtils::CreateTestFrame(320, 240); + auto results = detector.RunInference(frame, "test_cam"); + SUCCEED(); +} + +TEST_F(HelmetIntegrationTest, InferenceOnLargeFrame) { + cv::Mat frame = TestUtils::CreateTestFrame(3840, 2160); + auto results = detector.RunInference(frame, "test_cam"); + SUCCEED(); +} + +TEST_F(HelmetIntegrationTest, DetectionResultFieldsAreValid) { + if (!TestConfig::VideoExists(TestConfig::HELMET_VIDEO)) { + GTEST_SKIP() << "Helmet test video not found"; + } + + cv::VideoCapture cap(TestConfig::HELMET_VIDEO); + ASSERT_TRUE(cap.isOpened()); + + bool detectionFound = false; + for (int i = 0; i < 300 && !detectionFound; i++) { + cv::Mat frame; + if (!cap.read(frame)) break; + + auto results = detector.RunInference(frame, "test_cam"); + for (const auto& obj : results) { + detectionFound = true; + EXPECT_GE(obj.confidence, 0.0f); + EXPECT_LE(obj.confidence, 1.0f); + EXPECT_GE(obj.box.width, 0); + EXPECT_GE(obj.box.height, 0); + EXPECT_GE(obj.classId, 0); + } + } + cap.release(); +} + +TEST_F(HelmetIntegrationTest, PerformanceBenchmark) { + if (!TestConfig::VideoExists(TestConfig::HELMET_VIDEO)) { + GTEST_SKIP() << "Helmet test video not found"; + } + + auto [totalDetections, avgMs] = TestUtils::RunVideoFrames(detector, TestConfig::HELMET_VIDEO, 100); + ASSERT_GE(totalDetections, 0) << "Video could not be opened"; + + std::cout << "[Helmet] 100 frames: avg=" << avgMs << "ms/frame, " + << "detections=" << totalDetections << std::endl; + + EXPECT_LT(avgMs, 200.0) << "Average inference time exceeds 200ms"; +} + +TEST_F(HelmetIntegrationTest, ThreadSafetyConcurrentInference) { + cv::Mat frame1 = TestUtils::CreateTestFrame(640, 480, cv::Scalar(100, 100, 100)); + cv::Mat frame2 = TestUtils::CreateTestFrame(640, 480, cv::Scalar(200, 200, 200)); + + std::vector results1, results2; + + std::thread t1([&]() { results1 = detector.RunInference(frame1, "cam_1"); }); + std::thread t2([&]() { results2 = detector.RunInference(frame2, "cam_2"); }); + + t1.join(); + t2.join(); + + SUCCEED(); +} diff --git a/tests/TestCommon.h b/tests/TestCommon.h new file mode 100644 index 0000000..926045e --- /dev/null +++ b/tests/TestCommon.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "ANSLIB.h" + +// --------------------------------------------------------------------------- +// Model directory paths — update these to match your local environment +// --------------------------------------------------------------------------- +namespace TestConfig { + +inline const std::string FIRE_SMOKE_MODEL_DIR = + "C:\\Programs\\DemoAssets\\ModelsForANSVIS\\ANS_FireSmoke_v2.0"; +inline const std::string HELMET_MODEL_DIR = + "C:\\Programs\\DemoAssets\\ModelsForANSVIS\\ANS_Helmet(GPU)_v1.0"; +inline const std::string WEAPON_MODEL_DIR = + "C:\\Programs\\DemoAssets\\ModelsForANSVIS\\ANS_WeaponDetection(GPU)_1.0"; + +inline const std::string FIRE_SMOKE_VIDEO = + "C:\\Programs\\DemoAssets\\Videos\\FireNSmoke\\ANSFireFull.mp4"; +inline const std::string HELMET_VIDEO = + "C:\\Programs\\DemoAssets\\Videos\\Helmet\\HM2.mp4"; +inline const std::string WEAPON_VIDEO = + "C:\\Programs\\DemoAssets\\Videos\\Weapon\\AK47 Glock.mp4"; + +// Check if model directory exists +inline bool ModelExists(const std::string& path) { + return std::filesystem::exists(path) && std::filesystem::is_directory(path); +} + +// Check if video file exists +inline bool VideoExists(const std::string& path) { + return std::filesystem::exists(path); +} + +} // namespace TestConfig + +// --------------------------------------------------------------------------- +// Helper utilities +// --------------------------------------------------------------------------- +namespace TestUtils { + +// Parse comma-separated label map into vector of class names +inline std::vector ParseLabelMap(const std::string& labelMap) { + std::vector classes; + std::stringstream ss(labelMap); + std::string item; + while (std::getline(ss, item, ',')) { + classes.push_back(item); + } + return classes; +} + +// Create a solid-color test frame (no model required) +inline cv::Mat CreateTestFrame(int width, int height, cv::Scalar color = cv::Scalar(128, 128, 128)) { + return cv::Mat(height, width, CV_8UC3, color); +} + +// Create a frame with a bright red/orange region to simulate fire-like colors +inline cv::Mat CreateFireLikeFrame(int width, int height) { + cv::Mat frame(height, width, CV_8UC3, cv::Scalar(50, 50, 50)); + cv::Rect fireRegion(width / 4, height / 4, width / 2, height / 2); + frame(fireRegion) = cv::Scalar(0, 80, 255); // BGR: orange-red + return frame; +} + +// Create a frame with a gray haze to simulate smoke-like colors +inline cv::Mat CreateSmokeLikeFrame(int width, int height) { + cv::Mat frame(height, width, CV_8UC3, cv::Scalar(30, 30, 30)); + cv::Rect smokeRegion(width / 4, height / 4, width / 2, height / 2); + frame(smokeRegion) = cv::Scalar(180, 180, 190); // BGR: light gray + return frame; +} + +// Measure inference time in milliseconds +template +double MeasureMs(Func&& func) { + auto start = std::chrono::high_resolution_clock::now(); + func(); + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start).count(); +} + +// Run N frames of video through a detector, return (totalDetections, avgMs) +template +std::pair RunVideoFrames(Detector& detector, const std::string& videoPath, int maxFrames) { + cv::VideoCapture cap(videoPath); + if (!cap.isOpened()) return { -1, 0.0 }; + + int totalDetections = 0; + double totalMs = 0.0; + int frameCount = 0; + + while (frameCount < maxFrames) { + cv::Mat frame; + if (!cap.read(frame)) break; + + double ms = MeasureMs([&]() { + auto results = detector.RunInference(frame); + totalDetections += static_cast(results.size()); + }); + totalMs += ms; + frameCount++; + } + + cap.release(); + double avgMs = (frameCount > 0) ? totalMs / frameCount : 0.0; + return { totalDetections, avgMs }; +} + +} // namespace TestUtils diff --git a/tests/WeaponDetection/CMakeLists.txt b/tests/WeaponDetection/CMakeLists.txt new file mode 100644 index 0000000..f88259b --- /dev/null +++ b/tests/WeaponDetection/CMakeLists.txt @@ -0,0 +1,54 @@ +project(WeaponDetection_Tests LANGUAGES CXX) + +add_executable(${PROJECT_NAME} + WeaponDetectionTest.cpp +) + +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) + +target_compile_definitions(${PROJECT_NAME} PRIVATE + WIN32_LEAN_AND_MEAN + NOMINMAX + $<$:_DEBUG> + $<$:NDEBUG> +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${TEST_COMMON_DIR} + ${ANSLIB_INCLUDE_DIR} + ${OPENCV_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ANSCustomWeaponDetection +) + +target_link_directories(${PROJECT_NAME} PRIVATE + ${ANSLIB_LIB_DIR} + ${OPENCV_LIB_DIR} +) + +target_link_libraries(${PROJECT_NAME} PRIVATE + gtest + gtest_main + ANSLIB + opencv_world4130 + ANSCustomWeaponDetection +) + +if(MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /W3 /sdl /permissive-) +endif() + +# Copy required DLLs next to the test executable so Windows can find them +add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + # ANSLIB.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${ANSLIB_LIB_DIR}/ANSLIB.dll" + "$" + # OpenCV DLL + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${OPENCV_BIN_DIR}/opencv_world4130.dll" + "$" + COMMENT "Copying runtime DLLs for ${PROJECT_NAME}" +) + +include(GoogleTest) +gtest_discover_tests(${PROJECT_NAME} DISCOVERY_MODE PRE_TEST) diff --git a/tests/WeaponDetection/WeaponDetectionTest.cpp b/tests/WeaponDetection/WeaponDetectionTest.cpp new file mode 100644 index 0000000..d9ece80 --- /dev/null +++ b/tests/WeaponDetection/WeaponDetectionTest.cpp @@ -0,0 +1,162 @@ +#include "TestCommon.h" +#include "ANSCustomCodeWeaponDetection.h" + +// =========================================================================== +// Unit Tests — no model files required +// =========================================================================== + +class WeaponUnitTest : public ::testing::Test { +protected: + ANSCustomWD detector; +}; + +TEST_F(WeaponUnitTest, EmptyFrameReturnsNoDetections) { + cv::Mat empty; + auto results = detector.RunInference(empty); + EXPECT_TRUE(results.empty()); +} + +TEST_F(WeaponUnitTest, TinyFrameReturnsNoDetections) { + cv::Mat tiny = TestUtils::CreateTestFrame(5, 5); + auto results = detector.RunInference(tiny); + EXPECT_TRUE(results.empty()); +} + +TEST_F(WeaponUnitTest, UninitializedDetectorReturnsNoDetections) { + cv::Mat frame = TestUtils::CreateTestFrame(640, 480); + auto results = detector.RunInference(frame); + EXPECT_TRUE(results.empty()); +} + +TEST_F(WeaponUnitTest, RunInferenceWithCameraId) { + cv::Mat frame = TestUtils::CreateTestFrame(640, 480); + auto results = detector.RunInference(frame, "test_cam_01"); + EXPECT_TRUE(results.empty()); +} + +TEST_F(WeaponUnitTest, ConfigureParametersReturnsValidConfig) { + CustomParams params; + bool result = detector.ConfigureParameters(params); + EXPECT_TRUE(result); +} + +TEST_F(WeaponUnitTest, DestroySucceeds) { + EXPECT_TRUE(detector.Destroy()); +} + +TEST_F(WeaponUnitTest, DestroyCanBeCalledMultipleTimes) { + EXPECT_TRUE(detector.Destroy()); + EXPECT_TRUE(detector.Destroy()); +} + +TEST_F(WeaponUnitTest, InitializeWithInvalidDirectoryFails) { + std::string labelMap; + bool result = detector.Initialize("C:\\NonExistent\\Path\\Model", 0.5f, labelMap); + EXPECT_FALSE(result); +} + +TEST_F(WeaponUnitTest, OptimizeBeforeInitializeReturnsFalse) { + EXPECT_FALSE(detector.OptimizeModel(true)); +} + +// =========================================================================== +// Integration Tests — require model files on disk +// =========================================================================== + +class WeaponIntegrationTest : public ::testing::Test { +protected: + ANSCustomWD detector; + std::string labelMap; + std::vector classes; + + void SetUp() override { + if (!TestConfig::ModelExists(TestConfig::WEAPON_MODEL_DIR)) { + GTEST_SKIP() << "Weapon model not found at: " << TestConfig::WEAPON_MODEL_DIR; + } + bool ok = detector.Initialize(TestConfig::WEAPON_MODEL_DIR, 0.6f, labelMap); + ASSERT_TRUE(ok) << "Failed to initialize Weapon detector"; + classes = TestUtils::ParseLabelMap(labelMap); + } + + void TearDown() override { + detector.Destroy(); + } +}; + +TEST_F(WeaponIntegrationTest, InitializeProducesLabelMap) { + EXPECT_FALSE(labelMap.empty()); + EXPECT_FALSE(classes.empty()); +} + +TEST_F(WeaponIntegrationTest, InferenceOnSolidFrameReturnsNoDetections) { + cv::Mat frame = TestUtils::CreateTestFrame(1920, 1080); + auto results = detector.RunInference(frame, "test_cam"); + EXPECT_TRUE(results.empty()) << "Solid gray frame should not trigger weapon detection"; +} + +TEST_F(WeaponIntegrationTest, InferenceOnSmallFrame) { + cv::Mat frame = TestUtils::CreateTestFrame(320, 240); + auto results = detector.RunInference(frame, "test_cam"); + SUCCEED(); +} + +TEST_F(WeaponIntegrationTest, InferenceOnLargeFrame) { + cv::Mat frame = TestUtils::CreateTestFrame(3840, 2160); + auto results = detector.RunInference(frame, "test_cam"); + SUCCEED(); +} + +TEST_F(WeaponIntegrationTest, DetectionResultFieldsAreValid) { + if (!TestConfig::VideoExists(TestConfig::WEAPON_VIDEO)) { + GTEST_SKIP() << "Weapon test video not found"; + } + + cv::VideoCapture cap(TestConfig::WEAPON_VIDEO); + ASSERT_TRUE(cap.isOpened()); + + bool detectionFound = false; + for (int i = 0; i < 300 && !detectionFound; i++) { + cv::Mat frame; + if (!cap.read(frame)) break; + + auto results = detector.RunInference(frame, "test_cam"); + for (const auto& obj : results) { + detectionFound = true; + EXPECT_GE(obj.confidence, 0.0f); + EXPECT_LE(obj.confidence, 1.0f); + EXPECT_GE(obj.box.width, 0); + EXPECT_GE(obj.box.height, 0); + EXPECT_GE(obj.classId, 0); + } + } + cap.release(); +} + +TEST_F(WeaponIntegrationTest, PerformanceBenchmark) { + if (!TestConfig::VideoExists(TestConfig::WEAPON_VIDEO)) { + GTEST_SKIP() << "Weapon test video not found"; + } + + auto [totalDetections, avgMs] = TestUtils::RunVideoFrames(detector, TestConfig::WEAPON_VIDEO, 100); + ASSERT_GE(totalDetections, 0) << "Video could not be opened"; + + std::cout << "[Weapon] 100 frames: avg=" << avgMs << "ms/frame, " + << "detections=" << totalDetections << std::endl; + + EXPECT_LT(avgMs, 200.0) << "Average inference time exceeds 200ms"; +} + +TEST_F(WeaponIntegrationTest, ThreadSafetyConcurrentInference) { + cv::Mat frame1 = TestUtils::CreateTestFrame(640, 480, cv::Scalar(100, 100, 100)); + cv::Mat frame2 = TestUtils::CreateTestFrame(640, 480, cv::Scalar(200, 200, 200)); + + std::vector results1, results2; + + std::thread t1([&]() { results1 = detector.RunInference(frame1, "cam_1"); }); + std::thread t2([&]() { results2 = detector.RunInference(frame2, "cam_2"); }); + + t1.join(); + t2.join(); + + SUCCEED(); +}