Files
ANSCORE/tests/ANSODEngine-UnitTest/ANSSAM3-UnitTest.cpp

155 lines
6.2 KiB
C++
Raw Normal View History

2026-03-29 08:45:38 +11:00
// Separate translation unit for ANSSAM3 (TensorRT) tests.
// TensorRT/CUDA headers conflict with Windows SDK (ACCESS_MASK ambiguous symbol)
// when included in the same .cpp as ONNX Runtime headers, so we isolate them here.
#include "ANSSAM3.h"
#include <iostream>
#include <chrono>
#include <opencv2/opencv.hpp>
int SAM3TRT_UnitTest()
{
std::string videoFile = "E:\\Programs\\DemoAssets\\Videos\\video_20.mp4";
std::string modelFolder = "C:\\Projects\\ANSVIS\\Models\\ANS_SAM_v3.0";
ANSCENTER::ANSSAM3 infHandle;
ANSCENTER::ModelConfig modelConfig;
modelConfig.modelConfThreshold = 0.5f;
std::string labelmap;
if (!infHandle.LoadModelFromFolder("", modelConfig, "anssam3", "", modelFolder, labelmap)) {
std::cerr << "SAM3TRT_UnitTest: LoadModelFromFolder failed\n";
return -1;
}
infHandle.SetPrompt("person");
cv::VideoCapture capture(videoFile);
if (!capture.isOpened()) {
std::cerr << "SAM3TRT_UnitTest: could not open video file\n";
return -1;
}
while (true) {
cv::Mat frame;
if (!capture.read(frame)) {
std::cout << "\nEnd of video.\n";
break;
}
auto start = std::chrono::system_clock::now();
std::vector<ANSCENTER::Object> masks = infHandle.RunInference(frame);
auto end = std::chrono::system_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
printf("Time = %lld ms\n", static_cast<long long int>(elapsed.count()));
for (size_t i = 0; i < masks.size(); i++) {
cv::rectangle(frame, masks[i].box, 123, 2);
}
cv::imshow("SAM3 TensorRT Test", frame);
if (cv::waitKey(30) == 27) break;
}
capture.release();
cv::destroyAllWindows();
infHandle.Destroy();
std::cout << "SAM3TRT_UnitTest: done.\n";
return 0;
}
int SAM3TRT_ImageTest()
{
std::string modelFilePath = "C:\\Projects\\ANSVIS\\Models\\ANS_SAM_v3.0.zip";
std::string modelFolder = "C:\\Projects\\ANSVIS\\Models\\ANS_SAM_v3.0";
std::string imageFile = "C:\\Projects\\Research\\sam3onnx\\sam3-onnx\\images\\dog.jpg";
ANSCENTER::ANSSAM3 infHandle;
ANSCENTER::ModelConfig modelConfig;
modelConfig.modelConfThreshold = 0.5f;
std::string labelmap;
std::string licenseKey = "";
std::string modelZipFilePassword = "";
if (!infHandle.Initialize(licenseKey, modelConfig, modelFilePath, modelZipFilePassword, labelmap)) {
std::cerr << "SAM3TRT_ImageTest: LoadModelFromFolder failed\n";
return -1;
}
infHandle.SetPrompt("dog");
cv::Mat image = cv::imread(imageFile);
if (image.empty()) {
std::cerr << "SAM3TRT_ImageTest: could not read image: " << imageFile << "\n";
return -1;
}
const int NUM_RUNS = 5;
for (int run = 0; run < NUM_RUNS; run++) {
auto start = std::chrono::system_clock::now();
std::vector<ANSCENTER::Object> results = infHandle.RunInference(image);
auto end = std::chrono::system_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cout << "SAM3TRT_ImageTest run " << (run + 1) << "/" << NUM_RUNS
<< ": " << results.size() << " detections in "
<< elapsed.count() << " ms\n";
if (run == NUM_RUNS - 1) {
for (size_t i = 0; i < results.size(); i++) {
const auto& obj = results[i];
std::cout << " [" << i << "] box=" << obj.box
<< " conf=" << obj.confidence
<< " polygon=" << obj.polygon.size() << " pts\n";
// Draw bounding box
cv::Scalar boxColor(0, 255, 0);
cv::rectangle(image, obj.box, boxColor, 2);
// Draw label
std::string label = obj.className + " " + std::to_string(obj.confidence).substr(0, 4);
int baseline = 0;
cv::Size textSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseline);
cv::rectangle(image,
cv::Point(obj.box.x, obj.box.y - textSize.height - 4),
cv::Point(obj.box.x + textSize.width, obj.box.y),
boxColor, cv::FILLED);
cv::putText(image, label, cv::Point(obj.box.x, obj.box.y - 2),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0), 1, cv::LINE_AA);
// Draw polygon (normalized coordinates)
if (obj.polygon.size() >= 3) {
std::vector<cv::Point> polyPts;
polyPts.reserve(obj.polygon.size());
for (const auto& pt : obj.polygon) {
polyPts.emplace_back(
static_cast<int>(pt.x * image.cols),
static_cast<int>(pt.y * image.rows));
}
cv::Mat overlay = image.clone();
std::vector<std::vector<cv::Point>> polys = { polyPts };
cv::Scalar polyColor((i * 67 + 50) % 256, (i * 123 + 100) % 256, (i * 37 + 150) % 256);
cv::fillPoly(overlay, polys, polyColor);
cv::addWeighted(overlay, 0.4, image, 0.6, 0, image);
cv::polylines(image, polys, true, polyColor, 2, cv::LINE_AA);
}
// Mask overlay fallback
else if (!obj.mask.empty()) {
cv::Mat colorMask(obj.mask.size(), CV_8UC3,
cv::Scalar((i * 67 + 50) % 256, (i * 123 + 100) % 256, (i * 37 + 150) % 256));
cv::Mat roiImg = image(obj.box);
cv::Mat maskBool;
if (obj.mask.type() != CV_8UC1)
obj.mask.convertTo(maskBool, CV_8UC1, 255.0);
else
maskBool = obj.mask;
colorMask.copyTo(roiImg, maskBool);
cv::addWeighted(roiImg, 0.4, image(obj.box), 0.6, 0, image(obj.box));
}
}
}
}
cv::imshow("SAM3 TRT Image Test", image);
cv::waitKey(0);
cv::destroyAllWindows();
infHandle.Destroy();
std::cout << "SAM3TRT_ImageTest: done.\n";
return 0;
}