155 lines
6.2 KiB
C++
155 lines
6.2 KiB
C++
// Separate translation unit for ANSSAM3 (TensorRT) tests.
|
|
// TensorRT/CUDA headers conflict with Windows SDK (ACCESS_MASK ambiguous symbol)
|
|
// when included in the same .cpp as ONNX Runtime headers, so we isolate them here.
|
|
|
|
#include "ANSSAM3.h"
|
|
#include <iostream>
|
|
#include <chrono>
|
|
#include <opencv2/opencv.hpp>
|
|
|
|
int SAM3TRT_UnitTest()
|
|
{
|
|
std::string videoFile = "E:\\Programs\\DemoAssets\\Videos\\video_20.mp4";
|
|
std::string modelFolder = "C:\\Projects\\ANSVIS\\Models\\ANS_SAM_v3.0";
|
|
|
|
ANSCENTER::ANSSAM3 infHandle;
|
|
ANSCENTER::ModelConfig modelConfig;
|
|
modelConfig.modelConfThreshold = 0.5f;
|
|
std::string labelmap;
|
|
|
|
if (!infHandle.LoadModelFromFolder("", modelConfig, "anssam3", "", modelFolder, labelmap)) {
|
|
std::cerr << "SAM3TRT_UnitTest: LoadModelFromFolder failed\n";
|
|
return -1;
|
|
}
|
|
infHandle.SetPrompt("person");
|
|
|
|
cv::VideoCapture capture(videoFile);
|
|
if (!capture.isOpened()) {
|
|
std::cerr << "SAM3TRT_UnitTest: could not open video file\n";
|
|
return -1;
|
|
}
|
|
|
|
while (true) {
|
|
cv::Mat frame;
|
|
if (!capture.read(frame)) {
|
|
std::cout << "\nEnd of video.\n";
|
|
break;
|
|
}
|
|
auto start = std::chrono::system_clock::now();
|
|
std::vector<ANSCENTER::Object> masks = infHandle.RunInference(frame);
|
|
auto end = std::chrono::system_clock::now();
|
|
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
|
printf("Time = %lld ms\n", static_cast<long long int>(elapsed.count()));
|
|
|
|
for (size_t i = 0; i < masks.size(); i++) {
|
|
cv::rectangle(frame, masks[i].box, 123, 2);
|
|
}
|
|
cv::imshow("SAM3 TensorRT Test", frame);
|
|
if (cv::waitKey(30) == 27) break;
|
|
}
|
|
capture.release();
|
|
cv::destroyAllWindows();
|
|
infHandle.Destroy();
|
|
std::cout << "SAM3TRT_UnitTest: done.\n";
|
|
return 0;
|
|
}
|
|
|
|
int SAM3TRT_ImageTest()
|
|
{
|
|
std::string modelFilePath = "C:\\Projects\\ANSVIS\\Models\\ANS_SAM_v3.0.zip";
|
|
|
|
std::string modelFolder = "C:\\Projects\\ANSVIS\\Models\\ANS_SAM_v3.0";
|
|
std::string imageFile = "C:\\Projects\\Research\\sam3onnx\\sam3-onnx\\images\\dog.jpg";
|
|
|
|
ANSCENTER::ANSSAM3 infHandle;
|
|
ANSCENTER::ModelConfig modelConfig;
|
|
modelConfig.modelConfThreshold = 0.5f;
|
|
std::string labelmap;
|
|
std::string licenseKey = "";
|
|
std::string modelZipFilePassword = "";
|
|
if (!infHandle.Initialize(licenseKey, modelConfig, modelFilePath, modelZipFilePassword, labelmap)) {
|
|
std::cerr << "SAM3TRT_ImageTest: LoadModelFromFolder failed\n";
|
|
return -1;
|
|
}
|
|
infHandle.SetPrompt("dog");
|
|
|
|
cv::Mat image = cv::imread(imageFile);
|
|
if (image.empty()) {
|
|
std::cerr << "SAM3TRT_ImageTest: could not read image: " << imageFile << "\n";
|
|
return -1;
|
|
}
|
|
|
|
const int NUM_RUNS = 5;
|
|
for (int run = 0; run < NUM_RUNS; run++) {
|
|
auto start = std::chrono::system_clock::now();
|
|
std::vector<ANSCENTER::Object> results = infHandle.RunInference(image);
|
|
auto end = std::chrono::system_clock::now();
|
|
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
|
|
|
std::cout << "SAM3TRT_ImageTest run " << (run + 1) << "/" << NUM_RUNS
|
|
<< ": " << results.size() << " detections in "
|
|
<< elapsed.count() << " ms\n";
|
|
|
|
if (run == NUM_RUNS - 1) {
|
|
for (size_t i = 0; i < results.size(); i++) {
|
|
const auto& obj = results[i];
|
|
std::cout << " [" << i << "] box=" << obj.box
|
|
<< " conf=" << obj.confidence
|
|
<< " polygon=" << obj.polygon.size() << " pts\n";
|
|
|
|
// Draw bounding box
|
|
cv::Scalar boxColor(0, 255, 0);
|
|
cv::rectangle(image, obj.box, boxColor, 2);
|
|
|
|
// Draw label
|
|
std::string label = obj.className + " " + std::to_string(obj.confidence).substr(0, 4);
|
|
int baseline = 0;
|
|
cv::Size textSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseline);
|
|
cv::rectangle(image,
|
|
cv::Point(obj.box.x, obj.box.y - textSize.height - 4),
|
|
cv::Point(obj.box.x + textSize.width, obj.box.y),
|
|
boxColor, cv::FILLED);
|
|
cv::putText(image, label, cv::Point(obj.box.x, obj.box.y - 2),
|
|
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0), 1, cv::LINE_AA);
|
|
|
|
// Draw polygon (normalized coordinates)
|
|
if (obj.polygon.size() >= 3) {
|
|
std::vector<cv::Point> polyPts;
|
|
polyPts.reserve(obj.polygon.size());
|
|
for (const auto& pt : obj.polygon) {
|
|
polyPts.emplace_back(
|
|
static_cast<int>(pt.x * image.cols),
|
|
static_cast<int>(pt.y * image.rows));
|
|
}
|
|
cv::Mat overlay = image.clone();
|
|
std::vector<std::vector<cv::Point>> polys = { polyPts };
|
|
cv::Scalar polyColor((i * 67 + 50) % 256, (i * 123 + 100) % 256, (i * 37 + 150) % 256);
|
|
cv::fillPoly(overlay, polys, polyColor);
|
|
cv::addWeighted(overlay, 0.4, image, 0.6, 0, image);
|
|
cv::polylines(image, polys, true, polyColor, 2, cv::LINE_AA);
|
|
}
|
|
// Mask overlay fallback
|
|
else if (!obj.mask.empty()) {
|
|
cv::Mat colorMask(obj.mask.size(), CV_8UC3,
|
|
cv::Scalar((i * 67 + 50) % 256, (i * 123 + 100) % 256, (i * 37 + 150) % 256));
|
|
cv::Mat roiImg = image(obj.box);
|
|
cv::Mat maskBool;
|
|
if (obj.mask.type() != CV_8UC1)
|
|
obj.mask.convertTo(maskBool, CV_8UC1, 255.0);
|
|
else
|
|
maskBool = obj.mask;
|
|
colorMask.copyTo(roiImg, maskBool);
|
|
cv::addWeighted(roiImg, 0.4, image(obj.box), 0.6, 0, image(obj.box));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
cv::imshow("SAM3 TRT Image Test", image);
|
|
cv::waitKey(0);
|
|
cv::destroyAllWindows();
|
|
infHandle.Destroy();
|
|
std::cout << "SAM3TRT_ImageTest: done.\n";
|
|
return 0;
|
|
}
|