Refactor project structure
This commit is contained in:
196
modules/ANSMOT/ANSByteTrack.cpp
Normal file
196
modules/ANSMOT/ANSByteTrack.cpp
Normal file
@@ -0,0 +1,196 @@
|
||||
#include "ANSByteTrack.h"
|
||||
#include "boost/property_tree/ptree.hpp"
|
||||
#include "boost/property_tree/json_parser.hpp"
|
||||
#include "boost/foreach.hpp"
|
||||
#include "boost/optional.hpp"
|
||||
#include <map>
|
||||
namespace ANSCENTER {
|
||||
|
||||
template <typename T>
|
||||
T get_data(const boost::property_tree::ptree& pt, const std::string& key)
|
||||
{
|
||||
T ret;
|
||||
if (boost::optional<T> data = pt.get_optional<T>(key))
|
||||
{
|
||||
ret = data.get();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
std::vector<ByteTrack::Object> GetByteTrackInputs(const boost::property_tree::ptree& pt)
|
||||
{
|
||||
std::vector<ByteTrack::Object> inputs_ref;
|
||||
inputs_ref.clear();
|
||||
BOOST_FOREACH(const boost::property_tree::ptree::value_type & child, pt.get_child("results"))
|
||||
{
|
||||
const boost::property_tree::ptree& result = child.second;
|
||||
const auto class_id = get_data<int>(result, "class_id");
|
||||
const auto prob = get_data<float>(result, "prob");
|
||||
const auto x = get_data<float>(result, "x");
|
||||
const auto y = get_data<float>(result, "y");
|
||||
const auto width = get_data<float>(result, "width");
|
||||
const auto height = get_data<float>(result, "height");
|
||||
const auto left = get_data<float>(result, "left");
|
||||
const auto top = get_data<float>(result, "top");
|
||||
const auto right = get_data<float>(result, "right");
|
||||
const auto bottom = get_data<float>(result, "bottom");
|
||||
const auto object_id = get_data<std::string>(result, "object_id");
|
||||
ByteTrack::Object temp{ ByteTrack::Rect(x, y, width, height),
|
||||
class_id,prob,left,top,right,bottom,object_id };
|
||||
inputs_ref.push_back(temp);
|
||||
}
|
||||
return inputs_ref;
|
||||
}
|
||||
ANSByteTrack::ANSByteTrack() {
|
||||
_licenseValid = false;
|
||||
CheckLicense();
|
||||
tracker.update_parameters(30, 30, 0.5, 0.6, 0.8);
|
||||
}
|
||||
|
||||
ANSByteTrack::~ANSByteTrack() {
|
||||
|
||||
}
|
||||
bool ANSByteTrack::Destroy() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ANSByteTrack::UpdateParameters(const std::string& trackerParameters) {
|
||||
// Use JSON Boost to parse paramter from trackerParameters
|
||||
try {
|
||||
int frameRate, trackBuffer;
|
||||
double trackThreshold, highThreshold, matchThresold;
|
||||
bool autoFrameRate;
|
||||
frameRate = 15;
|
||||
trackBuffer = 300;
|
||||
trackThreshold = 0.5;
|
||||
highThreshold = 0.1;
|
||||
matchThresold = 0.95;
|
||||
autoFrameRate = true;
|
||||
if (!trackerParameters.empty()) {
|
||||
std::stringstream ss;
|
||||
ss << trackerParameters;
|
||||
boost::property_tree::ptree pt;
|
||||
boost::property_tree::read_json(ss, pt);
|
||||
auto rootNode = pt.get_child("parameters");
|
||||
|
||||
auto childNode = rootNode.get_child("frame_rate");
|
||||
frameRate = childNode.get_value<int>();
|
||||
|
||||
childNode = rootNode.get_child("track_buffer");
|
||||
trackBuffer = childNode.get_value<int>();
|
||||
|
||||
childNode = rootNode.get_child("track_threshold");
|
||||
trackThreshold = childNode.get_value<float>();
|
||||
|
||||
childNode = rootNode.get_child("high_threshold");
|
||||
highThreshold = childNode.get_value<float>();
|
||||
|
||||
childNode = rootNode.get_child("match_thresold");
|
||||
matchThresold = childNode.get_value<float>();
|
||||
|
||||
// Optional: auto frame rate estimation
|
||||
if (auto optNode = rootNode.get_child_optional("auto_frame_rate")) {
|
||||
autoFrameRate = optNode->get_value<int>() != 0;
|
||||
}
|
||||
}
|
||||
tracker.update_parameters(frameRate, trackBuffer, trackThreshold, highThreshold, matchThresold, autoFrameRate);
|
||||
return true;
|
||||
}
|
||||
|
||||
catch (const std::exception& e) {
|
||||
this->_logger.LogFatal("ANSByteTrack::UpdateParameters", e.what(), __FILE__, __LINE__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::string ANSByteTrack::Update(int modelId, const std::string& detectionData) {
|
||||
if (!_licenseValid) {
|
||||
this->_logger.LogFatal("ANSByteTrack::Update", "Invalid license", __FILE__, __LINE__);
|
||||
return "";
|
||||
}
|
||||
try {
|
||||
boost::property_tree::ptree root;
|
||||
boost::property_tree::ptree trackingObjects;
|
||||
boost::property_tree::ptree pt;
|
||||
//1. Get input
|
||||
std::stringstream ss;
|
||||
ss << detectionData;
|
||||
boost::property_tree::read_json(ss, pt);
|
||||
std::vector<ByteTrack::Object> objects = GetByteTrackInputs(pt);
|
||||
//2. Do tracking
|
||||
const auto outputs = tracker.update(objects);
|
||||
if (outputs.size() > 0) {
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
boost::property_tree::ptree trackingNode;
|
||||
trackingNode.put("model_id", modelId);
|
||||
trackingNode.put("track_id", outputs[i]->getTrackId());
|
||||
trackingNode.put("class_id", outputs[i]->class_id);
|
||||
trackingNode.put("prob", outputs[i]->getScore());
|
||||
trackingNode.put("x", outputs[i]->getRect().tlwh[0]);
|
||||
trackingNode.put("y", outputs[i]->getRect().tlwh[1]);
|
||||
trackingNode.put("width", outputs[i]->getRect().tlwh[2]);
|
||||
trackingNode.put("height", outputs[i]->getRect().tlwh[3]);
|
||||
trackingNode.put("left", outputs[i]->left);
|
||||
trackingNode.put("top", outputs[i]->top);
|
||||
trackingNode.put("right", outputs[i]->right);
|
||||
trackingNode.put("bottom", outputs[i]->bottom);
|
||||
trackingNode.put("valid", outputs[i]->isActivated());
|
||||
trackingNode.put("object_id", outputs[i]->object_id);
|
||||
// Add this node to the list.
|
||||
trackingObjects.push_back(std::make_pair("", trackingNode));
|
||||
}
|
||||
}
|
||||
//3. Convert result
|
||||
root.add_child("results", trackingObjects);
|
||||
std::ostringstream stream;
|
||||
boost::property_tree::write_json(stream, root,false);
|
||||
std::string st = stream.str();
|
||||
return st;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
this->_logger.LogFatal("ANSByteTrack::Update", e.what(), __FILE__, __LINE__);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TrackerObject> ANSByteTrack::UpdateTracker(int modelId, const std::vector<TrackerObject>& detectionObjects) {
|
||||
if (!_licenseValid) {
|
||||
this->_logger.LogFatal("ANSByteTrack::UpdateTracker", "Invalid license", __FILE__, __LINE__);
|
||||
return std::vector<TrackerObject>();
|
||||
}
|
||||
try {
|
||||
std::vector<TrackerObject> trackingResults;
|
||||
std::vector<ByteTrack::Object> objects;
|
||||
for (const auto& detObj : detectionObjects) {
|
||||
ByteTrack::Object temp{ ByteTrack::Rect(detObj.x, detObj.y, detObj.width, detObj.height),
|
||||
detObj.class_id,detObj.prob,detObj.left,detObj.top,detObj.right,detObj.bottom,detObj.object_id };
|
||||
objects.push_back(temp);
|
||||
}
|
||||
//2. Do tracking
|
||||
const auto outputs = tracker.update(objects);
|
||||
if (outputs.size() > 0) {
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
TrackerObject trackObj;
|
||||
trackObj.track_id = outputs[i]->getTrackId();
|
||||
trackObj.class_id = outputs[i]->class_id;
|
||||
trackObj.prob = outputs[i]->getScore();
|
||||
trackObj.x = outputs[i]->getRect().tlwh[0];
|
||||
trackObj.y = outputs[i]->getRect().tlwh[1];
|
||||
trackObj.width = outputs[i]->getRect().tlwh[2];
|
||||
trackObj.height = outputs[i]->getRect().tlwh[3];
|
||||
trackObj.left = outputs[i]->left;
|
||||
trackObj.top = outputs[i]->top;
|
||||
trackObj.right = outputs[i]->right;
|
||||
trackObj.bottom = outputs[i]->bottom;
|
||||
trackObj.object_id = outputs[i]->object_id;
|
||||
trackingResults.push_back(trackObj);
|
||||
}
|
||||
}
|
||||
return trackingResults;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
this->_logger.LogFatal("ANSByteTrack::UpdateTracker", e.what(), __FILE__, __LINE__);
|
||||
return std::vector<TrackerObject>();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user