#include "ANSByteTrackEigen.h" #include "boost/property_tree/ptree.hpp" #include "boost/property_tree/json_parser.hpp" #include "boost/foreach.hpp" #include "boost/optional.hpp" #include namespace ANSCENTER { Eigen::MatrixXf GetByteTrackEigenInputs(const char* jsonString, std::vector&object_ids) { Eigen::MatrixXf ret(0, 7); // 0 row and 7 columns try { boost::property_tree::ptree pt; std::stringstream ss; ss << jsonString; object_ids.clear(); boost::property_tree::read_json(ss, pt); boost::property_tree::ptree detections = pt.get_child("results"); if (detections.size() > 0) { for (const auto& detection : detections) { float xMin = detection.second.get("x"); float yMin = detection.second.get("y"); float width = detection.second.get("width"); float height = detection.second.get("height"); float conf = detection.second.get("prob"); int classId = detection.second.get("class_id"); std::string object_id = detection.second.get("object_id"); double xMax = xMin + width; double yMax = yMin + height; Eigen::MatrixXf newRow(1, ret.cols()); newRow << xMin, yMin, xMax, yMax, conf, classId; ret.conservativeResize(ret.rows() + 1, Eigen::NoChange); ret.row(ret.rows() - 1) = newRow; object_ids.push_back(object_id); } } return ret; } catch (const std::exception& e) { Eigen::MatrixXf newRow(1, ret.cols()); newRow << 0, 0, 0, 0, 0, 0, 0; ret.conservativeResize(ret.rows() + 1, Eigen::NoChange); ret.row(ret.rows() - 1) = newRow; return ret; } } ANSByteTrackEigen::ANSByteTrackEigen() { _licenseValid = false; CheckLicense(); int frame_rate = 10; int track_buffer = 30; float track_thresh = 0.25; float track_highthres = 0.25; float match_thresh = 0.8; tracker.update_parameters(frame_rate, track_buffer, track_thresh, track_highthres, match_thresh); } ANSByteTrackEigen::~ANSByteTrackEigen() { } bool ANSByteTrackEigen::Destroy() { return true; } bool ANSByteTrackEigen::UpdateParameters(const std::string& trackerParameters) { //// Use JSON Boost to parse paramter from trackerParameters try { int frameRate, trackBuffer; double trackThreshold, highThreshold, matchThresold; bool autoFrameRate; frameRate = 15; trackBuffer = 300; trackThreshold = 0.5; highThreshold = 0.1; matchThresold = 0.95; autoFrameRate = true; if (!trackerParameters.empty()) { std::stringstream ss; ss << trackerParameters; boost::property_tree::ptree pt; boost::property_tree::read_json(ss, pt); auto rootNode = pt.get_child("parameters"); auto childNode = rootNode.get_child("frame_rate"); frameRate = childNode.get_value(); childNode = rootNode.get_child("track_buffer"); trackBuffer = childNode.get_value(); childNode = rootNode.get_child("track_threshold"); trackThreshold = childNode.get_value(); childNode = rootNode.get_child("high_threshold"); highThreshold = childNode.get_value(); childNode = rootNode.get_child("match_thresold"); matchThresold = childNode.get_value(); // Optional: auto frame rate estimation if (auto optNode = rootNode.get_child_optional("auto_frame_rate")) { autoFrameRate = optNode->get_value() != 0; } } tracker.update_parameters(frameRate, trackBuffer, trackThreshold, highThreshold, matchThresold, autoFrameRate); return true; } catch (const std::exception& e) { this->_logger.LogFatal("ANSByteTrackEigen::UpdateParameters", e.what(), __FILE__, __LINE__); return false; } } std::string ANSByteTrackEigen::Update(int modelId, const std::string& detectionData) { if (!_licenseValid) { this->_logger.LogFatal("ANSByteTrackEigen::Update", "Invalid license", __FILE__, __LINE__); return ""; } try { boost::property_tree::ptree root; boost::property_tree::ptree trackingObjects; boost::property_tree::ptree pt; std::vectorobj_ids; Eigen::MatrixXf objects = GetByteTrackEigenInputs(detectionData.c_str(), obj_ids); ////2. Do tracking const auto outputs = tracker.update(objects, obj_ids); if (outputs.size() > 0) { for (int i = 0; i < outputs.size(); i++) { double x = outputs[i]._tlwh[0]; //top left x double y = outputs[i]._tlwh[1]; //top left y double w = outputs[i]._tlwh[2]; //width double h = outputs[i]._tlwh[3]; //height double left = x; double top = y; double right = x + w; double bottom = y + h; int object_id = outputs[i].get_frame_id(); // Should be object id bool is_valid = outputs[i].get_is_activated(); boost::property_tree::ptree trackingNode; trackingNode.put("model_id", modelId); trackingNode.put("track_id", outputs[i].get_track_id()); trackingNode.put("class_id", outputs[i].get_track_id());//It should be class id trackingNode.put("prob", outputs[i].get_score()); trackingNode.put("x", outputs[i]._tlwh[0]); trackingNode.put("y", outputs[i]._tlwh[1]); trackingNode.put("width", outputs[i]._tlwh[2]); trackingNode.put("height", outputs[i]._tlwh[3]); trackingNode.put("left", left); trackingNode.put("top", top); trackingNode.put("right", right); trackingNode.put("bottom", bottom); trackingNode.put("valid", is_valid); trackingNode.put("object_id", object_id); trackingObjects.push_back(std::make_pair("", trackingNode)); } } //3. Convert result root.add_child("results", trackingObjects); std::ostringstream stream; boost::property_tree::write_json(stream, root,false); std::string st = stream.str(); return st; } catch (const std::exception& e) { this->_logger.LogFatal("ANSByteTrackEigen::Update", e.what(), __FILE__, __LINE__); return ""; } } std::vector ANSByteTrackEigen::UpdateTracker(int modelId, const std::vector& detectionObjects) { if (!_licenseValid) { this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", "Invalid license", __FILE__, __LINE__); return std::vector(); } try { std::vector trackingResults; Eigen::MatrixXf ret(0, 7); // 0 row and 7 columns std::vector object_ids; for (const auto& detection : detectionObjects) { float xMin = detection.x; float yMin = detection.y; float width = detection.width; float height = detection.height; float conf = detection.prob; int classId = detection.class_id; std::string object_id = detection.object_id; double xMax = xMin + width; double yMax = yMin + height; Eigen::MatrixXf newRow(1, ret.cols()); newRow << xMin, yMin, xMax, yMax, conf, classId; ret.conservativeResize(ret.rows() + 1, Eigen::NoChange); ret.row(ret.rows() - 1) = newRow; object_ids.push_back(object_id); } //2. Do tracking const auto outputs = tracker.update(ret, object_ids); if (outputs.size() > 0) { for (int i = 0; i < outputs.size(); i++) { TrackerObject trackObj; double x = outputs[i]._tlwh[0]; //top left x double y = outputs[i]._tlwh[1]; //top left y double w = outputs[i]._tlwh[2]; //width double h = outputs[i]._tlwh[3]; //height double left = x; double top = y; double right = x + w; double bottom = y + h; std::string object_id = outputs[i].object_id; // Should be object id bool is_valid = outputs[i].get_is_activated(); trackObj.track_id = outputs[i].get_track_id(); trackObj.class_id = outputs[i].class_id; trackObj.prob = outputs[i].get_score(); trackObj.x = x; trackObj.y = y; trackObj.width = w; trackObj.height = h; trackObj.left =x; trackObj.top = y; trackObj.right = right; trackObj.bottom = bottom; trackObj.object_id = object_id; trackingResults.push_back(trackObj); } } return trackingResults; } catch (const std::exception& e) { this->_logger.LogFatal("ANSByteTrackNCNN::UpdateTracker", e.what(), __FILE__, __LINE__); return std::vector(); } } }