188 lines
6.8 KiB
C++
188 lines
6.8 KiB
C++
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_
|
|
#define DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_
|
|
|
|
#include "cross_validate_sequence_segmenter_abstract.h"
|
|
#include "sequence_segmenter.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
namespace impl
|
|
{
|
|
template <
|
|
typename sequence_segmenter_type,
|
|
typename sequence_type
|
|
>
|
|
const matrix<double,1,3> raw_metrics_test_sequence_segmenter (
|
|
const sequence_segmenter_type& segmenter,
|
|
const std::vector<sequence_type>& samples,
|
|
const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments
|
|
)
|
|
{
|
|
std::vector<std::pair<unsigned long,unsigned long> > truth;
|
|
std::vector<std::pair<unsigned long,unsigned long> > pred;
|
|
|
|
double true_hits = 0;
|
|
double total_detections = 0;
|
|
double total_true_segments = 0;
|
|
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
{
|
|
segmenter.segment_sequence(samples[i], pred);
|
|
truth = segments[i];
|
|
// sort the segments so they will be in the same orders
|
|
std::sort(truth.begin(), truth.end());
|
|
std::sort(pred.begin(), pred.end());
|
|
|
|
total_true_segments += truth.size();
|
|
total_detections += pred.size();
|
|
|
|
unsigned long j=0,k=0;
|
|
while (j < pred.size() && k < truth.size())
|
|
{
|
|
if (pred[j].first == truth[k].first &&
|
|
pred[j].second == truth[k].second)
|
|
{
|
|
++true_hits;
|
|
++j;
|
|
++k;
|
|
}
|
|
else if (pred[j].first < truth[k].first)
|
|
{
|
|
++j;
|
|
}
|
|
else
|
|
{
|
|
++k;
|
|
}
|
|
}
|
|
}
|
|
|
|
matrix<double,1,3> res;
|
|
res = total_detections, total_true_segments, true_hits;
|
|
return res;
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename sequence_segmenter_type,
|
|
typename sequence_type
|
|
>
|
|
const matrix<double,1,3> test_sequence_segmenter (
|
|
const sequence_segmenter_type& segmenter,
|
|
const std::vector<sequence_type>& samples,
|
|
const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true,
|
|
"\tmatrix test_sequence_segmenter()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_sequence_segmentation_problem(samples, segments): "
|
|
<< is_sequence_segmentation_problem(samples, segments));
|
|
|
|
const matrix<double,1,3> metrics = impl::raw_metrics_test_sequence_segmenter(segmenter, samples, segments);
|
|
|
|
const double total_detections = metrics(0);
|
|
const double total_true_segments = metrics(1);
|
|
const double true_hits = metrics(2);
|
|
|
|
const double precision = (total_detections ==0) ? 1 : true_hits/total_detections;
|
|
const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments;
|
|
const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall);
|
|
|
|
matrix<double,1,3> res;
|
|
res = precision, recall, f1;
|
|
return res;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type,
|
|
typename sequence_type
|
|
>
|
|
const matrix<double,1,3> cross_validate_sequence_segmenter (
|
|
const trainer_type& trainer,
|
|
const std::vector<sequence_type>& samples,
|
|
const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments,
|
|
const long folds
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true &&
|
|
1 < folds && folds <= static_cast<long>(samples.size()),
|
|
"\tmatrix cross_validate_sequence_segmenter()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t folds: " << folds
|
|
<< "\n\t is_sequence_segmentation_problem(samples, segments): "
|
|
<< is_sequence_segmentation_problem(samples, segments));
|
|
|
|
|
|
const long num_in_test = samples.size()/folds;
|
|
const long num_in_train = samples.size() - num_in_test;
|
|
|
|
std::vector<sequence_type> x_test, x_train;
|
|
std::vector<std::vector<std::pair<unsigned long,unsigned long> > > y_test, y_train;
|
|
|
|
long next_test_idx = 0;
|
|
|
|
matrix<double,1,3> metrics;
|
|
metrics = 0;
|
|
|
|
for (long i = 0; i < folds; ++i)
|
|
{
|
|
x_test.clear();
|
|
y_test.clear();
|
|
x_train.clear();
|
|
y_train.clear();
|
|
|
|
// load up the test samples
|
|
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
|
{
|
|
x_test.push_back(samples[next_test_idx]);
|
|
y_test.push_back(segments[next_test_idx]);
|
|
next_test_idx = (next_test_idx + 1)%samples.size();
|
|
}
|
|
|
|
// load up the training samples
|
|
long next = next_test_idx;
|
|
for (long cnt = 0; cnt < num_in_train; ++cnt)
|
|
{
|
|
x_train.push_back(samples[next]);
|
|
y_train.push_back(segments[next]);
|
|
next = (next + 1)%samples.size();
|
|
}
|
|
|
|
|
|
metrics += impl::raw_metrics_test_sequence_segmenter(trainer.train(x_train,y_train), x_test, y_test);
|
|
} // for (long i = 0; i < folds; ++i)
|
|
|
|
|
|
const double total_detections = metrics(0);
|
|
const double total_true_segments = metrics(1);
|
|
const double true_hits = metrics(2);
|
|
|
|
const double precision = (total_detections ==0) ? 1 : true_hits/total_detections;
|
|
const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments;
|
|
const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall);
|
|
|
|
matrix<double,1,3> res;
|
|
res = precision, recall, f1;
|
|
return res;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_
|
|
|
|
|