83 lines
3.1 KiB
C++
83 lines
3.1 KiB
C++
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#undef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_
|
|
#ifdef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_
|
|
|
|
#include <vector>
|
|
#include "../matrix.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename reg_funct_type,
|
|
typename sample_type,
|
|
typename label_type
|
|
>
|
|
matrix<double,1,4>
|
|
test_regression_function (
|
|
reg_funct_type& reg_funct,
|
|
const std::vector<sample_type>& x_test,
|
|
const std::vector<label_type>& y_test
|
|
);
|
|
/*!
|
|
requires
|
|
- is_learning_problem(x_test, y_test)
|
|
- reg_funct_type == some kind of regression function object
|
|
(e.g. a decision_function created by the svr_trainer )
|
|
ensures
|
|
- Tests reg_funct against the given samples in x_test and target values in
|
|
y_test and returns a matrix M summarizing the results. Specifically:
|
|
- M(0) == the mean squared error.
|
|
The MSE is given by: sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0)
|
|
- M(1) == the correlation between reg_funct(x_test[i]) and y_test[i].
|
|
This is a number between -1 and 1.
|
|
- M(2) == the mean absolute error.
|
|
This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
|
|
- M(3) == the standard deviation of the absolute error.
|
|
!*/
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type,
|
|
typename sample_type,
|
|
typename label_type
|
|
>
|
|
matrix<double,1,4>
|
|
cross_validate_regression_trainer (
|
|
const trainer_type& trainer,
|
|
const std::vector<sample_type>& x,
|
|
const std::vector<label_type>& y,
|
|
const long folds
|
|
);
|
|
/*!
|
|
requires
|
|
- is_learning_problem(x,y)
|
|
- 1 < folds <= x.size()
|
|
- trainer_type == some kind of regression trainer object (e.g. svr_trainer)
|
|
ensures
|
|
- Performs k-fold cross validation by using the given trainer to solve a
|
|
regression problem for the given number of folds. Each fold is tested using
|
|
the output of the trainer. A matrix M summarizing the results is returned.
|
|
Specifically:
|
|
- M(0) == the mean squared error.
|
|
The MSE is given by: sum over i: pow(reg_funct(x[i]) - y[i], 2.0)
|
|
- M(1) == the correlation between a predicted y value and its true value.
|
|
This is a number between -1 and 1.
|
|
- M(2) == the mean absolute error.
|
|
This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
|
|
- M(3) == the standard deviation of the absolute error.
|
|
!*/
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_
|
|
|
|
|
|
|