#include "ANSByteTrack.h" #include "boost/property_tree/ptree.hpp" #include "boost/property_tree/json_parser.hpp" #include "boost/foreach.hpp" #include "boost/optional.hpp" #include namespace ANSCENTER { template T get_data(const boost::property_tree::ptree& pt, const std::string& key) { T ret; if (boost::optional data = pt.get_optional(key)) { ret = data.get(); } return ret; } std::vector GetByteTrackInputs(const boost::property_tree::ptree& pt) { std::vector inputs_ref; inputs_ref.clear(); BOOST_FOREACH(const boost::property_tree::ptree::value_type & child, pt.get_child("results")) { const boost::property_tree::ptree& result = child.second; const auto class_id = get_data(result, "class_id"); const auto prob = get_data(result, "prob"); const auto x = get_data(result, "x"); const auto y = get_data(result, "y"); const auto width = get_data(result, "width"); const auto height = get_data(result, "height"); const auto left = get_data(result, "left"); const auto top = get_data(result, "top"); const auto right = get_data(result, "right"); const auto bottom = get_data(result, "bottom"); const auto object_id = get_data(result, "object_id"); ByteTrack::Object temp{ ByteTrack::Rect(x, y, width, height), class_id,prob,left,top,right,bottom,object_id }; inputs_ref.push_back(temp); } return inputs_ref; } ANSByteTrack::ANSByteTrack() { _licenseValid = false; CheckLicense(); tracker.update_parameters(30, 30, 0.5, 0.6, 0.8); } ANSByteTrack::~ANSByteTrack() { } bool ANSByteTrack::Destroy() { return true; } bool ANSByteTrack::UpdateParameters(const std::string& trackerParameters) { // Use JSON Boost to parse paramter from trackerParameters try { int frameRate, trackBuffer; double trackThreshold, highThreshold, matchThresold; bool autoFrameRate; frameRate = 15; trackBuffer = 300; trackThreshold = 0.5; highThreshold = 0.1; matchThresold = 0.95; autoFrameRate = true; if (!trackerParameters.empty()) { std::stringstream ss; ss << trackerParameters; boost::property_tree::ptree pt; boost::property_tree::read_json(ss, pt); auto rootNode = pt.get_child("parameters"); auto childNode = rootNode.get_child("frame_rate"); frameRate = childNode.get_value(); childNode = rootNode.get_child("track_buffer"); trackBuffer = childNode.get_value(); childNode = rootNode.get_child("track_threshold"); trackThreshold = childNode.get_value(); childNode = rootNode.get_child("high_threshold"); highThreshold = childNode.get_value(); childNode = rootNode.get_child("match_thresold"); matchThresold = childNode.get_value(); // Optional: auto frame rate estimation if (auto optNode = rootNode.get_child_optional("auto_frame_rate")) { autoFrameRate = optNode->get_value() != 0; } } tracker.update_parameters(frameRate, trackBuffer, trackThreshold, highThreshold, matchThresold, autoFrameRate); return true; } catch (const std::exception& e) { this->_logger.LogFatal("ANSByteTrack::UpdateParameters", e.what(), __FILE__, __LINE__); return false; } } std::string ANSByteTrack::Update(int modelId, const std::string& detectionData) { if (!_licenseValid) { this->_logger.LogFatal("ANSByteTrack::Update", "Invalid license", __FILE__, __LINE__); return ""; } try { boost::property_tree::ptree root; boost::property_tree::ptree trackingObjects; boost::property_tree::ptree pt; //1. Get input std::stringstream ss; ss << detectionData; boost::property_tree::read_json(ss, pt); std::vector objects = GetByteTrackInputs(pt); //2. Do tracking const auto outputs = tracker.update(objects); if (outputs.size() > 0) { for (int i = 0; i < outputs.size(); i++) { boost::property_tree::ptree trackingNode; trackingNode.put("model_id", modelId); trackingNode.put("track_id", outputs[i]->getTrackId()); trackingNode.put("class_id", outputs[i]->class_id); trackingNode.put("prob", outputs[i]->getScore()); trackingNode.put("x", outputs[i]->getRect().tlwh[0]); trackingNode.put("y", outputs[i]->getRect().tlwh[1]); trackingNode.put("width", outputs[i]->getRect().tlwh[2]); trackingNode.put("height", outputs[i]->getRect().tlwh[3]); trackingNode.put("left", outputs[i]->left); trackingNode.put("top", outputs[i]->top); trackingNode.put("right", outputs[i]->right); trackingNode.put("bottom", outputs[i]->bottom); trackingNode.put("valid", outputs[i]->isActivated()); trackingNode.put("object_id", outputs[i]->object_id); // Add this node to the list. trackingObjects.push_back(std::make_pair("", trackingNode)); } } //3. Convert result root.add_child("results", trackingObjects); std::ostringstream stream; boost::property_tree::write_json(stream, root,false); std::string st = stream.str(); return st; } catch (const std::exception& e) { this->_logger.LogFatal("ANSByteTrack::Update", e.what(), __FILE__, __LINE__); return ""; } } std::vector ANSByteTrack::UpdateTracker(int modelId, const std::vector& detectionObjects) { if (!_licenseValid) { this->_logger.LogFatal("ANSByteTrack::UpdateTracker", "Invalid license", __FILE__, __LINE__); return std::vector(); } try { std::vector trackingResults; std::vector objects; for (const auto& detObj : detectionObjects) { ByteTrack::Object temp{ ByteTrack::Rect(detObj.x, detObj.y, detObj.width, detObj.height), detObj.class_id,detObj.prob,detObj.left,detObj.top,detObj.right,detObj.bottom,detObj.object_id }; objects.push_back(temp); } //2. Do tracking const auto outputs = tracker.update(objects); if (outputs.size() > 0) { for (int i = 0; i < outputs.size(); i++) { TrackerObject trackObj; trackObj.track_id = outputs[i]->getTrackId(); trackObj.class_id = outputs[i]->class_id; trackObj.prob = outputs[i]->getScore(); trackObj.x = outputs[i]->getRect().tlwh[0]; trackObj.y = outputs[i]->getRect().tlwh[1]; trackObj.width = outputs[i]->getRect().tlwh[2]; trackObj.height = outputs[i]->getRect().tlwh[3]; trackObj.left = outputs[i]->left; trackObj.top = outputs[i]->top; trackObj.right = outputs[i]->right; trackObj.bottom = outputs[i]->bottom; trackObj.object_id = outputs[i]->object_id; trackingResults.push_back(trackObj); } } return trackingResults; } catch (const std::exception& e) { this->_logger.LogFatal("ANSByteTrack::UpdateTracker", e.what(), __FILE__, __LINE__); return std::vector(); } } }