156 lines
4.7 KiB
C++
156 lines
4.7 KiB
C++
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_
|
|
#define DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_
|
|
|
|
#include <vector>
|
|
#include "../matrix.h"
|
|
#include "../statistics.h"
|
|
#include "cross_validate_regression_trainer_abstract.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename reg_funct_type,
|
|
typename sample_type,
|
|
typename label_type
|
|
>
|
|
matrix<double,1,4>
|
|
test_regression_function (
|
|
reg_funct_type& reg_funct,
|
|
const std::vector<sample_type>& x_test,
|
|
const std::vector<label_type>& y_test
|
|
)
|
|
{
|
|
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
|
|
"\tmatrix test_regression_function()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_learning_problem(x_test,y_test): "
|
|
<< is_learning_problem(x_test,y_test));
|
|
|
|
running_stats<double> rs, rs_mae;
|
|
running_scalar_covariance<double> rc;
|
|
|
|
for (unsigned long i = 0; i < x_test.size(); ++i)
|
|
{
|
|
// compute error
|
|
const double output = reg_funct(x_test[i]);
|
|
const double temp = output - y_test[i];
|
|
|
|
rs_mae.add(std::abs(temp));
|
|
rs.add(temp*temp);
|
|
rc.add(output, y_test[i]);
|
|
}
|
|
|
|
matrix<double,1,4> result;
|
|
result = rs.mean(), rc.correlation(), rs_mae.mean(), rs_mae.stddev();
|
|
return result;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type,
|
|
typename sample_type,
|
|
typename label_type
|
|
>
|
|
matrix<double,1,4>
|
|
cross_validate_regression_trainer (
|
|
const trainer_type& trainer,
|
|
const std::vector<sample_type>& x,
|
|
const std::vector<label_type>& y,
|
|
const long folds
|
|
)
|
|
{
|
|
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_learning_problem(x,y) == true &&
|
|
1 < folds && folds <= static_cast<long>(x.size()),
|
|
"\tmatrix cross_validate_regression_trainer()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t x.size(): " << x.size()
|
|
<< "\n\t folds: " << folds
|
|
<< "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
|
|
);
|
|
|
|
|
|
|
|
const long num_in_test = x.size()/folds;
|
|
const long num_in_train = x.size() - num_in_test;
|
|
|
|
running_stats<double> rs, rs_mae;
|
|
running_scalar_covariance<double> rc;
|
|
|
|
std::vector<sample_type> x_test, x_train;
|
|
std::vector<label_type> y_test, y_train;
|
|
|
|
|
|
long next_test_idx = 0;
|
|
|
|
|
|
for (long i = 0; i < folds; ++i)
|
|
{
|
|
x_test.clear();
|
|
y_test.clear();
|
|
x_train.clear();
|
|
y_train.clear();
|
|
|
|
// load up the test samples
|
|
for (long cnt = 0; cnt < num_in_test; ++cnt)
|
|
{
|
|
x_test.push_back(x[next_test_idx]);
|
|
y_test.push_back(y[next_test_idx]);
|
|
next_test_idx = (next_test_idx + 1)%x.size();
|
|
}
|
|
|
|
// load up the training samples
|
|
long next = next_test_idx;
|
|
for (long cnt = 0; cnt < num_in_train; ++cnt)
|
|
{
|
|
x_train.push_back(x[next]);
|
|
y_train.push_back(y[next]);
|
|
next = (next + 1)%x.size();
|
|
}
|
|
|
|
|
|
try
|
|
{
|
|
const typename trainer_type::trained_function_type& df = trainer.train(x_train,y_train);
|
|
|
|
// do the training and testing
|
|
for (unsigned long j = 0; j < x_test.size(); ++j)
|
|
{
|
|
// compute error
|
|
const double output = df(x_test[j]);
|
|
const double temp = output - y_test[j];
|
|
|
|
rs_mae.add(std::abs(temp));
|
|
rs.add(temp*temp);
|
|
rc.add(output, y_test[j]);
|
|
}
|
|
}
|
|
catch (invalid_nu_error&)
|
|
{
|
|
// just ignore cases which result in an invalid nu
|
|
}
|
|
|
|
} // for (long i = 0; i < folds; ++i)
|
|
|
|
matrix<double,1,4> result;
|
|
result = rs.mean(), rc.correlation(), rs_mae.mean(), rs_mae.stddev();
|
|
return result;
|
|
}
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_
|
|
|
|
|