Refactor project structure
This commit is contained in:
238
modules/ANSMOT/ANSByteTrackEigen.cpp
Normal file
238
modules/ANSMOT/ANSByteTrackEigen.cpp
Normal file
@@ -0,0 +1,238 @@
|
||||
#include "ANSByteTrackEigen.h"
|
||||
#include "boost/property_tree/ptree.hpp"
|
||||
#include "boost/property_tree/json_parser.hpp"
|
||||
#include "boost/foreach.hpp"
|
||||
#include "boost/optional.hpp"
|
||||
#include <map>
|
||||
namespace ANSCENTER {
|
||||
|
||||
Eigen::MatrixXf GetByteTrackEigenInputs(const char* jsonString, std::vector<std::string>&object_ids) {
|
||||
Eigen::MatrixXf ret(0, 7); // 0 row and 7 columns
|
||||
try {
|
||||
boost::property_tree::ptree pt;
|
||||
std::stringstream ss;
|
||||
ss << jsonString;
|
||||
object_ids.clear();
|
||||
boost::property_tree::read_json(ss, pt);
|
||||
boost::property_tree::ptree detections = pt.get_child("results");
|
||||
if (detections.size() > 0) {
|
||||
for (const auto& detection : detections) {
|
||||
float xMin = detection.second.get<float>("x");
|
||||
float yMin = detection.second.get<float>("y");
|
||||
float width = detection.second.get<float>("width");
|
||||
float height = detection.second.get<float>("height");
|
||||
float conf = detection.second.get<float>("prob");
|
||||
int classId = detection.second.get<int>("class_id");
|
||||
std::string object_id = detection.second.get<std::string>("object_id");
|
||||
double xMax = xMin + width;
|
||||
double yMax = yMin + height;
|
||||
Eigen::MatrixXf newRow(1, ret.cols());
|
||||
newRow << xMin, yMin, xMax, yMax, conf, classId;
|
||||
ret.conservativeResize(ret.rows() + 1, Eigen::NoChange);
|
||||
ret.row(ret.rows() - 1) = newRow;
|
||||
object_ids.push_back(object_id);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
Eigen::MatrixXf newRow(1, ret.cols());
|
||||
newRow << 0, 0, 0, 0, 0, 0, 0;
|
||||
ret.conservativeResize(ret.rows() + 1, Eigen::NoChange);
|
||||
ret.row(ret.rows() - 1) = newRow;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ANSByteTrackEigen::ANSByteTrackEigen()
|
||||
{
|
||||
_licenseValid = false;
|
||||
CheckLicense();
|
||||
|
||||
int frame_rate = 30;
|
||||
int track_buffer = 30;
|
||||
float track_thresh = 0.25;
|
||||
float track_highthres = 0.25;
|
||||
float match_thresh = 0.8;
|
||||
|
||||
tracker.update_parameters(frame_rate, track_buffer, track_thresh, track_highthres, match_thresh);
|
||||
}
|
||||
|
||||
ANSByteTrackEigen::~ANSByteTrackEigen() {
|
||||
|
||||
}
|
||||
bool ANSByteTrackEigen::Destroy() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ANSByteTrackEigen::UpdateParameters(const std::string& trackerParameters) {
|
||||
//// Use JSON Boost to parse paramter from trackerParameters
|
||||
try {
|
||||
int frameRate, trackBuffer;
|
||||
double trackThreshold, highThreshold, matchThresold;
|
||||
bool autoFrameRate;
|
||||
frameRate = 15;
|
||||
trackBuffer = 300;
|
||||
trackThreshold = 0.5;
|
||||
highThreshold = 0.1;
|
||||
matchThresold = 0.95;
|
||||
autoFrameRate = true;
|
||||
if (!trackerParameters.empty()) {
|
||||
std::stringstream ss;
|
||||
ss << trackerParameters;
|
||||
boost::property_tree::ptree pt;
|
||||
boost::property_tree::read_json(ss, pt);
|
||||
auto rootNode = pt.get_child("parameters");
|
||||
|
||||
auto childNode = rootNode.get_child("frame_rate");
|
||||
frameRate = childNode.get_value<int>();
|
||||
|
||||
childNode = rootNode.get_child("track_buffer");
|
||||
trackBuffer = childNode.get_value<int>();
|
||||
|
||||
childNode = rootNode.get_child("track_threshold");
|
||||
trackThreshold = childNode.get_value<float>();
|
||||
|
||||
childNode = rootNode.get_child("high_threshold");
|
||||
highThreshold = childNode.get_value<float>();
|
||||
|
||||
childNode = rootNode.get_child("match_thresold");
|
||||
matchThresold = childNode.get_value<float>();
|
||||
|
||||
// Optional: auto frame rate estimation
|
||||
if (auto optNode = rootNode.get_child_optional("auto_frame_rate")) {
|
||||
autoFrameRate = optNode->get_value<int>() != 0;
|
||||
}
|
||||
}
|
||||
tracker.update_parameters(frameRate, trackBuffer, trackThreshold, highThreshold, matchThresold, autoFrameRate);
|
||||
return true;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
this->_logger.LogFatal("ANSByteTrackEigen::UpdateParameters", e.what(), __FILE__, __LINE__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::string ANSByteTrackEigen::Update(int modelId, const std::string& detectionData) {
|
||||
if (!_licenseValid) {
|
||||
this->_logger.LogFatal("ANSByteTrackEigen::Update", "Invalid license", __FILE__, __LINE__);
|
||||
return "";
|
||||
}
|
||||
try {
|
||||
boost::property_tree::ptree root;
|
||||
boost::property_tree::ptree trackingObjects;
|
||||
boost::property_tree::ptree pt;
|
||||
std::vector<std::string>obj_ids;
|
||||
Eigen::MatrixXf objects = GetByteTrackEigenInputs(detectionData.c_str(), obj_ids);
|
||||
////2. Do tracking
|
||||
const auto outputs = tracker.update(objects, obj_ids);
|
||||
if (outputs.size() > 0) {
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
double x = outputs[i]._tlwh[0]; //top left x
|
||||
double y = outputs[i]._tlwh[1]; //top left y
|
||||
double w = outputs[i]._tlwh[2]; //width
|
||||
double h = outputs[i]._tlwh[3]; //height
|
||||
double left = x;
|
||||
double top = y;
|
||||
double right = x + w;
|
||||
double bottom = y + h;
|
||||
int object_id = outputs[i].get_frame_id(); // Should be object id
|
||||
bool is_valid = outputs[i].get_is_activated();
|
||||
boost::property_tree::ptree trackingNode;
|
||||
trackingNode.put("model_id", modelId);
|
||||
trackingNode.put("track_id", outputs[i].get_track_id());
|
||||
trackingNode.put("class_id", outputs[i].get_track_id());//It should be class id
|
||||
trackingNode.put("prob", outputs[i].get_score());
|
||||
trackingNode.put("x", outputs[i]._tlwh[0]);
|
||||
trackingNode.put("y", outputs[i]._tlwh[1]);
|
||||
trackingNode.put("width", outputs[i]._tlwh[2]);
|
||||
trackingNode.put("height", outputs[i]._tlwh[3]);
|
||||
trackingNode.put("left", left);
|
||||
trackingNode.put("top", top);
|
||||
trackingNode.put("right", right);
|
||||
trackingNode.put("bottom", bottom);
|
||||
trackingNode.put("valid", is_valid);
|
||||
trackingNode.put("object_id", object_id);
|
||||
trackingObjects.push_back(std::make_pair("", trackingNode));
|
||||
}
|
||||
}
|
||||
//3. Convert result
|
||||
root.add_child("results", trackingObjects);
|
||||
std::ostringstream stream;
|
||||
boost::property_tree::write_json(stream, root,false);
|
||||
std::string st = stream.str();
|
||||
return st;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
this->_logger.LogFatal("ANSByteTrackEigen::Update", e.what(), __FILE__, __LINE__);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
std::vector<TrackerObject> ANSByteTrackEigen::UpdateTracker(int modelId, const std::vector<TrackerObject>& detectionObjects) {
|
||||
if (!_licenseValid) {
|
||||
this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", "Invalid license", __FILE__, __LINE__);
|
||||
return std::vector<TrackerObject>();
|
||||
}
|
||||
try {
|
||||
std::vector<TrackerObject> trackingResults;
|
||||
Eigen::MatrixXf ret(0, 7); // 0 row and 7 columns
|
||||
std::vector<std::string> object_ids;
|
||||
for (const auto& detection : detectionObjects) {
|
||||
float xMin = detection.x;
|
||||
float yMin = detection.y;
|
||||
float width = detection.width;
|
||||
float height = detection.height;
|
||||
float conf = detection.prob;
|
||||
int classId = detection.class_id;
|
||||
std::string object_id = detection.object_id;
|
||||
double xMax = xMin + width;
|
||||
double yMax = yMin + height;
|
||||
Eigen::MatrixXf newRow(1, ret.cols());
|
||||
newRow << xMin, yMin, xMax, yMax, conf, classId;
|
||||
ret.conservativeResize(ret.rows() + 1, Eigen::NoChange);
|
||||
ret.row(ret.rows() - 1) = newRow;
|
||||
object_ids.push_back(object_id);
|
||||
}
|
||||
|
||||
//2. Do tracking
|
||||
const auto outputs = tracker.update(ret, object_ids);
|
||||
if (outputs.size() > 0) {
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
TrackerObject trackObj;
|
||||
double x = outputs[i]._tlwh[0]; //top left x
|
||||
double y = outputs[i]._tlwh[1]; //top left y
|
||||
double w = outputs[i]._tlwh[2]; //width
|
||||
double h = outputs[i]._tlwh[3]; //height
|
||||
double left = x;
|
||||
double top = y;
|
||||
double right = x + w;
|
||||
double bottom = y + h;
|
||||
std::string object_id = outputs[i].object_id; // Should be object id
|
||||
bool is_valid = outputs[i].get_is_activated();
|
||||
|
||||
trackObj.track_id = outputs[i].get_track_id();
|
||||
trackObj.class_id = outputs[i].class_id;
|
||||
trackObj.prob = outputs[i].get_score();
|
||||
trackObj.x = x;
|
||||
trackObj.y = y;
|
||||
trackObj.width = w;
|
||||
trackObj.height = h;
|
||||
trackObj.left =x;
|
||||
trackObj.top = y;
|
||||
trackObj.right = right;
|
||||
trackObj.bottom = bottom;
|
||||
trackObj.object_id = object_id;
|
||||
trackingResults.push_back(trackObj);
|
||||
}
|
||||
}
|
||||
return trackingResults;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", e.what(), __FILE__, __LINE__);
|
||||
return std::vector<TrackerObject>();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user