197 lines
8.2 KiB
C++
197 lines
8.2 KiB
C++
#include "ANSByteTrack.h"
|
|
#include "boost/property_tree/ptree.hpp"
|
|
#include "boost/property_tree/json_parser.hpp"
|
|
#include "boost/foreach.hpp"
|
|
#include "boost/optional.hpp"
|
|
#include <map>
|
|
namespace ANSCENTER {
|
|
|
|
template <typename T>
|
|
T get_data(const boost::property_tree::ptree& pt, const std::string& key)
|
|
{
|
|
T ret;
|
|
if (boost::optional<T> data = pt.get_optional<T>(key))
|
|
{
|
|
ret = data.get();
|
|
}
|
|
return ret;
|
|
}
|
|
std::vector<ByteTrack::Object> GetByteTrackInputs(const boost::property_tree::ptree& pt)
|
|
{
|
|
std::vector<ByteTrack::Object> inputs_ref;
|
|
inputs_ref.clear();
|
|
BOOST_FOREACH(const boost::property_tree::ptree::value_type & child, pt.get_child("results"))
|
|
{
|
|
const boost::property_tree::ptree& result = child.second;
|
|
const auto class_id = get_data<int>(result, "class_id");
|
|
const auto prob = get_data<float>(result, "prob");
|
|
const auto x = get_data<float>(result, "x");
|
|
const auto y = get_data<float>(result, "y");
|
|
const auto width = get_data<float>(result, "width");
|
|
const auto height = get_data<float>(result, "height");
|
|
const auto left = get_data<float>(result, "left");
|
|
const auto top = get_data<float>(result, "top");
|
|
const auto right = get_data<float>(result, "right");
|
|
const auto bottom = get_data<float>(result, "bottom");
|
|
const auto object_id = get_data<std::string>(result, "object_id");
|
|
ByteTrack::Object temp{ ByteTrack::Rect(x, y, width, height),
|
|
class_id,prob,left,top,right,bottom,object_id };
|
|
inputs_ref.push_back(temp);
|
|
}
|
|
return inputs_ref;
|
|
}
|
|
ANSByteTrack::ANSByteTrack() {
|
|
_licenseValid = false;
|
|
CheckLicense();
|
|
tracker.update_parameters(30, 30, 0.5, 0.6, 0.8);
|
|
}
|
|
|
|
ANSByteTrack::~ANSByteTrack() {
|
|
|
|
}
|
|
bool ANSByteTrack::Destroy() {
|
|
return true;
|
|
}
|
|
|
|
bool ANSByteTrack::UpdateParameters(const std::string& trackerParameters) {
|
|
// Use JSON Boost to parse paramter from trackerParameters
|
|
try {
|
|
int frameRate, trackBuffer;
|
|
double trackThreshold, highThreshold, matchThresold;
|
|
bool autoFrameRate;
|
|
frameRate = 15;
|
|
trackBuffer = 300;
|
|
trackThreshold = 0.5;
|
|
highThreshold = 0.1;
|
|
matchThresold = 0.95;
|
|
autoFrameRate = true;
|
|
if (!trackerParameters.empty()) {
|
|
std::stringstream ss;
|
|
ss << trackerParameters;
|
|
boost::property_tree::ptree pt;
|
|
boost::property_tree::read_json(ss, pt);
|
|
auto rootNode = pt.get_child("parameters");
|
|
|
|
auto childNode = rootNode.get_child("frame_rate");
|
|
frameRate = childNode.get_value<int>();
|
|
|
|
childNode = rootNode.get_child("track_buffer");
|
|
trackBuffer = childNode.get_value<int>();
|
|
|
|
childNode = rootNode.get_child("track_threshold");
|
|
trackThreshold = childNode.get_value<float>();
|
|
|
|
childNode = rootNode.get_child("high_threshold");
|
|
highThreshold = childNode.get_value<float>();
|
|
|
|
childNode = rootNode.get_child("match_thresold");
|
|
matchThresold = childNode.get_value<float>();
|
|
|
|
// Optional: auto frame rate estimation
|
|
if (auto optNode = rootNode.get_child_optional("auto_frame_rate")) {
|
|
autoFrameRate = optNode->get_value<int>() != 0;
|
|
}
|
|
}
|
|
tracker.update_parameters(frameRate, trackBuffer, trackThreshold, highThreshold, matchThresold, autoFrameRate);
|
|
return true;
|
|
}
|
|
|
|
catch (const std::exception& e) {
|
|
this->_logger.LogFatal("ANSByteTrack::UpdateParameters", e.what(), __FILE__, __LINE__);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
std::string ANSByteTrack::Update(int modelId, const std::string& detectionData) {
|
|
if (!_licenseValid) {
|
|
this->_logger.LogFatal("ANSByteTrack::Update", "Invalid license", __FILE__, __LINE__);
|
|
return "";
|
|
}
|
|
try {
|
|
boost::property_tree::ptree root;
|
|
boost::property_tree::ptree trackingObjects;
|
|
boost::property_tree::ptree pt;
|
|
//1. Get input
|
|
std::stringstream ss;
|
|
ss << detectionData;
|
|
boost::property_tree::read_json(ss, pt);
|
|
std::vector<ByteTrack::Object> objects = GetByteTrackInputs(pt);
|
|
//2. Do tracking
|
|
const auto outputs = tracker.update(objects);
|
|
if (outputs.size() > 0) {
|
|
for (int i = 0; i < outputs.size(); i++) {
|
|
boost::property_tree::ptree trackingNode;
|
|
trackingNode.put("model_id", modelId);
|
|
trackingNode.put("track_id", outputs[i]->getTrackId());
|
|
trackingNode.put("class_id", outputs[i]->class_id);
|
|
trackingNode.put("prob", outputs[i]->getScore());
|
|
trackingNode.put("x", outputs[i]->getRect().tlwh[0]);
|
|
trackingNode.put("y", outputs[i]->getRect().tlwh[1]);
|
|
trackingNode.put("width", outputs[i]->getRect().tlwh[2]);
|
|
trackingNode.put("height", outputs[i]->getRect().tlwh[3]);
|
|
trackingNode.put("left", outputs[i]->left);
|
|
trackingNode.put("top", outputs[i]->top);
|
|
trackingNode.put("right", outputs[i]->right);
|
|
trackingNode.put("bottom", outputs[i]->bottom);
|
|
trackingNode.put("valid", outputs[i]->isActivated());
|
|
trackingNode.put("object_id", outputs[i]->object_id);
|
|
// Add this node to the list.
|
|
trackingObjects.push_back(std::make_pair("", trackingNode));
|
|
}
|
|
}
|
|
//3. Convert result
|
|
root.add_child("results", trackingObjects);
|
|
std::ostringstream stream;
|
|
boost::property_tree::write_json(stream, root,false);
|
|
std::string st = stream.str();
|
|
return st;
|
|
}
|
|
catch (const std::exception& e) {
|
|
this->_logger.LogFatal("ANSByteTrack::Update", e.what(), __FILE__, __LINE__);
|
|
return "";
|
|
}
|
|
}
|
|
|
|
std::vector<TrackerObject> ANSByteTrack::UpdateTracker(int modelId, const std::vector<TrackerObject>& detectionObjects) {
|
|
if (!_licenseValid) {
|
|
this->_logger.LogFatal("ANSByteTrack::UpdateTracker", "Invalid license", __FILE__, __LINE__);
|
|
return std::vector<TrackerObject>();
|
|
}
|
|
try {
|
|
std::vector<TrackerObject> trackingResults;
|
|
std::vector<ByteTrack::Object> objects;
|
|
for (const auto& detObj : detectionObjects) {
|
|
ByteTrack::Object temp{ ByteTrack::Rect(detObj.x, detObj.y, detObj.width, detObj.height),
|
|
detObj.class_id,detObj.prob,detObj.left,detObj.top,detObj.right,detObj.bottom,detObj.object_id };
|
|
objects.push_back(temp);
|
|
}
|
|
//2. Do tracking
|
|
const auto outputs = tracker.update(objects);
|
|
if (outputs.size() > 0) {
|
|
for (int i = 0; i < outputs.size(); i++) {
|
|
TrackerObject trackObj;
|
|
trackObj.track_id = outputs[i]->getTrackId();
|
|
trackObj.class_id = outputs[i]->class_id;
|
|
trackObj.prob = outputs[i]->getScore();
|
|
trackObj.x = outputs[i]->getRect().tlwh[0];
|
|
trackObj.y = outputs[i]->getRect().tlwh[1];
|
|
trackObj.width = outputs[i]->getRect().tlwh[2];
|
|
trackObj.height = outputs[i]->getRect().tlwh[3];
|
|
trackObj.left = outputs[i]->left;
|
|
trackObj.top = outputs[i]->top;
|
|
trackObj.right = outputs[i]->right;
|
|
trackObj.bottom = outputs[i]->bottom;
|
|
trackObj.object_id = outputs[i]->object_id;
|
|
trackingResults.push_back(trackObj);
|
|
}
|
|
}
|
|
return trackingResults;
|
|
}
|
|
catch (const std::exception& e) {
|
|
this->_logger.LogFatal("ANSByteTrack::UpdateTracker", e.what(), __FILE__, __LINE__);
|
|
return std::vector<TrackerObject>();
|
|
}
|
|
|
|
}
|
|
}
|