153 lines
4.9 KiB
C++
153 lines
4.9 KiB
C++
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_
|
|
#define DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_
|
|
|
|
#include "cross_validate_sequence_labeler_abstract.h"
|
|
#include <vector>
|
|
#include "../matrix.h"
|
|
#include "svm.h"
|
|
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename sequence_labeler_type,
|
|
typename sequence_type
|
|
>
|
|
const matrix<double> test_sequence_labeler (
|
|
const sequence_labeler_type& labeler,
|
|
const std::vector<sequence_type>& samples,
|
|
const std::vector<std::vector<unsigned long> >& labels
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT( is_sequence_labeling_problem(samples, labels) == true,
|
|
"\tmatrix test_sequence_labeler()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_sequence_labeling_problem(samples, labels): "
|
|
<< is_sequence_labeling_problem(samples, labels));
|
|
|
|
matrix<double> res(labeler.num_labels(), labeler.num_labels());
|
|
res = 0;
|
|
|
|
std::vector<unsigned long> pred;
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
{
|
|
labeler.label_sequence(samples[i], pred);
|
|
|
|
for (unsigned long j = 0; j < pred.size(); ++j)
|
|
{
|
|
const unsigned long truth = labels[i][j];
|
|
if (truth >= static_cast<unsigned long>(res.nr()))
|
|
{
|
|
// ignore labels the labeler doesn't know about.
|
|
continue;
|
|
}
|
|
|
|
res(truth, pred[j]) += 1;
|
|
}
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type,
|
|
typename sequence_type
|
|
>
|
|
const matrix<double> cross_validate_sequence_labeler (
|
|
const trainer_type& trainer,
|
|
const std::vector<sequence_type>& samples,
|
|
const std::vector<std::vector<unsigned long> >& labels,
|
|
const long folds
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true &&
|
|
1 < folds && folds <= static_cast<long>(samples.size()),
|
|
"\tmatrix cross_validate_sequence_labeler()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t samples.size(): " << samples.size()
|
|
<< "\n\t folds: " << folds
|
|
<< "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels)
|
|
);
|
|
|
|
#ifdef ENABLE_ASSERTS
|
|
for (unsigned long i = 0; i < labels.size(); ++i)
|
|
{
|
|
for (unsigned long j = 0; j < labels[i].size(); ++j)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(labels[i][j] < trainer.num_labels(),
|
|
"\t matrix cross_validate_sequence_labeler()"
|
|
<< "\n\t The labels are invalid."
|
|
<< "\n\t labels[i][j]: " << labels[i][j]
|
|
<< "\n\t trainer.num_labels(): " << trainer.num_labels()
|
|
<< "\n\t i: " << i
|
|
<< "\n\t j: " << j
|
|
);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
|
|
|
|
|
|
const long num_in_test = samples.size()/folds;
|
|
const long num_in_train = samples.size() - num_in_test;
|
|
|
|
std::vector<sequence_type> x_test, x_train;
|
|
std::vector<std::vector<unsigned long> > y_test, y_train;
|
|
|
|
|
|
long next_test_idx = 0;
|
|
|
|
matrix<double> res;
|
|
|
|
|
|
for (long i = 0; i < folds; ++i)
|
|
{
|
|
x_test.clear();
|
|
y_test.clear();
|
|
x_train.clear();
|
|
y_train.clear();
|
|
|
|
// load up the test samples
|
|
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
|
{
|
|
x_test.push_back(samples[next_test_idx]);
|
|
y_test.push_back(labels[next_test_idx]);
|
|
next_test_idx = (next_test_idx + 1)%samples.size();
|
|
}
|
|
|
|
// load up the training samples
|
|
long next = next_test_idx;
|
|
for (long cnt = 0; cnt < num_in_train; ++cnt)
|
|
{
|
|
x_train.push_back(samples[next]);
|
|
y_train.push_back(labels[next]);
|
|
next = (next + 1)%samples.size();
|
|
}
|
|
|
|
|
|
res += test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test);
|
|
|
|
} // for (long i = 0; i < folds; ++i)
|
|
|
|
return res;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_
|
|
|
|
|