182 lines
6.4 KiB
C++
182 lines
6.4 KiB
C++
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
|
|
#define DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
|
|
|
|
#include "cross_validate_assignment_trainer_abstract.h"
|
|
#include <vector>
|
|
#include "../matrix.h"
|
|
#include "svm.h"
|
|
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename assignment_function
|
|
>
|
|
double test_assignment_function (
|
|
const assignment_function& assigner,
|
|
const std::vector<typename assignment_function::sample_type>& samples,
|
|
const std::vector<typename assignment_function::label_type>& labels
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
#ifdef ENABLE_ASSERTS
|
|
if (assigner.forces_assignment())
|
|
{
|
|
DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
|
|
"\t double test_assignment_function()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
|
|
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
|
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
|
);
|
|
}
|
|
else
|
|
{
|
|
DLIB_ASSERT(is_assignment_problem(samples, labels),
|
|
"\t double test_assignment_function()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
|
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
|
);
|
|
}
|
|
#endif
|
|
double total_right = 0;
|
|
double total = 0;
|
|
for (unsigned long i = 0; i < samples.size(); ++i)
|
|
{
|
|
const std::vector<long>& out = assigner(samples[i]);
|
|
for (unsigned long j = 0; j < out.size(); ++j)
|
|
{
|
|
if (out[j] == labels[i][j])
|
|
++total_right;
|
|
|
|
++total;
|
|
}
|
|
}
|
|
|
|
if (total != 0)
|
|
return total_right/total;
|
|
else
|
|
return 1;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type
|
|
>
|
|
double cross_validate_assignment_trainer (
|
|
const trainer_type& trainer,
|
|
const std::vector<typename trainer_type::sample_type>& samples,
|
|
const std::vector<typename trainer_type::label_type>& labels,
|
|
const long folds
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
#ifdef ENABLE_ASSERTS
|
|
if (trainer.forces_assignment())
|
|
{
|
|
DLIB_ASSERT(is_forced_assignment_problem(samples, labels) &&
|
|
1 < folds && folds <= static_cast<long>(samples.size()),
|
|
"\t double cross_validate_assignment_trainer()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t samples.size(): " << samples.size()
|
|
<< "\n\t folds: " << folds
|
|
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
|
|
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
|
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
|
);
|
|
}
|
|
else
|
|
{
|
|
DLIB_ASSERT(is_assignment_problem(samples, labels) &&
|
|
1 < folds && folds <= static_cast<long>(samples.size()),
|
|
"\t double cross_validate_assignment_trainer()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t samples.size(): " << samples.size()
|
|
<< "\n\t folds: " << folds
|
|
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
|
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
|
);
|
|
}
|
|
#endif
|
|
|
|
|
|
|
|
typedef typename trainer_type::sample_type sample_type;
|
|
typedef typename trainer_type::label_type label_type;
|
|
|
|
const long num_in_test = samples.size()/folds;
|
|
const long num_in_train = samples.size() - num_in_test;
|
|
|
|
|
|
std::vector<sample_type> samples_test, samples_train;
|
|
std::vector<label_type> labels_test, labels_train;
|
|
|
|
|
|
long next_test_idx = 0;
|
|
double total_right = 0;
|
|
double total = 0;
|
|
|
|
|
|
for (long i = 0; i < folds; ++i)
|
|
{
|
|
samples_test.clear();
|
|
labels_test.clear();
|
|
samples_train.clear();
|
|
labels_train.clear();
|
|
|
|
// load up the test samples
|
|
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
|
{
|
|
samples_test.push_back(samples[next_test_idx]);
|
|
labels_test.push_back(labels[next_test_idx]);
|
|
next_test_idx = (next_test_idx + 1)%samples.size();
|
|
}
|
|
|
|
// load up the training samples
|
|
long next = next_test_idx;
|
|
for (long cnt = 0; cnt < num_in_train; ++cnt)
|
|
{
|
|
samples_train.push_back(samples[next]);
|
|
labels_train.push_back(labels[next]);
|
|
next = (next + 1)%samples.size();
|
|
}
|
|
|
|
|
|
const typename trainer_type::trained_function_type& df = trainer.train(samples_train,labels_train);
|
|
|
|
// check how good df is on the test data
|
|
for (unsigned long i = 0; i < samples_test.size(); ++i)
|
|
{
|
|
const std::vector<long>& out = df(samples_test[i]);
|
|
for (unsigned long j = 0; j < out.size(); ++j)
|
|
{
|
|
if (out[j] == labels_test[i][j])
|
|
++total_right;
|
|
|
|
++total;
|
|
}
|
|
}
|
|
|
|
} // for (long i = 0; i < folds; ++i)
|
|
|
|
if (total != 0)
|
|
return total_right/total;
|
|
else
|
|
return 1;
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
|
|
|