637 lines
20 KiB
C++
637 lines
20 KiB
C++
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_SVM_C_EKm_TRAINER_Hh_
|
|
#define DLIB_SVM_C_EKm_TRAINER_Hh_
|
|
|
|
#include "../algs.h"
|
|
#include "function.h"
|
|
#include "kernel.h"
|
|
#include "empirical_kernel_map.h"
|
|
#include "svm_c_linear_trainer.h"
|
|
#include "svm_c_ekm_trainer_abstract.h"
|
|
#include "../statistics.h"
|
|
#include "../rand.h"
|
|
#include <vector>
|
|
|
|
namespace dlib
|
|
{
|
|
template <
|
|
typename K
|
|
>
|
|
class svm_c_ekm_trainer
|
|
{
|
|
|
|
public:
|
|
typedef K kernel_type;
|
|
typedef typename kernel_type::scalar_type scalar_type;
|
|
typedef typename kernel_type::sample_type sample_type;
|
|
typedef typename kernel_type::mem_manager_type mem_manager_type;
|
|
typedef decision_function<kernel_type> trained_function_type;
|
|
|
|
svm_c_ekm_trainer (
|
|
)
|
|
{
|
|
verbose = false;
|
|
ekm_stale = true;
|
|
|
|
initial_basis_size = 10;
|
|
basis_size_increment = 50;
|
|
max_basis_size = 300;
|
|
}
|
|
|
|
explicit svm_c_ekm_trainer (
|
|
const scalar_type& C
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C > 0,
|
|
"\t svm_c_ekm_trainer::svm_c_ekm_trainer()"
|
|
<< "\n\t C must be greater than 0"
|
|
<< "\n\t C: " << C
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
|
|
ocas.set_c(C);
|
|
verbose = false;
|
|
ekm_stale = true;
|
|
|
|
initial_basis_size = 10;
|
|
basis_size_increment = 50;
|
|
max_basis_size = 300;
|
|
}
|
|
|
|
void set_epsilon (
|
|
scalar_type eps
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(eps > 0,
|
|
"\t void svm_c_ekm_trainer::set_epsilon()"
|
|
<< "\n\t eps must be greater than 0"
|
|
<< "\n\t eps: " << eps
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
ocas.set_epsilon(eps);
|
|
}
|
|
|
|
const scalar_type get_epsilon (
|
|
) const
|
|
{
|
|
return ocas.get_epsilon();
|
|
}
|
|
|
|
void set_max_iterations (
|
|
unsigned long max_iter
|
|
)
|
|
{
|
|
ocas.set_max_iterations(max_iter);
|
|
}
|
|
|
|
unsigned long get_max_iterations (
|
|
)
|
|
{
|
|
return ocas.get_max_iterations();
|
|
}
|
|
|
|
void be_verbose (
|
|
)
|
|
{
|
|
verbose = true;
|
|
ocas.be_quiet();
|
|
}
|
|
|
|
void be_very_verbose (
|
|
)
|
|
{
|
|
verbose = true;
|
|
ocas.be_verbose();
|
|
}
|
|
|
|
void be_quiet (
|
|
)
|
|
{
|
|
verbose = false;
|
|
ocas.be_quiet();
|
|
}
|
|
|
|
void set_oca (
|
|
const oca& item
|
|
)
|
|
{
|
|
ocas.set_oca(item);
|
|
}
|
|
|
|
const oca get_oca (
|
|
) const
|
|
{
|
|
return ocas.get_oca();
|
|
}
|
|
|
|
const kernel_type get_kernel (
|
|
) const
|
|
{
|
|
return kern;
|
|
}
|
|
|
|
void set_kernel (
|
|
const kernel_type& k
|
|
)
|
|
{
|
|
kern = k;
|
|
ekm_stale = true;
|
|
}
|
|
|
|
template <typename T>
|
|
void set_basis (
|
|
const T& basis_samples
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)),
|
|
"\tvoid svm_c_ekm_trainer::set_basis(basis_samples)"
|
|
<< "\n\t You have to give a non-empty set of basis_samples and it must be a vector"
|
|
<< "\n\t basis_samples.size(): " << basis_samples.size()
|
|
<< "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples))
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
basis = mat(basis_samples);
|
|
ekm_stale = true;
|
|
}
|
|
|
|
bool basis_loaded(
|
|
) const
|
|
{
|
|
return (basis.size() != 0);
|
|
}
|
|
|
|
void clear_basis (
|
|
)
|
|
{
|
|
basis.set_size(0);
|
|
ekm.clear();
|
|
ekm_stale = true;
|
|
}
|
|
|
|
unsigned long get_max_basis_size (
|
|
) const
|
|
{
|
|
return max_basis_size;
|
|
}
|
|
|
|
void set_max_basis_size (
|
|
unsigned long max_basis_size_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(max_basis_size_ > 0,
|
|
"\t void svm_c_ekm_trainer::set_max_basis_size()"
|
|
<< "\n\t max_basis_size_ must be greater than 0"
|
|
<< "\n\t max_basis_size_: " << max_basis_size_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
max_basis_size = max_basis_size_;
|
|
if (initial_basis_size > max_basis_size)
|
|
initial_basis_size = max_basis_size;
|
|
}
|
|
|
|
unsigned long get_initial_basis_size (
|
|
) const
|
|
{
|
|
return initial_basis_size;
|
|
}
|
|
|
|
void set_initial_basis_size (
|
|
unsigned long initial_basis_size_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(initial_basis_size_ > 0,
|
|
"\t void svm_c_ekm_trainer::set_initial_basis_size()"
|
|
<< "\n\t initial_basis_size_ must be greater than 0"
|
|
<< "\n\t initial_basis_size_: " << initial_basis_size_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
initial_basis_size = initial_basis_size_;
|
|
|
|
if (initial_basis_size > max_basis_size)
|
|
max_basis_size = initial_basis_size;
|
|
}
|
|
|
|
unsigned long get_basis_size_increment (
|
|
) const
|
|
{
|
|
return basis_size_increment;
|
|
}
|
|
|
|
void set_basis_size_increment (
|
|
unsigned long basis_size_increment_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(basis_size_increment_ > 0,
|
|
"\t void svm_c_ekm_trainer::set_basis_size_increment()"
|
|
<< "\n\t basis_size_increment_ must be greater than 0"
|
|
<< "\n\t basis_size_increment_: " << basis_size_increment_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
basis_size_increment = basis_size_increment_;
|
|
}
|
|
|
|
void set_c (
|
|
scalar_type C
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C > 0,
|
|
"\t void svm_c_ekm_trainer::set_c()"
|
|
<< "\n\t C must be greater than 0"
|
|
<< "\n\t C: " << C
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
ocas.set_c(C);
|
|
}
|
|
|
|
const scalar_type get_c_class1 (
|
|
) const
|
|
{
|
|
return ocas.get_c_class1();
|
|
}
|
|
|
|
const scalar_type get_c_class2 (
|
|
) const
|
|
{
|
|
return ocas.get_c_class2();
|
|
}
|
|
|
|
void set_c_class1 (
|
|
scalar_type C
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C > 0,
|
|
"\t void svm_c_ekm_trainer::set_c_class1()"
|
|
<< "\n\t C must be greater than 0"
|
|
<< "\n\t C: " << C
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
ocas.set_c_class1(C);
|
|
}
|
|
|
|
void set_c_class2 (
|
|
scalar_type C
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C > 0,
|
|
"\t void svm_c_ekm_trainer::set_c_class2()"
|
|
<< "\n\t C must be greater than 0"
|
|
<< "\n\t C: " << C
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
ocas.set_c_class2(C);
|
|
}
|
|
|
|
template <
|
|
typename in_sample_vector_type,
|
|
typename in_scalar_vector_type
|
|
>
|
|
const decision_function<kernel_type> train (
|
|
const in_sample_vector_type& x,
|
|
const in_scalar_vector_type& y
|
|
) const
|
|
{
|
|
scalar_type obj;
|
|
if (basis_loaded())
|
|
return do_train_user_basis(mat(x),mat(y),obj);
|
|
else
|
|
return do_train_auto_basis(mat(x),mat(y),obj);
|
|
}
|
|
|
|
template <
|
|
typename in_sample_vector_type,
|
|
typename in_scalar_vector_type
|
|
>
|
|
const decision_function<kernel_type> train (
|
|
const in_sample_vector_type& x,
|
|
const in_scalar_vector_type& y,
|
|
scalar_type& svm_objective
|
|
) const
|
|
{
|
|
if (basis_loaded())
|
|
return do_train_user_basis(mat(x),mat(y),svm_objective);
|
|
else
|
|
return do_train_auto_basis(mat(x),mat(y),svm_objective);
|
|
}
|
|
|
|
|
|
private:
|
|
|
|
template <
|
|
typename in_sample_vector_type,
|
|
typename in_scalar_vector_type
|
|
>
|
|
const decision_function<kernel_type> do_train_user_basis (
|
|
const in_sample_vector_type& x,
|
|
const in_scalar_vector_type& y,
|
|
scalar_type& svm_objective
|
|
) const
|
|
/*!
|
|
requires
|
|
- basis_loaded() == true
|
|
ensures
|
|
- trains an SVM with the user supplied basis
|
|
!*/
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
|
|
"\t decision_function svm_c_ekm_trainer::train(x,y)"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t x.nr(): " << x.nr()
|
|
<< "\n\t y.nr(): " << y.nr()
|
|
<< "\n\t x.nc(): " << x.nc()
|
|
<< "\n\t y.nc(): " << y.nc()
|
|
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
|
|
);
|
|
|
|
if (ekm_stale)
|
|
{
|
|
ekm.load(kern, basis);
|
|
ekm_stale = false;
|
|
}
|
|
|
|
// project all the samples with the ekm
|
|
running_stats<scalar_type> rs;
|
|
std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples;
|
|
proj_samples.reserve(x.size());
|
|
for (long i = 0; i < x.size(); ++i)
|
|
{
|
|
if (verbose)
|
|
{
|
|
scalar_type err;
|
|
proj_samples.push_back(ekm.project(x(i), err));
|
|
rs.add(err);
|
|
}
|
|
else
|
|
{
|
|
proj_samples.push_back(ekm.project(x(i)));
|
|
}
|
|
}
|
|
|
|
if (verbose)
|
|
{
|
|
std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl;
|
|
std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
|
|
}
|
|
|
|
// now do the training
|
|
decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;
|
|
df = ocas.train(proj_samples, y, svm_objective);
|
|
|
|
if (verbose)
|
|
{
|
|
std::cout << "Final svm objective: " << svm_objective << std::endl;
|
|
}
|
|
|
|
decision_function<kernel_type> final_df;
|
|
final_df = ekm.convert_to_decision_function(df.basis_vectors(0));
|
|
final_df.b = df.b;
|
|
return final_df;
|
|
}
|
|
|
|
template <
|
|
typename in_sample_vector_type,
|
|
typename in_scalar_vector_type
|
|
>
|
|
const decision_function<kernel_type> do_train_auto_basis (
|
|
const in_sample_vector_type& x,
|
|
const in_scalar_vector_type& y,
|
|
scalar_type& svm_objective
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
|
|
"\t decision_function svm_c_ekm_trainer::train(x,y)"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t x.nr(): " << x.nr()
|
|
<< "\n\t y.nr(): " << y.nr()
|
|
<< "\n\t x.nc(): " << x.nc()
|
|
<< "\n\t y.nc(): " << y.nc()
|
|
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
|
|
);
|
|
|
|
|
|
std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples(x.size());
|
|
decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;
|
|
|
|
// we will use a linearly_independent_subset_finder to store our basis set.
|
|
linearly_independent_subset_finder<kernel_type> lisf(get_kernel(), max_basis_size);
|
|
|
|
dlib::rand rnd;
|
|
|
|
// first pick the initial basis set randomly
|
|
for (unsigned long i = 0; i < 10*initial_basis_size && lisf.size() < initial_basis_size; ++i)
|
|
{
|
|
lisf.add(x(rnd.get_random_32bit_number()%x.size()));
|
|
}
|
|
|
|
ekm.load(lisf);
|
|
|
|
// first project all samples into the span of the current basis
|
|
for (long i = 0; i < x.size(); ++i)
|
|
{
|
|
proj_samples[i] = ekm.project(x(i));
|
|
}
|
|
|
|
|
|
svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > trainer(ocas);
|
|
|
|
const scalar_type min_epsilon = trainer.get_epsilon();
|
|
// while we are determining what the basis set will be we are going to use a very
|
|
// lose stopping condition. We will tighten it back up before producing the
|
|
// final decision_function.
|
|
trainer.set_epsilon(0.2);
|
|
|
|
scalar_type prev_svm_objective = std::numeric_limits<scalar_type>::max();
|
|
|
|
empirical_kernel_map<kernel_type> prev_ekm;
|
|
|
|
// This loop is where we try to generate a basis for SVM training. We will
|
|
// do this by repeatedly training the SVM and adding a few points which violate the
|
|
// margin to the basis in each iteration.
|
|
while (true)
|
|
{
|
|
// if the basis is already as big as it's going to get then just do the most
|
|
// accurate training right now.
|
|
if (lisf.size() == max_basis_size)
|
|
trainer.set_epsilon(min_epsilon);
|
|
|
|
while (true)
|
|
{
|
|
// now do the training.
|
|
df = trainer.train(proj_samples, y, svm_objective);
|
|
|
|
if (svm_objective < prev_svm_objective)
|
|
break;
|
|
|
|
// If the training didn't reduce the objective more than last time then
|
|
// try lowering the epsilon and doing it again.
|
|
if (trainer.get_epsilon() > min_epsilon)
|
|
{
|
|
trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon));
|
|
if (verbose)
|
|
std::cout << " *** Reducing epsilon to " << trainer.get_epsilon() << std::endl;
|
|
}
|
|
else
|
|
break;
|
|
}
|
|
|
|
if (verbose)
|
|
{
|
|
std::cout << "svm objective: " << svm_objective << std::endl;
|
|
std::cout << "basis size: " << lisf.size() << std::endl;
|
|
}
|
|
|
|
// if we failed to make progress on this iteration then we are done
|
|
if (svm_objective >= prev_svm_objective)
|
|
break;
|
|
|
|
prev_svm_objective = svm_objective;
|
|
|
|
// now add more elements to the basis
|
|
unsigned long count = 0;
|
|
for (unsigned long j = 0;
|
|
(j < 100*basis_size_increment) && (count < basis_size_increment) && (lisf.size() < max_basis_size);
|
|
++j)
|
|
{
|
|
// pick a random sample
|
|
const unsigned long idx = rnd.get_random_32bit_number()%x.size();
|
|
// If it is a margin violator then it is useful to add it into the basis set.
|
|
if (df(proj_samples[idx])*y(idx) < 1)
|
|
{
|
|
// Add the sample into the basis set if it is linearly independent of all the
|
|
// vectors already in the basis set.
|
|
if (lisf.add(x(idx)))
|
|
{
|
|
++count;
|
|
}
|
|
}
|
|
}
|
|
// if we couldn't add any more basis vectors then stop
|
|
if (count == 0)
|
|
{
|
|
if (verbose)
|
|
std::cout << "Stopping, couldn't add more basis vectors." << std::endl;
|
|
break;
|
|
}
|
|
|
|
|
|
// Project all the samples into the span of our newly enlarged basis. We will do this
|
|
// using the special transformation in the EKM that lets us project from a smaller
|
|
// basis set to a larger without needing to reevaluate kernel functions we have already
|
|
// computed.
|
|
ekm.swap(prev_ekm);
|
|
ekm.load(lisf);
|
|
projection_function<kernel_type> proj_part;
|
|
matrix<double> prev_to_new;
|
|
prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part);
|
|
|
|
|
|
matrix<scalar_type,0,1, mem_manager_type> temp;
|
|
for (long i = 0; i < x.size(); ++i)
|
|
{
|
|
// assign to temporary to avoid memory allocation that would result if we
|
|
// assigned this expression straight into proj_samples[i]
|
|
temp = prev_to_new*proj_samples[i] + proj_part(x(i));
|
|
proj_samples[i] = temp;
|
|
|
|
}
|
|
}
|
|
|
|
// Reproject all the data samples using the final basis. We could just use what we
|
|
// already have but the recursive thing done above to compute the proj_samples
|
|
// might have accumulated a little numerical error. So lets just be safe.
|
|
running_stats<scalar_type> rs, rs_margin;
|
|
for (long i = 0; i < x.size(); ++i)
|
|
{
|
|
if (verbose)
|
|
{
|
|
scalar_type err;
|
|
proj_samples[i] = ekm.project(x(i),err);
|
|
rs.add(err);
|
|
// if this point is within the margin
|
|
if (df(proj_samples[i])*y(i) < 1)
|
|
rs_margin.add(err);
|
|
}
|
|
else
|
|
{
|
|
proj_samples[i] = ekm.project(x(i));
|
|
}
|
|
}
|
|
|
|
// do the final training
|
|
trainer.set_epsilon(min_epsilon);
|
|
df = trainer.train(proj_samples, y, svm_objective);
|
|
|
|
|
|
if (verbose)
|
|
{
|
|
std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl;
|
|
std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
|
|
std::cout << "Mean EKM projection error for margin violators: " << rs_margin.mean() << std::endl;
|
|
std::cout << "Standard deviation of EKM projection error for margin violators: " << ((rs_margin.current_n()>1)?rs_margin.stddev():0) << std::endl;
|
|
|
|
std::cout << "Final svm objective: " << svm_objective << std::endl;
|
|
}
|
|
|
|
|
|
decision_function<kernel_type> final_df;
|
|
final_df = ekm.convert_to_decision_function(df.basis_vectors(0));
|
|
final_df.b = df.b;
|
|
|
|
// we don't need the ekm anymore so clear it out
|
|
ekm.clear();
|
|
|
|
return final_df;
|
|
}
|
|
|
|
|
|
|
|
|
|
/*!
|
|
CONVENTION
|
|
- if (ekm_stale) then
|
|
- kern or basis have changed since the last time
|
|
they were loaded into the ekm
|
|
!*/
|
|
|
|
svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > ocas;
|
|
bool verbose;
|
|
|
|
kernel_type kern;
|
|
unsigned long max_basis_size;
|
|
unsigned long basis_size_increment;
|
|
unsigned long initial_basis_size;
|
|
|
|
|
|
matrix<sample_type,0,1,mem_manager_type> basis;
|
|
mutable empirical_kernel_map<kernel_type> ekm;
|
|
mutable bool ekm_stale;
|
|
|
|
};
|
|
|
|
}
|
|
|
|
#endif // DLIB_SVM_C_EKm_TRAINER_Hh_
|
|
|
|
|
|
|