164 lines
6.3 KiB
C++
164 lines
6.3 KiB
C++
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
|
|
#define DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
|
|
|
|
#include "cross_validate_track_association_trainer_abstract.h"
|
|
#include "structural_track_association_trainer.h"
|
|
|
|
namespace dlib
|
|
{
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
namespace impl
|
|
{
|
|
template <
|
|
typename track_association_function,
|
|
typename detection_type,
|
|
typename label_type
|
|
>
|
|
void test_track_association_function (
|
|
const track_association_function& assoc,
|
|
const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& samples,
|
|
unsigned long& total_dets,
|
|
unsigned long& correctly_associated_dets
|
|
)
|
|
{
|
|
const typename track_association_function::association_function_type& f = assoc.get_assignment_function();
|
|
|
|
typedef typename detection_type::track_type track_type;
|
|
using namespace impl;
|
|
|
|
dlib::rand rnd;
|
|
std::vector<track_type> tracks;
|
|
std::map<label_type,long> track_idx; // tracks[track_idx[id]] == track with ID id.
|
|
|
|
for (unsigned long j = 0; j < samples.size(); ++j)
|
|
{
|
|
std::vector<labeled_detection<detection_type,label_type> > dets = samples[j];
|
|
// Shuffle the order of the detections so we can be sure that there isn't
|
|
// anything funny going on like the detections always coming in the same
|
|
// order relative to their labels and the association function just gets
|
|
// lucky by picking the same assignment ordering every time. So this way
|
|
// we know the assignment function really is doing something rather than
|
|
// just being lucky.
|
|
randomize_samples(dets, rnd);
|
|
|
|
total_dets += dets.size();
|
|
std::vector<long> assignments = f(get_unlabeled_dets(dets), tracks);
|
|
std::vector<bool> updated_track(tracks.size(), false);
|
|
// now update all the tracks with the detections that associated to them.
|
|
for (unsigned long k = 0; k < assignments.size(); ++k)
|
|
{
|
|
// If the detection is associated to tracks[assignments[k]]
|
|
if (assignments[k] != -1)
|
|
{
|
|
tracks[assignments[k]].update_track(dets[k].det);
|
|
updated_track[assignments[k]] = true;
|
|
|
|
// if this detection was supposed to go to this track
|
|
if (track_idx.count(dets[k].label) && track_idx[dets[k].label]==assignments[k])
|
|
++correctly_associated_dets;
|
|
|
|
track_idx[dets[k].label] = assignments[k];
|
|
}
|
|
else
|
|
{
|
|
track_type new_track;
|
|
new_track.update_track(dets[k].det);
|
|
tracks.push_back(new_track);
|
|
|
|
// if this detection was supposed to go to a new track
|
|
if (track_idx.count(dets[k].label) == 0)
|
|
++correctly_associated_dets;
|
|
|
|
track_idx[dets[k].label] = tracks.size()-1;
|
|
}
|
|
}
|
|
|
|
// Now propagate all the tracks that didn't get any detections.
|
|
for (unsigned long k = 0; k < updated_track.size(); ++k)
|
|
{
|
|
if (!updated_track[k])
|
|
tracks[k].propagate_track();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename track_association_function,
|
|
typename detection_type,
|
|
typename label_type
|
|
>
|
|
double test_track_association_function (
|
|
const track_association_function& assoc,
|
|
const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
|
|
)
|
|
{
|
|
unsigned long total_dets = 0;
|
|
unsigned long correctly_associated_dets = 0;
|
|
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
{
|
|
impl::test_track_association_function(assoc, samples[i], total_dets, correctly_associated_dets);
|
|
}
|
|
|
|
return (double)correctly_associated_dets/(double)total_dets;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type,
|
|
typename detection_type,
|
|
typename label_type
|
|
>
|
|
double cross_validate_track_association_trainer (
|
|
const trainer_type& trainer,
|
|
const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples,
|
|
const long folds
|
|
)
|
|
{
|
|
const long num_in_test = samples.size()/folds;
|
|
const long num_in_train = samples.size() - num_in_test;
|
|
|
|
std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples_train;
|
|
|
|
long next_test_idx = 0;
|
|
unsigned long total_dets = 0;
|
|
unsigned long correctly_associated_dets = 0;
|
|
|
|
for (long i = 0; i < folds; ++i)
|
|
{
|
|
samples_train.clear();
|
|
|
|
// load up the training samples
|
|
long next = (next_test_idx + num_in_test)%samples.size();
|
|
for (long cnt = 0; cnt < num_in_train; ++cnt)
|
|
{
|
|
samples_train.push_back(samples[next]);
|
|
next = (next + 1)%samples.size();
|
|
}
|
|
|
|
const track_association_function<detection_type>& df = trainer.train(samples_train);
|
|
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
|
{
|
|
impl::test_track_association_function(df, samples[next_test_idx], total_dets, correctly_associated_dets);
|
|
next_test_idx = (next_test_idx + 1)%samples.size();
|
|
}
|
|
}
|
|
|
|
return (double)correctly_associated_dets/(double)total_dets;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_
|
|
|
|
|