Files
ANSCORE/modules/ANSMOT/ANSByteTrackNCNN.cpp

198 lines
8.4 KiB
C++
Raw Permalink Normal View History

2026-03-28 16:54:11 +11:00
#include "ANSByteTrackNCNN.h"
#include "boost/property_tree/ptree.hpp"
#include "boost/property_tree/json_parser.hpp"
#include "boost/foreach.hpp"
#include "boost/optional.hpp"
#include <map>
namespace ANSCENTER{
template <typename T>
T get_data(const boost::property_tree::ptree& pt, const std::string& key)
{
T ret;
if (boost::optional<T> data = pt.get_optional<T>(key))
{
ret = data.get();
}
return ret;
}
std::vector<ByteTrackNCNN::Object> GetByteTrackNCNNInputs(const boost::property_tree::ptree& pt)
{
std::vector<ByteTrackNCNN::Object> inputs_ref;
inputs_ref.clear();
BOOST_FOREACH(const boost::property_tree::ptree::value_type & child, pt.get_child("results"))
{
const boost::property_tree::ptree& result = child.second;
const auto class_id = get_data<int>(result, "class_id");
const auto prob = get_data<float>(result, "prob");
const auto x = get_data<float>(result, "x");
const auto y = get_data<float>(result, "y");
const auto width = get_data<float>(result, "width");
const auto height = get_data<float>(result, "height");
const auto left = get_data<float>(result, "left");
const auto top = get_data<float>(result, "top");
const auto right = get_data<float>(result, "right");
const auto bottom = get_data<float>(result, "bottom");
const auto object_id = get_data<std::string>(result, "object_id");
ByteTrackNCNN::Object temp{ ByteTrackNCNN::Rect(x, y, width, height),
class_id,prob,left,top,right,bottom,object_id };
inputs_ref.push_back(temp);
}
return inputs_ref;
}
ANSByteTrackNCNN::ANSByteTrackNCNN() {
_licenseValid = false;
CheckLicense();
tracker.update_parameters(10, 30, 0.5, 0.6, 0.8);
2026-03-28 16:54:11 +11:00
}
ANSByteTrackNCNN::~ANSByteTrackNCNN() {
}
bool ANSByteTrackNCNN::Destroy() {
return true;
}
bool ANSByteTrackNCNN::UpdateParameters(const std::string& trackerParameters) {
try {
// Use JSON Boost to parse paramter from trackerParameters
int frameRate, trackBuffer;
double trackThreshold, highThreshold, matchThresold;
bool autoFrameRate;
frameRate = 15;
trackBuffer = 300;
trackThreshold = 0.5;
highThreshold = 0.1;
matchThresold = 0.95;
autoFrameRate = true;
if (!trackerParameters.empty()) {
std::stringstream ss;
ss << trackerParameters;
boost::property_tree::ptree pt;
boost::property_tree::read_json(ss, pt);
auto rootNode = pt.get_child("parameters");
auto childNode = rootNode.get_child("frame_rate");
frameRate = childNode.get_value<int>();
childNode = rootNode.get_child("track_buffer");
trackBuffer = childNode.get_value<int>();
childNode = rootNode.get_child("track_threshold");
trackThreshold = childNode.get_value<float>();
childNode = rootNode.get_child("high_threshold");
highThreshold = childNode.get_value<float>();
childNode = rootNode.get_child("match_thresold");
matchThresold = childNode.get_value<float>();
// Optional: auto frame rate estimation
if (auto optNode = rootNode.get_child_optional("auto_frame_rate")) {
autoFrameRate = optNode->get_value<int>() != 0;
}
}
tracker.update_parameters(frameRate, trackBuffer, trackThreshold, highThreshold, matchThresold, autoFrameRate);
return true;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSByteTrackNCNN::UpdateParameters", e.what(), __FILE__, __LINE__);
return false;
}
}
std::string ANSByteTrackNCNN::Update(int modelId, const std::string& detectionData) {
if (!_licenseValid) {
this->_logger.LogFatal("ANSByteTrackNCNN::Update", "Invalid license", __FILE__, __LINE__);
return "";
}
try {
boost::property_tree::ptree root;
boost::property_tree::ptree trackingObjects;
boost::property_tree::ptree pt;
//1. Get input
std::stringstream ss;
ss << detectionData;
boost::property_tree::read_json(ss, pt);
std::vector<ByteTrackNCNN::Object> objects = GetByteTrackNCNNInputs(pt);
//2. Do tracking
const auto outputs = tracker.update(objects);
if (outputs.size()) {
for (int i = 0; i < outputs.size(); i++) {
boost::property_tree::ptree trackingNode;
trackingNode.put("model_id", modelId);
trackingNode.put("track_id", outputs[i].track_id);
trackingNode.put("class_id", outputs[i].class_id);
trackingNode.put("prob", outputs[i].score);
trackingNode.put("x", outputs[i].tlwh[0]);
trackingNode.put("y", outputs[i].tlwh[1]);
trackingNode.put("width", outputs[i].tlwh[2]);
trackingNode.put("height", outputs[i].tlwh[3]);
trackingNode.put("left", outputs[i].left);
trackingNode.put("top", outputs[i].top);
trackingNode.put("right", outputs[i].right);
trackingNode.put("bottom", outputs[i].bottom);
trackingNode.put("valid", outputs[i].is_activated);
trackingNode.put("object_id", outputs[i].object_id);
// Add this node to the list.
trackingObjects.push_back(std::make_pair("", trackingNode));
}
}
//3. Convert result
root.add_child("results", trackingObjects);
std::ostringstream stream;
boost::property_tree::write_json(stream, root,false);
std::string st = stream.str();
return st;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSByteTrackNCNN::Update", e.what(), __FILE__, __LINE__);
return "";
}
}
std::vector<TrackerObject> ANSByteTrackNCNN::UpdateTracker(int modelId, const std::vector<TrackerObject>& detectionObjects) {
if (!_licenseValid) {
this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", "Invalid license", __FILE__, __LINE__);
return std::vector<TrackerObject>();
}
try {
std::vector<TrackerObject> trackingResults;
std::vector<ByteTrackNCNN::Object> objects;
for (const auto& detObj : detectionObjects) {
ByteTrackNCNN::Object temp{ ByteTrackNCNN::Rect(detObj.x, detObj.y, detObj.width, detObj.height),
detObj.class_id,detObj.prob,detObj.left,detObj.top,detObj.right,detObj.bottom,detObj.object_id };
objects.push_back(temp);
}
//2. Do tracking
const auto outputs = tracker.update(objects);
if (outputs.size() > 0) {
for (int i = 0; i < outputs.size(); i++) {
TrackerObject trackObj;
trackObj.track_id = outputs[i].track_id;
trackObj.class_id = outputs[i].class_id;
trackObj.prob = outputs[i].score;
trackObj.x = outputs[i].tlwh[0];
trackObj.y = outputs[i].tlwh[1];
trackObj.width = outputs[i].tlwh[2];
trackObj.height = outputs[i].tlwh[3];
trackObj.left = outputs[i].left;
trackObj.top = outputs[i].top;
trackObj.right = outputs[i].right;
trackObj.bottom = outputs[i].bottom;
trackObj.object_id = outputs[i].object_id;
trackingResults.push_back(trackObj);
}
}
return trackingResults;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", e.what(), __FILE__, __LINE__);
return std::vector<TrackerObject>();
}
}
}