Refactor project structure
This commit is contained in:
701
modules/ANSMOT/ByteTrack/src/BYTETracker.cpp
Normal file
701
modules/ANSMOT/ByteTrack/src/BYTETracker.cpp
Normal file
@@ -0,0 +1,701 @@
|
||||
#include "BYTETracker.h"
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
ByteTrack::BYTETracker::BYTETracker(const int& frame_rate,
|
||||
const int& track_buffer,
|
||||
const float& track_thresh,
|
||||
const float& high_thresh,
|
||||
const float& match_thresh) :
|
||||
track_thresh_(track_thresh),
|
||||
high_thresh_(high_thresh),
|
||||
match_thresh_(match_thresh),
|
||||
max_time_lost_(std::max(static_cast<size_t>(5), static_cast<size_t>(frame_rate / 30.0 * track_buffer))),
|
||||
track_buffer_(track_buffer),
|
||||
frame_id_(0),
|
||||
track_id_count_(0),
|
||||
auto_frame_rate_(false),
|
||||
estimated_fps_(static_cast<float>(frame_rate)),
|
||||
time_scale_factor_(1.0f),
|
||||
fps_sample_count_(0),
|
||||
has_last_update_time_(false)
|
||||
{
|
||||
tracked_stracks_.clear();
|
||||
lost_stracks_.clear();
|
||||
removed_stracks_.clear();
|
||||
}
|
||||
|
||||
ByteTrack::BYTETracker::~BYTETracker()
|
||||
{
|
||||
}
|
||||
|
||||
void ByteTrack::BYTETracker::update_parameters(int frameRate,
|
||||
int trackBuffer,
|
||||
double trackThreshold,
|
||||
double highThreshold,
|
||||
double matchThresold,
|
||||
bool autoFrameRate) {
|
||||
track_thresh_ = trackThreshold;
|
||||
high_thresh_ = highThreshold;
|
||||
match_thresh_ = matchThresold;
|
||||
track_buffer_ = trackBuffer;
|
||||
auto_frame_rate_ = autoFrameRate;
|
||||
estimated_fps_ = static_cast<float>(frameRate);
|
||||
time_scale_factor_ = 1.0f;
|
||||
fps_sample_count_ = 0;
|
||||
has_last_update_time_ = false;
|
||||
max_time_lost_ = std::max(static_cast<size_t>(5), static_cast<size_t>(frameRate / 30.0 * trackBuffer));
|
||||
frame_id_ = 0;
|
||||
track_id_count_ = 0;
|
||||
tracked_stracks_.clear();
|
||||
lost_stracks_.clear();
|
||||
removed_stracks_.clear();
|
||||
}
|
||||
|
||||
float ByteTrack::BYTETracker::getEstimatedFps() const {
|
||||
return estimated_fps_;
|
||||
}
|
||||
|
||||
void ByteTrack::BYTETracker::estimateFrameRate() {
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
if (!has_last_update_time_) {
|
||||
last_update_time_ = now;
|
||||
has_last_update_time_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
double delta_sec = std::chrono::duration<double>(now - last_update_time_).count();
|
||||
last_update_time_ = now;
|
||||
|
||||
// Ignore unreasonable gaps (likely pauses, not real frame intervals)
|
||||
if (delta_sec < 0.001 || delta_sec > 5.0) {
|
||||
return;
|
||||
}
|
||||
|
||||
float current_fps = static_cast<float>(1.0 / delta_sec);
|
||||
|
||||
// Clamp to reasonable range
|
||||
current_fps = std::max(1.0f, std::min(current_fps, 120.0f));
|
||||
|
||||
fps_sample_count_++;
|
||||
|
||||
// EMA smoothing: use higher alpha during warmup for faster convergence
|
||||
float alpha = (fps_sample_count_ <= 10) ? 0.3f : 0.1f;
|
||||
estimated_fps_ = alpha * current_fps + (1.0f - alpha) * estimated_fps_;
|
||||
|
||||
// Compute time scale factor: ratio of actual interval to expected interval
|
||||
float expected_dt = 1.0f / estimated_fps_;
|
||||
time_scale_factor_ = static_cast<float>(delta_sec) / expected_dt;
|
||||
time_scale_factor_ = std::max(0.5f, std::min(time_scale_factor_, 10.0f));
|
||||
|
||||
// Only adjust max_time_lost_ after warmup and when change is significant
|
||||
if (fps_sample_count_ >= 10) {
|
||||
size_t new_max_time_lost = std::max(
|
||||
static_cast<size_t>(5),
|
||||
static_cast<size_t>(estimated_fps_ / 30.0 * track_buffer_));
|
||||
|
||||
// Only update if there's a meaningful change (>10%)
|
||||
double ratio = static_cast<double>(new_max_time_lost) / static_cast<double>(max_time_lost_);
|
||||
if (ratio > 1.1 || ratio < 0.9) {
|
||||
max_time_lost_ = new_max_time_lost;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ByteTrack::BYTETracker::STrackPtr> ByteTrack::BYTETracker::update(const std::vector<Object>& objects)
|
||||
{
|
||||
// Auto-estimate frame rate from update() call timing
|
||||
if (auto_frame_rate_) {
|
||||
estimateFrameRate();
|
||||
}
|
||||
|
||||
////////////////// Step 1: Get detections //////////////////
|
||||
frame_id_++;
|
||||
|
||||
// Create new STracks using the result of object detection
|
||||
std::vector<STrackPtr> det_stracks;
|
||||
std::vector<STrackPtr> det_low_stracks;
|
||||
|
||||
for (const auto &object : objects)
|
||||
{
|
||||
const auto strack = std::make_shared<STrack>(object.rect,
|
||||
object.prob,
|
||||
object.label,
|
||||
object.left,
|
||||
object.top,
|
||||
object.right,
|
||||
object.bottom,
|
||||
object.object_id);
|
||||
if (object.prob >= track_thresh_)
|
||||
{
|
||||
det_stracks.push_back(strack);
|
||||
}
|
||||
else
|
||||
{
|
||||
det_low_stracks.push_back(strack);
|
||||
}
|
||||
}
|
||||
|
||||
// Create lists of existing STrack
|
||||
std::vector<STrackPtr> active_stracks;
|
||||
std::vector<STrackPtr> non_active_stracks;
|
||||
std::vector<STrackPtr> strack_pool;
|
||||
|
||||
for (const auto& tracked_strack : tracked_stracks_)
|
||||
{
|
||||
if (!tracked_strack->isActivated())
|
||||
{
|
||||
non_active_stracks.push_back(tracked_strack);
|
||||
}
|
||||
else
|
||||
{
|
||||
active_stracks.push_back(tracked_strack);
|
||||
}
|
||||
}
|
||||
|
||||
strack_pool = jointStracks(active_stracks, lost_stracks_);
|
||||
|
||||
// Multi-predict: call predict() multiple times when frames are skipped
|
||||
int num_predicts = 1;
|
||||
if (auto_frame_rate_ && time_scale_factor_ > 1.5f) {
|
||||
num_predicts = std::min(static_cast<int>(std::round(time_scale_factor_)), 10);
|
||||
}
|
||||
for (int p = 0; p < num_predicts; p++)
|
||||
{
|
||||
for (auto &strack : strack_pool)
|
||||
{
|
||||
strack->predict();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////// Step 2: First association, with IoU //////////////////
|
||||
// Adaptive matching: relax threshold during frame skips
|
||||
float effective_match_thresh = match_thresh_;
|
||||
if (num_predicts > 1) {
|
||||
effective_match_thresh = std::min(match_thresh_ + 0.005f * (num_predicts - 1), 0.99f);
|
||||
}
|
||||
|
||||
std::vector<STrackPtr> current_tracked_stracks;
|
||||
std::vector<STrackPtr> remain_tracked_stracks;
|
||||
std::vector<STrackPtr> remain_det_stracks;
|
||||
std::vector<STrackPtr> refind_stracks;
|
||||
|
||||
{
|
||||
std::vector<std::vector<int>> matches_idx;
|
||||
std::vector<int> unmatch_detection_idx, unmatch_track_idx;
|
||||
|
||||
const auto dists = calcIouDistance(strack_pool, det_stracks);
|
||||
linearAssignment(dists, strack_pool.size(), det_stracks.size(), effective_match_thresh,
|
||||
matches_idx, unmatch_track_idx, unmatch_detection_idx);
|
||||
|
||||
for (const auto &match_idx : matches_idx)
|
||||
{
|
||||
const auto track = strack_pool[match_idx[0]];
|
||||
const auto det = det_stracks[match_idx[1]];
|
||||
if (track->getSTrackState() == STrackState::Tracked)
|
||||
{
|
||||
track->update(*det, frame_id_);
|
||||
current_tracked_stracks.push_back(track);
|
||||
}
|
||||
else
|
||||
{
|
||||
track->reActivate(*det, frame_id_);
|
||||
refind_stracks.push_back(track);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &unmatch_idx : unmatch_detection_idx)
|
||||
{
|
||||
remain_det_stracks.push_back(det_stracks[unmatch_idx]);
|
||||
}
|
||||
|
||||
for (const auto &unmatch_idx : unmatch_track_idx)
|
||||
{
|
||||
if (strack_pool[unmatch_idx]->getSTrackState() == STrackState::Tracked)
|
||||
{
|
||||
remain_tracked_stracks.push_back(strack_pool[unmatch_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////// Step 3: Second association, using low score dets //////////////////
|
||||
std::vector<STrackPtr> current_lost_stracks;
|
||||
|
||||
{
|
||||
std::vector<std::vector<int>> matches_idx;
|
||||
std::vector<int> unmatch_track_idx, unmatch_detection_idx;
|
||||
|
||||
const auto dists = calcIouDistance(remain_tracked_stracks, det_low_stracks);
|
||||
linearAssignment(dists, remain_tracked_stracks.size(), det_low_stracks.size(), 0.5,
|
||||
matches_idx, unmatch_track_idx, unmatch_detection_idx);
|
||||
|
||||
for (const auto &match_idx : matches_idx)
|
||||
{
|
||||
const auto track = remain_tracked_stracks[match_idx[0]];
|
||||
const auto det = det_low_stracks[match_idx[1]];
|
||||
if (track->getSTrackState() == STrackState::Tracked)
|
||||
{
|
||||
track->update(*det, frame_id_);
|
||||
current_tracked_stracks.push_back(track);
|
||||
}
|
||||
else
|
||||
{
|
||||
track->reActivate(*det, frame_id_);
|
||||
refind_stracks.push_back(track);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &unmatch_track : unmatch_track_idx)
|
||||
{
|
||||
const auto track = remain_tracked_stracks[unmatch_track];
|
||||
if (track->getSTrackState() != STrackState::Lost)
|
||||
{
|
||||
track->markAsLost();
|
||||
current_lost_stracks.push_back(track);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////// Step 4: Init new stracks //////////////////
|
||||
std::vector<STrackPtr> current_removed_stracks;
|
||||
|
||||
{
|
||||
std::vector<int> unmatch_detection_idx;
|
||||
std::vector<int> unmatch_unconfirmed_idx;
|
||||
std::vector<std::vector<int>> matches_idx;
|
||||
|
||||
// Deal with unconfirmed tracks, usually tracks with only one beginning frame
|
||||
const auto dists = calcIouDistance(non_active_stracks, remain_det_stracks);
|
||||
linearAssignment(dists, non_active_stracks.size(), remain_det_stracks.size(), 0.7,
|
||||
matches_idx, unmatch_unconfirmed_idx, unmatch_detection_idx);
|
||||
|
||||
for (const auto &match_idx : matches_idx)
|
||||
{
|
||||
non_active_stracks[match_idx[0]]->update(*remain_det_stracks[match_idx[1]], frame_id_);
|
||||
current_tracked_stracks.push_back(non_active_stracks[match_idx[0]]);
|
||||
}
|
||||
|
||||
for (const auto &unmatch_idx : unmatch_unconfirmed_idx)
|
||||
{
|
||||
const auto track = non_active_stracks[unmatch_idx];
|
||||
track->markAsRemoved();
|
||||
current_removed_stracks.push_back(track);
|
||||
}
|
||||
|
||||
// Add new stracks
|
||||
for (const auto &unmatch_idx : unmatch_detection_idx)
|
||||
{
|
||||
const auto track = remain_det_stracks[unmatch_idx];
|
||||
if (track->getScore() < high_thresh_)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
track_id_count_++;
|
||||
track->activate(frame_id_, track_id_count_);
|
||||
current_tracked_stracks.push_back(track);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////// Step 5: Update state //////////////////
|
||||
for (const auto &lost_strack : lost_stracks_)
|
||||
{
|
||||
if (frame_id_ - lost_strack->getFrameId() > max_time_lost_)
|
||||
{
|
||||
lost_strack->markAsRemoved();
|
||||
current_removed_stracks.push_back(lost_strack);
|
||||
}
|
||||
}
|
||||
|
||||
tracked_stracks_ = jointStracks(current_tracked_stracks, refind_stracks);
|
||||
lost_stracks_ = subStracks(jointStracks(subStracks(lost_stracks_, tracked_stracks_), current_lost_stracks), removed_stracks_);
|
||||
removed_stracks_ = jointStracks(removed_stracks_, current_removed_stracks);
|
||||
|
||||
std::vector<STrackPtr> tracked_stracks_out, lost_stracks_out;
|
||||
removeDuplicateStracks(tracked_stracks_, lost_stracks_, tracked_stracks_out, lost_stracks_out);
|
||||
tracked_stracks_ = tracked_stracks_out;
|
||||
lost_stracks_ = lost_stracks_out;
|
||||
|
||||
std::vector<STrackPtr> output_stracks;
|
||||
for (const auto &track : tracked_stracks_)
|
||||
{
|
||||
output_stracks.push_back(track); // Pushback all trackers
|
||||
/* if (track->isActivated())
|
||||
{
|
||||
output_stracks.push_back(track);
|
||||
}*/
|
||||
}
|
||||
|
||||
return output_stracks;
|
||||
}
|
||||
std::vector<ByteTrack::BYTETracker::STrackPtr> ByteTrack::BYTETracker::jointStracks(const std::vector<STrackPtr> &a_tlist,
|
||||
const std::vector<STrackPtr> &b_tlist) const
|
||||
{
|
||||
std::map<int, int> exists;
|
||||
std::vector<STrackPtr> res;
|
||||
for (size_t i = 0; i < a_tlist.size(); i++)
|
||||
{
|
||||
exists.emplace(a_tlist[i]->getTrackId(), 1);
|
||||
res.push_back(a_tlist[i]);
|
||||
}
|
||||
for (size_t i = 0; i < b_tlist.size(); i++)
|
||||
{
|
||||
const int &tid = b_tlist[i]->getTrackId();
|
||||
if (exists.count(tid) == 0)
|
||||
{
|
||||
exists[tid] = 1;
|
||||
res.push_back(b_tlist[i]);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<ByteTrack::BYTETracker::STrackPtr> ByteTrack::BYTETracker::subStracks(const std::vector<STrackPtr> &a_tlist,
|
||||
const std::vector<STrackPtr> &b_tlist) const
|
||||
{
|
||||
std::map<int, STrackPtr> stracks;
|
||||
for (size_t i = 0; i < a_tlist.size(); i++)
|
||||
{
|
||||
stracks.emplace(a_tlist[i]->getTrackId(), a_tlist[i]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < b_tlist.size(); i++)
|
||||
{
|
||||
const int &tid = b_tlist[i]->getTrackId();
|
||||
if (stracks.count(tid) != 0)
|
||||
{
|
||||
stracks.erase(tid);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<STrackPtr> res;
|
||||
std::map<int, STrackPtr>::iterator it;
|
||||
for (it = stracks.begin(); it != stracks.end(); ++it)
|
||||
{
|
||||
res.push_back(it->second);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ByteTrack::BYTETracker::removeDuplicateStracks(const std::vector<STrackPtr> &a_stracks,
|
||||
const std::vector<STrackPtr> &b_stracks,
|
||||
std::vector<STrackPtr> &a_res,
|
||||
std::vector<STrackPtr> &b_res) const
|
||||
{
|
||||
const auto ious = calcIouDistance(a_stracks, b_stracks);
|
||||
|
||||
std::vector<std::pair<size_t, size_t>> overlapping_combinations;
|
||||
for (size_t i = 0; i < ious.size(); i++)
|
||||
{
|
||||
for (size_t j = 0; j < ious[i].size(); j++)
|
||||
{
|
||||
if (ious[i][j] < 0.15)
|
||||
{
|
||||
overlapping_combinations.emplace_back(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<bool> a_overlapping(a_stracks.size(), false), b_overlapping(b_stracks.size(), false);
|
||||
for (const auto &[a_idx, b_idx] : overlapping_combinations)
|
||||
{
|
||||
const int timep = a_stracks[a_idx]->getFrameId() - a_stracks[a_idx]->getStartFrameId();
|
||||
const int timeq = b_stracks[b_idx]->getFrameId() - b_stracks[b_idx]->getStartFrameId();
|
||||
if (timep > timeq)
|
||||
{
|
||||
b_overlapping[b_idx] = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
a_overlapping[a_idx] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t ai = 0; ai < a_stracks.size(); ai++)
|
||||
{
|
||||
if (!a_overlapping[ai])
|
||||
{
|
||||
a_res.push_back(a_stracks[ai]);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t bi = 0; bi < b_stracks.size(); bi++)
|
||||
{
|
||||
if (!b_overlapping[bi])
|
||||
{
|
||||
b_res.push_back(b_stracks[bi]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ByteTrack::BYTETracker::linearAssignment(const std::vector<std::vector<float>> &cost_matrix,
|
||||
const int &cost_matrix_size,
|
||||
const int &cost_matrix_size_size,
|
||||
const float &thresh,
|
||||
std::vector<std::vector<int>> &matches,
|
||||
std::vector<int> &a_unmatched,
|
||||
std::vector<int> &b_unmatched) const
|
||||
{
|
||||
if (cost_matrix.size() == 0)
|
||||
{
|
||||
for (int i = 0; i < cost_matrix_size; i++)
|
||||
{
|
||||
a_unmatched.push_back(i);
|
||||
}
|
||||
for (int i = 0; i < cost_matrix_size_size; i++)
|
||||
{
|
||||
b_unmatched.push_back(i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> rowsol; std::vector<int> colsol;
|
||||
execLapjv(cost_matrix, rowsol, colsol, true, thresh);
|
||||
for (size_t i = 0; i < rowsol.size(); i++)
|
||||
{
|
||||
if (rowsol[i] >= 0)
|
||||
{
|
||||
std::vector<int> match;
|
||||
match.push_back(i);
|
||||
match.push_back(rowsol[i]);
|
||||
matches.push_back(match);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_unmatched.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < colsol.size(); i++)
|
||||
{
|
||||
if (colsol[i] < 0)
|
||||
{
|
||||
b_unmatched.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> ByteTrack::BYTETracker::calcIous(const std::vector<Rect<float>> &a_rect,
|
||||
const std::vector<Rect<float>> &b_rect) const
|
||||
{
|
||||
std::vector<std::vector<float>> ious;
|
||||
if (a_rect.size() * b_rect.size() == 0)
|
||||
{
|
||||
return ious;
|
||||
}
|
||||
|
||||
ious.resize(a_rect.size());
|
||||
for (size_t i = 0; i < ious.size(); i++)
|
||||
{
|
||||
ious[i].resize(b_rect.size());
|
||||
}
|
||||
|
||||
for (size_t bi = 0; bi < b_rect.size(); bi++)
|
||||
{
|
||||
for (size_t ai = 0; ai < a_rect.size(); ai++)
|
||||
{
|
||||
ious[ai][bi] = b_rect[bi].calcIoU(a_rect[ai]);
|
||||
}
|
||||
}
|
||||
return ious;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float> > ByteTrack::BYTETracker::calcIouDistance(const std::vector<STrackPtr> &a_tracks,
|
||||
const std::vector<STrackPtr> &b_tracks) const
|
||||
{
|
||||
std::vector<ByteTrack::Rect<float>> a_rects, b_rects;
|
||||
for (size_t i = 0; i < a_tracks.size(); i++)
|
||||
{
|
||||
a_rects.push_back(a_tracks[i]->getRect());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < b_tracks.size(); i++)
|
||||
{
|
||||
b_rects.push_back(b_tracks[i]->getRect());
|
||||
}
|
||||
|
||||
const auto ious = calcIous(a_rects, b_rects);
|
||||
|
||||
std::vector<std::vector<float>> cost_matrix;
|
||||
for (size_t i = 0; i < ious.size(); i++)
|
||||
{
|
||||
std::vector<float> iou;
|
||||
for (size_t j = 0; j < ious[i].size(); j++)
|
||||
{
|
||||
iou.push_back(1 - ious[i][j]);
|
||||
}
|
||||
cost_matrix.push_back(iou);
|
||||
}
|
||||
|
||||
return cost_matrix;
|
||||
}
|
||||
|
||||
double ByteTrack::BYTETracker::execLapjv(const std::vector<std::vector<float>> &cost,
|
||||
std::vector<int> &rowsol,
|
||||
std::vector<int> &colsol,
|
||||
bool extend_cost,
|
||||
float cost_limit,
|
||||
bool return_cost) const
|
||||
{
|
||||
std::vector<std::vector<float> > cost_c;
|
||||
cost_c.assign(cost.begin(), cost.end());
|
||||
|
||||
std::vector<std::vector<float> > cost_c_extended;
|
||||
|
||||
int n_rows = cost.size();
|
||||
int n_cols = cost[0].size();
|
||||
rowsol.resize(n_rows);
|
||||
colsol.resize(n_cols);
|
||||
|
||||
int n = 0;
|
||||
if (n_rows == n_cols)
|
||||
{
|
||||
n = n_rows;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!extend_cost)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (extend_cost || cost_limit < std::numeric_limits<float>::max())
|
||||
{
|
||||
n = n_rows + n_cols;
|
||||
cost_c_extended.resize(n);
|
||||
for (size_t i = 0; i < cost_c_extended.size(); i++)
|
||||
cost_c_extended[i].resize(n);
|
||||
|
||||
if (cost_limit < std::numeric_limits<float>::max())
|
||||
{
|
||||
for (size_t i = 0; i < cost_c_extended.size(); i++)
|
||||
{
|
||||
for (size_t j = 0; j < cost_c_extended[i].size(); j++)
|
||||
{
|
||||
cost_c_extended[i][j] = cost_limit / 2.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
float cost_max = -1;
|
||||
for (size_t i = 0; i < cost_c.size(); i++)
|
||||
{
|
||||
for (size_t j = 0; j < cost_c[i].size(); j++)
|
||||
{
|
||||
if (cost_c[i][j] > cost_max)
|
||||
cost_max = cost_c[i][j];
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < cost_c_extended.size(); i++)
|
||||
{
|
||||
for (size_t j = 0; j < cost_c_extended[i].size(); j++)
|
||||
{
|
||||
cost_c_extended[i][j] = cost_max + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = n_rows; i < cost_c_extended.size(); i++)
|
||||
{
|
||||
for (size_t j = n_cols; j < cost_c_extended[i].size(); j++)
|
||||
{
|
||||
cost_c_extended[i][j] = 0;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < n_rows; i++)
|
||||
{
|
||||
for (int j = 0; j < n_cols; j++)
|
||||
{
|
||||
cost_c_extended[i][j] = cost_c[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
cost_c.clear();
|
||||
cost_c.assign(cost_c_extended.begin(), cost_c_extended.end());
|
||||
}
|
||||
|
||||
double **cost_ptr;
|
||||
cost_ptr = new double *[n];
|
||||
for (int i = 0; i < n; i++)
|
||||
cost_ptr[i] = new double[n];
|
||||
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
for (int j = 0; j < n; j++)
|
||||
{
|
||||
cost_ptr[i][j] = cost_c[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
int* x_c = new int[n];
|
||||
int *y_c = new int[n];
|
||||
|
||||
int ret = lapjv_internal(n, cost_ptr, x_c, y_c);
|
||||
if (ret != 0)
|
||||
{
|
||||
for (int i = 0; i < n; i++)
|
||||
delete[] cost_ptr[i];
|
||||
delete[] cost_ptr;
|
||||
delete[] x_c;
|
||||
delete[] y_c;
|
||||
return -1;
|
||||
}
|
||||
|
||||
double opt = 0.0;
|
||||
|
||||
if (n != n_rows)
|
||||
{
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
if (x_c[i] >= n_cols)
|
||||
x_c[i] = -1;
|
||||
if (y_c[i] >= n_rows)
|
||||
y_c[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < n_rows; i++)
|
||||
{
|
||||
rowsol[i] = x_c[i];
|
||||
}
|
||||
for (int i = 0; i < n_cols; i++)
|
||||
{
|
||||
colsol[i] = y_c[i];
|
||||
}
|
||||
|
||||
if (return_cost)
|
||||
{
|
||||
for (size_t i = 0; i < rowsol.size(); i++)
|
||||
{
|
||||
if (rowsol[i] != -1)
|
||||
{
|
||||
opt += cost_ptr[i][rowsol[i]];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (return_cost)
|
||||
{
|
||||
for (size_t i = 0; i < rowsol.size(); i++)
|
||||
{
|
||||
opt += cost_ptr[i][rowsol[i]];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
delete[]cost_ptr[i];
|
||||
}
|
||||
delete[]cost_ptr;
|
||||
delete[]x_c;
|
||||
delete[]y_c;
|
||||
|
||||
return opt;
|
||||
}
|
||||
Reference in New Issue
Block a user