Files
ANSCORE/modules/ANSMOT/ANSByteTrackEigen.cpp

239 lines
10 KiB
C++

#include "ANSByteTrackEigen.h"
#include "boost/property_tree/ptree.hpp"
#include "boost/property_tree/json_parser.hpp"
#include "boost/foreach.hpp"
#include "boost/optional.hpp"
#include <map>
namespace ANSCENTER {
Eigen::MatrixXf GetByteTrackEigenInputs(const char* jsonString, std::vector<std::string>&object_ids) {
Eigen::MatrixXf ret(0, 7); // 0 row and 7 columns
try {
boost::property_tree::ptree pt;
std::stringstream ss;
ss << jsonString;
object_ids.clear();
boost::property_tree::read_json(ss, pt);
boost::property_tree::ptree detections = pt.get_child("results");
if (detections.size() > 0) {
for (const auto& detection : detections) {
float xMin = detection.second.get<float>("x");
float yMin = detection.second.get<float>("y");
float width = detection.second.get<float>("width");
float height = detection.second.get<float>("height");
float conf = detection.second.get<float>("prob");
int classId = detection.second.get<int>("class_id");
std::string object_id = detection.second.get<std::string>("object_id");
double xMax = xMin + width;
double yMax = yMin + height;
Eigen::MatrixXf newRow(1, ret.cols());
newRow << xMin, yMin, xMax, yMax, conf, classId;
ret.conservativeResize(ret.rows() + 1, Eigen::NoChange);
ret.row(ret.rows() - 1) = newRow;
object_ids.push_back(object_id);
}
}
return ret;
}
catch (const std::exception& e) {
Eigen::MatrixXf newRow(1, ret.cols());
newRow << 0, 0, 0, 0, 0, 0, 0;
ret.conservativeResize(ret.rows() + 1, Eigen::NoChange);
ret.row(ret.rows() - 1) = newRow;
return ret;
}
}
ANSByteTrackEigen::ANSByteTrackEigen()
{
_licenseValid = false;
CheckLicense();
int frame_rate = 30;
int track_buffer = 30;
float track_thresh = 0.25;
float track_highthres = 0.25;
float match_thresh = 0.8;
tracker.update_parameters(frame_rate, track_buffer, track_thresh, track_highthres, match_thresh);
}
ANSByteTrackEigen::~ANSByteTrackEigen() {
}
bool ANSByteTrackEigen::Destroy() {
return true;
}
bool ANSByteTrackEigen::UpdateParameters(const std::string& trackerParameters) {
//// Use JSON Boost to parse paramter from trackerParameters
try {
int frameRate, trackBuffer;
double trackThreshold, highThreshold, matchThresold;
bool autoFrameRate;
frameRate = 15;
trackBuffer = 300;
trackThreshold = 0.5;
highThreshold = 0.1;
matchThresold = 0.95;
autoFrameRate = true;
if (!trackerParameters.empty()) {
std::stringstream ss;
ss << trackerParameters;
boost::property_tree::ptree pt;
boost::property_tree::read_json(ss, pt);
auto rootNode = pt.get_child("parameters");
auto childNode = rootNode.get_child("frame_rate");
frameRate = childNode.get_value<int>();
childNode = rootNode.get_child("track_buffer");
trackBuffer = childNode.get_value<int>();
childNode = rootNode.get_child("track_threshold");
trackThreshold = childNode.get_value<float>();
childNode = rootNode.get_child("high_threshold");
highThreshold = childNode.get_value<float>();
childNode = rootNode.get_child("match_thresold");
matchThresold = childNode.get_value<float>();
// Optional: auto frame rate estimation
if (auto optNode = rootNode.get_child_optional("auto_frame_rate")) {
autoFrameRate = optNode->get_value<int>() != 0;
}
}
tracker.update_parameters(frameRate, trackBuffer, trackThreshold, highThreshold, matchThresold, autoFrameRate);
return true;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSByteTrackEigen::UpdateParameters", e.what(), __FILE__, __LINE__);
return false;
}
}
std::string ANSByteTrackEigen::Update(int modelId, const std::string& detectionData) {
if (!_licenseValid) {
this->_logger.LogFatal("ANSByteTrackEigen::Update", "Invalid license", __FILE__, __LINE__);
return "";
}
try {
boost::property_tree::ptree root;
boost::property_tree::ptree trackingObjects;
boost::property_tree::ptree pt;
std::vector<std::string>obj_ids;
Eigen::MatrixXf objects = GetByteTrackEigenInputs(detectionData.c_str(), obj_ids);
////2. Do tracking
const auto outputs = tracker.update(objects, obj_ids);
if (outputs.size() > 0) {
for (int i = 0; i < outputs.size(); i++) {
double x = outputs[i]._tlwh[0]; //top left x
double y = outputs[i]._tlwh[1]; //top left y
double w = outputs[i]._tlwh[2]; //width
double h = outputs[i]._tlwh[3]; //height
double left = x;
double top = y;
double right = x + w;
double bottom = y + h;
int object_id = outputs[i].get_frame_id(); // Should be object id
bool is_valid = outputs[i].get_is_activated();
boost::property_tree::ptree trackingNode;
trackingNode.put("model_id", modelId);
trackingNode.put("track_id", outputs[i].get_track_id());
trackingNode.put("class_id", outputs[i].get_track_id());//It should be class id
trackingNode.put("prob", outputs[i].get_score());
trackingNode.put("x", outputs[i]._tlwh[0]);
trackingNode.put("y", outputs[i]._tlwh[1]);
trackingNode.put("width", outputs[i]._tlwh[2]);
trackingNode.put("height", outputs[i]._tlwh[3]);
trackingNode.put("left", left);
trackingNode.put("top", top);
trackingNode.put("right", right);
trackingNode.put("bottom", bottom);
trackingNode.put("valid", is_valid);
trackingNode.put("object_id", object_id);
trackingObjects.push_back(std::make_pair("", trackingNode));
}
}
//3. Convert result
root.add_child("results", trackingObjects);
std::ostringstream stream;
boost::property_tree::write_json(stream, root,false);
std::string st = stream.str();
return st;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSByteTrackEigen::Update", e.what(), __FILE__, __LINE__);
return "";
}
}
std::vector<TrackerObject> ANSByteTrackEigen::UpdateTracker(int modelId, const std::vector<TrackerObject>& detectionObjects) {
if (!_licenseValid) {
this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", "Invalid license", __FILE__, __LINE__);
return std::vector<TrackerObject>();
}
try {
std::vector<TrackerObject> trackingResults;
Eigen::MatrixXf ret(0, 7); // 0 row and 7 columns
std::vector<std::string> object_ids;
for (const auto& detection : detectionObjects) {
float xMin = detection.x;
float yMin = detection.y;
float width = detection.width;
float height = detection.height;
float conf = detection.prob;
int classId = detection.class_id;
std::string object_id = detection.object_id;
double xMax = xMin + width;
double yMax = yMin + height;
Eigen::MatrixXf newRow(1, ret.cols());
newRow << xMin, yMin, xMax, yMax, conf, classId;
ret.conservativeResize(ret.rows() + 1, Eigen::NoChange);
ret.row(ret.rows() - 1) = newRow;
object_ids.push_back(object_id);
}
//2. Do tracking
const auto outputs = tracker.update(ret, object_ids);
if (outputs.size() > 0) {
for (int i = 0; i < outputs.size(); i++) {
TrackerObject trackObj;
double x = outputs[i]._tlwh[0]; //top left x
double y = outputs[i]._tlwh[1]; //top left y
double w = outputs[i]._tlwh[2]; //width
double h = outputs[i]._tlwh[3]; //height
double left = x;
double top = y;
double right = x + w;
double bottom = y + h;
std::string object_id = outputs[i].object_id; // Should be object id
bool is_valid = outputs[i].get_is_activated();
trackObj.track_id = outputs[i].get_track_id();
trackObj.class_id = outputs[i].class_id;
trackObj.prob = outputs[i].get_score();
trackObj.x = x;
trackObj.y = y;
trackObj.width = w;
trackObj.height = h;
trackObj.left =x;
trackObj.top = y;
trackObj.right = right;
trackObj.bottom = bottom;
trackObj.object_id = object_id;
trackingResults.push_back(trackObj);
}
}
return trackingResults;
}
catch (const std::exception& e) {
this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", e.what(), __FILE__, __LINE__);
return std::vector<TrackerObject>();
}
}
}