133 lines
3.8 KiB
C++
133 lines
3.8 KiB
C++
// Copyright (C) 2008 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#undef DLIB_RBf_NETWORK_ABSTRACT_
|
|
#ifdef DLIB_RBf_NETWORK_ABSTRACT_
|
|
|
|
#include "../algs.h"
|
|
#include "function_abstract.h"
|
|
#include "kernel_abstract.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename K
|
|
>
|
|
class rbf_network_trainer
|
|
{
|
|
/*!
|
|
REQUIREMENTS ON K
|
|
is a kernel function object as defined in dlib/svm/kernel_abstract.h
|
|
(since this is supposed to be a RBF network it is probably reasonable
|
|
to use some sort of radial basis kernel)
|
|
|
|
INITIAL VALUE
|
|
- get_num_centers() == 10
|
|
|
|
WHAT THIS OBJECT REPRESENTS
|
|
This object implements a trainer for a radial basis function network.
|
|
|
|
The implementation of this algorithm follows the normal RBF training
|
|
process. For more details see the code or the Wikipedia article
|
|
about RBF networks.
|
|
!*/
|
|
|
|
public:
|
|
typedef K kernel_type;
|
|
typedef typename kernel_type::scalar_type scalar_type;
|
|
typedef typename kernel_type::sample_type sample_type;
|
|
typedef typename kernel_type::mem_manager_type mem_manager_type;
|
|
typedef decision_function<kernel_type> trained_function_type;
|
|
|
|
rbf_network_trainer (
|
|
);
|
|
/*!
|
|
ensures
|
|
- this object is properly initialized
|
|
!*/
|
|
|
|
void set_kernel (
|
|
const kernel_type& k
|
|
);
|
|
/*!
|
|
ensures
|
|
- #get_kernel() == k
|
|
!*/
|
|
|
|
const kernel_type& get_kernel (
|
|
) const;
|
|
/*!
|
|
ensures
|
|
- returns a copy of the kernel function in use by this object
|
|
!*/
|
|
|
|
void set_num_centers (
|
|
const unsigned long num_centers
|
|
);
|
|
/*!
|
|
ensures
|
|
- #get_num_centers() == num_centers
|
|
!*/
|
|
|
|
const unsigned long get_num_centers (
|
|
) const;
|
|
/*!
|
|
ensures
|
|
- returns the maximum number of centers (a.k.a. basis_vectors in the
|
|
trained decision_function) you will get when you train this object on data.
|
|
!*/
|
|
|
|
template <
|
|
typename in_sample_vector_type,
|
|
typename in_scalar_vector_type
|
|
>
|
|
const decision_function<kernel_type> train (
|
|
const in_sample_vector_type& x,
|
|
const in_scalar_vector_type& y
|
|
) const
|
|
/*!
|
|
requires
|
|
- x == a matrix or something convertible to a matrix via mat().
|
|
Also, x should contain sample_type objects.
|
|
- y == a matrix or something convertible to a matrix via mat().
|
|
Also, y should contain scalar_type objects.
|
|
- is_learning_problem(x,y) == true
|
|
ensures
|
|
- trains a RBF network given the training samples in x and
|
|
labels in y and returns the resulting decision_function
|
|
throws
|
|
- std::bad_alloc
|
|
!*/
|
|
|
|
void swap (
|
|
rbf_network_trainer& item
|
|
);
|
|
/*!
|
|
ensures
|
|
- swaps *this and item
|
|
!*/
|
|
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename K>
|
|
void swap (
|
|
rbf_network_trainer<K>& a,
|
|
rbf_network_trainer<K>& b
|
|
) { a.swap(b); }
|
|
/*!
|
|
provides a global swap
|
|
!*/
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_RBf_NETWORK_ABSTRACT_
|
|
|
|
|
|
|