405 lines
13 KiB
C++
405 lines
13 KiB
C++
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
|
|
#define DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
|
|
|
|
#include "structural_track_association_trainer_abstract.h"
|
|
#include "../algs.h"
|
|
#include "svm.h"
|
|
#include <utility>
|
|
#include "track_association_function.h"
|
|
#include "structural_assignment_trainer.h"
|
|
#include <map>
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
namespace impl
|
|
{
|
|
template <
|
|
typename detection_type,
|
|
typename label_type
|
|
>
|
|
std::vector<detection_type> get_unlabeled_dets (
|
|
const std::vector<labeled_detection<detection_type,label_type> >& dets
|
|
)
|
|
{
|
|
std::vector<detection_type> temp;
|
|
temp.reserve(dets.size());
|
|
for (unsigned long i = 0; i < dets.size(); ++i)
|
|
temp.push_back(dets[i].det);
|
|
return temp;
|
|
}
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class structural_track_association_trainer
|
|
{
|
|
public:
|
|
|
|
structural_track_association_trainer (
|
|
)
|
|
{
|
|
set_defaults();
|
|
}
|
|
|
|
void set_num_threads (
|
|
unsigned long num
|
|
)
|
|
{
|
|
num_threads = num;
|
|
}
|
|
|
|
unsigned long get_num_threads (
|
|
) const
|
|
{
|
|
return num_threads;
|
|
}
|
|
|
|
void set_epsilon (
|
|
double eps_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(eps_ > 0,
|
|
"\t void structural_track_association_trainer::set_epsilon()"
|
|
<< "\n\t eps_ must be greater than 0"
|
|
<< "\n\t eps_: " << eps_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
eps = eps_;
|
|
}
|
|
|
|
double get_epsilon (
|
|
) const { return eps; }
|
|
|
|
void set_max_cache_size (
|
|
unsigned long max_size
|
|
)
|
|
{
|
|
max_cache_size = max_size;
|
|
}
|
|
|
|
unsigned long get_max_cache_size (
|
|
) const
|
|
{
|
|
return max_cache_size;
|
|
}
|
|
|
|
void set_loss_per_false_association (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss > 0,
|
|
"\t void structural_track_association_trainer::set_loss_per_false_association(loss)"
|
|
<< "\n\t Invalid inputs were given to this function "
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_false_association = loss;
|
|
}
|
|
|
|
double get_loss_per_false_association (
|
|
) const
|
|
{
|
|
return loss_per_false_association;
|
|
}
|
|
|
|
void set_loss_per_track_break (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss > 0,
|
|
"\t void structural_track_association_trainer::set_loss_per_track_break(loss)"
|
|
<< "\n\t Invalid inputs were given to this function "
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_track_break = loss;
|
|
}
|
|
|
|
double get_loss_per_track_break (
|
|
) const
|
|
{
|
|
return loss_per_track_break;
|
|
}
|
|
|
|
void be_verbose (
|
|
)
|
|
{
|
|
verbose = true;
|
|
}
|
|
|
|
void be_quiet (
|
|
)
|
|
{
|
|
verbose = false;
|
|
}
|
|
|
|
void set_oca (
|
|
const oca& item
|
|
)
|
|
{
|
|
solver = item;
|
|
}
|
|
|
|
const oca get_oca (
|
|
) const
|
|
{
|
|
return solver;
|
|
}
|
|
|
|
void set_c (
|
|
double C_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C_ > 0,
|
|
"\t void structural_track_association_trainer::set_c()"
|
|
<< "\n\t C_ must be greater than 0"
|
|
<< "\n\t C_: " << C_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
C = C_;
|
|
}
|
|
|
|
double get_c (
|
|
) const
|
|
{
|
|
return C;
|
|
}
|
|
|
|
bool learns_nonnegative_weights (
|
|
) const { return learn_nonnegative_weights; }
|
|
|
|
void set_learns_nonnegative_weights (
|
|
bool value
|
|
)
|
|
{
|
|
learn_nonnegative_weights = value;
|
|
}
|
|
|
|
template <
|
|
typename detection_type,
|
|
typename label_type
|
|
>
|
|
const track_association_function<detection_type> train (
|
|
const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_track_association_problem(samples),
|
|
"\t track_association_function structural_track_association_trainer::train()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_track_association_problem(samples): " << is_track_association_problem(samples)
|
|
);
|
|
|
|
typedef typename detection_type::track_type track_type;
|
|
|
|
const unsigned long num_dims = find_num_dims(samples);
|
|
|
|
feature_extractor_track_association<detection_type> fe(num_dims, learn_nonnegative_weights?num_dims:0);
|
|
structural_assignment_trainer<feature_extractor_track_association<detection_type> > trainer(fe);
|
|
|
|
|
|
if (verbose)
|
|
trainer.be_verbose();
|
|
|
|
trainer.set_c(C);
|
|
trainer.set_epsilon(eps);
|
|
trainer.set_max_cache_size(max_cache_size);
|
|
trainer.set_num_threads(num_threads);
|
|
trainer.set_oca(solver);
|
|
trainer.set_loss_per_missed_association(loss_per_track_break);
|
|
trainer.set_loss_per_false_association(loss_per_false_association);
|
|
|
|
std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > > assignment_samples;
|
|
std::vector<std::vector<long> > labels;
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
convert_dets_to_association_sets(samples[i], assignment_samples, labels);
|
|
|
|
|
|
return track_association_function<detection_type>(trainer.train(assignment_samples, labels));
|
|
}
|
|
|
|
template <
|
|
typename detection_type,
|
|
typename label_type
|
|
>
|
|
const track_association_function<detection_type> train (
|
|
const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& sample
|
|
) const
|
|
{
|
|
std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples;
|
|
samples.push_back(sample);
|
|
return train(samples);
|
|
}
|
|
|
|
private:
|
|
|
|
template <
|
|
typename detection_type,
|
|
typename label_type
|
|
>
|
|
static unsigned long find_num_dims (
|
|
const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
|
|
)
|
|
{
|
|
typedef typename detection_type::track_type track_type;
|
|
// find a detection_type object so we can call get_similarity_features() and
|
|
// find out how big the feature vectors are.
|
|
|
|
// for all detection histories
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
{
|
|
// for all time instances in the detection history
|
|
for (unsigned j = 0; j < samples[i].size(); ++j)
|
|
{
|
|
if (samples[i][j].size() > 0)
|
|
{
|
|
track_type new_track;
|
|
new_track.update_track(samples[i][j][0].det);
|
|
typename track_type::feature_vector_type feats;
|
|
new_track.get_similarity_features(samples[i][j][0].det, feats);
|
|
return feats.size();
|
|
}
|
|
}
|
|
}
|
|
|
|
DLIB_CASSERT(false,
|
|
"No detection objects were given in the call to dlib::structural_track_association_trainer::train()");
|
|
}
|
|
|
|
template <
|
|
typename detections_at_single_time_step,
|
|
typename detection_type,
|
|
typename track_type
|
|
>
|
|
static void convert_dets_to_association_sets (
|
|
const std::vector<detections_at_single_time_step>& det_history,
|
|
std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > >& data,
|
|
std::vector<std::vector<long> >& labels
|
|
)
|
|
{
|
|
if (det_history.size() < 1)
|
|
return;
|
|
|
|
typedef typename detections_at_single_time_step::value_type::label_type label_type;
|
|
std::vector<track_type> tracks;
|
|
// track_labels maps from detection labels to the index in tracks. So track
|
|
// with detection label X is at tracks[track_labels[X]].
|
|
std::map<label_type,unsigned long> track_labels;
|
|
add_dets_to_tracks(tracks, track_labels, det_history[0]);
|
|
|
|
using namespace impl;
|
|
for (unsigned long i = 1; i < det_history.size(); ++i)
|
|
{
|
|
data.push_back(std::make_pair(get_unlabeled_dets(det_history[i]), tracks));
|
|
labels.push_back(get_association_labels(det_history[i], track_labels));
|
|
add_dets_to_tracks(tracks, track_labels, det_history[i]);
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename labeled_detection,
|
|
typename label_type
|
|
>
|
|
static std::vector<long> get_association_labels(
|
|
const std::vector<labeled_detection>& dets,
|
|
const std::map<label_type,unsigned long>& track_labels
|
|
)
|
|
{
|
|
std::vector<long> assoc(dets.size(),-1);
|
|
// find out which detections associate to what tracks
|
|
for (unsigned long i = 0; i < dets.size(); ++i)
|
|
{
|
|
typename std::map<label_type,unsigned long>::const_iterator j;
|
|
j = track_labels.find(dets[i].label);
|
|
// If this detection matches one of the tracks then record which track it
|
|
// matched with.
|
|
if (j != track_labels.end())
|
|
assoc[i] = j->second;
|
|
}
|
|
return assoc;
|
|
}
|
|
|
|
template <
|
|
typename track_type,
|
|
typename label_type,
|
|
typename labeled_detection
|
|
>
|
|
static void add_dets_to_tracks (
|
|
std::vector<track_type>& tracks,
|
|
std::map<label_type,unsigned long>& track_labels,
|
|
const std::vector<labeled_detection>& dets
|
|
)
|
|
{
|
|
std::vector<bool> updated_track(tracks.size(), false);
|
|
|
|
// first assign the dets to the tracks
|
|
for (unsigned long i = 0; i < dets.size(); ++i)
|
|
{
|
|
const label_type& label = dets[i].label;
|
|
if (track_labels.count(label))
|
|
{
|
|
const unsigned long track_idx = track_labels[label];
|
|
tracks[track_idx].update_track(dets[i].det);
|
|
updated_track[track_idx] = true;
|
|
}
|
|
else
|
|
{
|
|
// this detection creates a new track
|
|
track_type new_track;
|
|
new_track.update_track(dets[i].det);
|
|
tracks.push_back(new_track);
|
|
track_labels[label] = tracks.size()-1;
|
|
}
|
|
|
|
}
|
|
|
|
// Now propagate all the tracks that didn't get any detections.
|
|
for (unsigned long i = 0; i < updated_track.size(); ++i)
|
|
{
|
|
if (!updated_track[i])
|
|
tracks[i].propagate_track();
|
|
}
|
|
}
|
|
|
|
double C;
|
|
oca solver;
|
|
double eps;
|
|
bool verbose;
|
|
unsigned long num_threads;
|
|
unsigned long max_cache_size;
|
|
bool learn_nonnegative_weights;
|
|
double loss_per_track_break;
|
|
double loss_per_false_association;
|
|
|
|
void set_defaults ()
|
|
{
|
|
C = 100;
|
|
verbose = false;
|
|
eps = 0.001;
|
|
num_threads = 2;
|
|
max_cache_size = 5;
|
|
learn_nonnegative_weights = false;
|
|
loss_per_track_break = 1;
|
|
loss_per_false_association = 1;
|
|
}
|
|
};
|
|
|
|
}
|
|
|
|
#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
|
|
|