Files
ANSCORE/modules/ANSMOT/ANSUCMC.cpp

217 lines
9.1 KiB
C++
Raw Normal View History

2026-03-28 16:54:11 +11:00
#include "ANSUCMC.h"
#include "boost/property_tree/ptree.hpp"
#include "boost/property_tree/json_parser.hpp"
#include "boost/foreach.hpp"
#include "boost/optional.hpp"
#include <map>
namespace ANSCENTER {
template <typename T>
T get_data(const boost::property_tree::ptree& pt, const std::string& key)
{
T ret;
if (boost::optional<T> data = pt.get_optional<T>(key))
{
ret = data.get();
}
return ret;
}
std::vector<UCMC::Object> GetUCMCTrackInputs(const boost::property_tree::ptree& pt)
{
std::vector<UCMC::Object> inputs_ref;
inputs_ref.clear();
BOOST_FOREACH(const boost::property_tree::ptree::value_type & child, pt.get_child("results"))
{
const boost::property_tree::ptree& result = child.second;
const auto class_id = get_data<int>(result, "class_id");
const auto prob = get_data<float>(result, "prob");
const auto x = get_data<float>(result, "x");
const auto y = get_data<float>(result, "y");
const auto width = get_data<float>(result, "width");
const auto height = get_data<float>(result, "height");
const auto left = get_data<float>(result, "left");
const auto top = get_data<float>(result, "top");
const auto right = get_data<float>(result, "right");
const auto bottom = get_data<float>(result, "bottom");
const auto object_id = get_data<std::string>(result, "object_id");
UCMC::Object temp;
temp.rect.x = x;
temp.rect.y = y;
temp.rect.width = width;
temp.rect.height = height;
temp.label = class_id;
temp.prob = prob;
temp.object_id = object_id;
inputs_ref.push_back(temp);
}
return inputs_ref;
}
ANSUCMCTrack::ANSUCMCTrack()
{
_licenseValid = false;
CheckLicense();
double a1 = 100.0;
double a2 = 100.0;
double wx = 5.0;
double wy = 5.0;
double vmax = 10.0;
double max_age = 10.0;
double high_score = 0.5;
double conf_threshold = 0.01;
std::vector<double> Ki = { 1040, 0, 680, 0,
0, 1040, 382, 0,
0, 0, 1, 0};
std::vector<double> Ko = { -0.33962705646204017, -0.9403932802759871, -0.01771837833151917, 0,
-0.49984364998094355, 0.1964143950108213, -0.843550657048088,1,
0.7967495140209739, -0.2776362077328922, -0.5367572524549995, 33.64,
0, 0, 0, 1 };
tracker.update_parameters(a1, a2,wx,wy,vmax,max_age,high_score,conf_threshold,0.1,Ki,Ko);
}
ANSUCMCTrack::~ANSUCMCTrack()
{
}
bool ANSUCMCTrack::Destroy() {
return true;
}
bool ANSUCMCTrack::UpdateParameters(const std::string& trackerParameters) {
// Use JSON Boost to parse paramter from trackerParameters
try {
double a1 = 100.0;
double a2 = 100.0;
double wx = 5.0;
double wy = 5.0;
double vmax = 10.0;
double max_age = 10.0;
double high_score = 0.5;
double conf_threshold = 0.01;
std::vector<double> Ki = { 1040., 0., 680., 0.,0., 1040., 382., 0.,0., 0., 1., 0. };
std::vector<double> Ko = { -0.33962705646204017, -0.9403932802759871, -0.01771837833151917, 0,
-0.49984364998094355, 0.1964143950108213, -0.843550657048088, 1,
0.7967495140209739, -0.2776362077328922, -0.5367572524549995, 33.64,
0, 0, 0, 1 };
tracker.update_parameters(a1, a2, wx, wy, vmax, max_age, high_score, conf_threshold, 0.1, Ki, Ko);
return true;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSUCMCTrack::UpdateParameters", e.what(), __FILE__, __LINE__);
return false;
}
}
std::string ANSUCMCTrack::Update(int modelId, const std::string& detectionData) {
if (!_licenseValid) {
this->_logger.LogFatal("ANSUCMCTrack::Update", "Invalid license", __FILE__, __LINE__);
return "";
}
try {
boost::property_tree::ptree root;
boost::property_tree::ptree trackingObjects;
boost::property_tree::ptree pt;
//1. Get input
std::stringstream ss;
ss << detectionData;
boost::property_tree::read_json(ss, pt);
std::vector<UCMC::Object> objects = GetUCMCTrackInputs(pt);
//2. Do tracking
std::vector<UCMC::Obj> outputs = tracker.update(objects);
if (outputs.size() > 0) {
for (int i = 0; i < outputs.size(); i++) {
boost::property_tree::ptree trackingNode;
trackingNode.put("model_id", modelId);
trackingNode.put("track_id", outputs[i].track_idx);
trackingNode.put("class_id", outputs[i].obj.label);
trackingNode.put("prob", outputs[i].obj.prob);
trackingNode.put("x", outputs[i].obj.rect.x);
trackingNode.put("y", outputs[i].obj.rect.y);
trackingNode.put("width", outputs[i].obj.rect.width);
trackingNode.put("height", outputs[i].obj.rect.height);
trackingNode.put("left", outputs[i].obj.rect.x);
trackingNode.put("top", outputs[i].obj.rect.y);
trackingNode.put("right", outputs[i].obj.rect.x+ outputs[i].obj.rect.width);
trackingNode.put("bottom", outputs[i].obj.rect.y + outputs[i].obj.rect.height);
trackingNode.put("valid", true);
trackingNode.put("object_id", outputs[i].obj.object_id);
// Add this node to the list.
trackingObjects.push_back(std::make_pair("", trackingNode));
}
}
//3. Convert result
root.add_child("results", trackingObjects);
std::ostringstream stream;
boost::property_tree::write_json(stream, root,false);
std::string st = stream.str();
return st;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSUCMCTrack::Update", e.what(), __FILE__, __LINE__);
return "";
}
}
std::vector<TrackerObject> ANSUCMCTrack::UpdateTracker(int modelId, const std::vector<TrackerObject>& detectionObjects) {
if (!_licenseValid) {
this->_logger.LogFatal("ANSByteTrack::UpdateTracker", "Invalid license", __FILE__, __LINE__);
return std::vector<TrackerObject>();
}
try {
std::vector<TrackerObject> trackingResults;
std::vector<UCMC::Object> objects;
for (const auto& detObj : detectionObjects) {
const auto class_id = detObj.class_id;
const auto prob = detObj.prob;
const auto x = detObj.x;
const auto y = detObj.y;
const auto width = detObj.width;
const auto height = detObj.height;
const auto left = detObj.left;
const auto top = detObj.top;
const auto right = detObj.right;
const auto bottom = detObj.bottom;
const auto object_id = detObj.object_id;
UCMC::Object temp;
temp.rect.x = x;
temp.rect.y = y;
temp.rect.width = width;
temp.rect.height = height;
temp.label = class_id;
temp.prob = prob;
temp.object_id = object_id;
objects.push_back(temp);
}
//2. Do tracking
const auto outputs = tracker.update(objects);
if (outputs.size() > 0) {
for (int i = 0; i < outputs.size(); i++) {
TrackerObject trackObj;
trackObj.track_id = outputs[i].track_idx;
trackObj.class_id = outputs[i].obj.label;
trackObj.prob = outputs[i].obj.prob;
trackObj.x = outputs[i].obj.rect.x;
trackObj.y = outputs[i].obj.rect.y;
trackObj.width = outputs[i].obj.rect.width;
trackObj.height = outputs[i].obj.rect.height;
trackObj.left = outputs[i].obj.rect.x;
trackObj.top = outputs[i].obj.rect.y;
trackObj.right = outputs[i].obj.rect.x + outputs[i].obj.rect.width;
trackObj.bottom = outputs[i].obj.rect.y + outputs[i].obj.rect.height;
trackObj.object_id = outputs[i].obj.object_id;
trackingResults.push_back(trackObj);
}
}
return trackingResults;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSUCMCTrack::UpdateTracker", e.what(), __FILE__, __LINE__);
return std::vector<TrackerObject>();
}
}
}