233 lines
6.7 KiB
C++
233 lines
6.7 KiB
C++
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_RLs_Hh_
|
|
#define DLIB_RLs_Hh_
|
|
|
|
#include "rls_abstract.h"
|
|
#include "../matrix.h"
|
|
#include "function.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class rls
|
|
{
|
|
|
|
public:
|
|
|
|
|
|
explicit rls(
|
|
double forget_factor_,
|
|
double C_ = 1000,
|
|
bool apply_forget_factor_to_C_ = false
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(0 < forget_factor_ && forget_factor_ <= 1 &&
|
|
0 < C_,
|
|
"\t rls::rls()"
|
|
<< "\n\t invalid arguments were given to this function"
|
|
<< "\n\t forget_factor_: " << forget_factor_
|
|
<< "\n\t C_: " << C_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
|
|
C = C_;
|
|
forget_factor = forget_factor_;
|
|
apply_forget_factor_to_C = apply_forget_factor_to_C_;
|
|
}
|
|
|
|
rls(
|
|
)
|
|
{
|
|
C = 1000;
|
|
forget_factor = 1;
|
|
apply_forget_factor_to_C = false;
|
|
}
|
|
|
|
double get_c(
|
|
) const
|
|
{
|
|
return C;
|
|
}
|
|
|
|
double get_forget_factor(
|
|
) const
|
|
{
|
|
return forget_factor;
|
|
}
|
|
|
|
bool should_apply_forget_factor_to_C (
|
|
) const
|
|
{
|
|
return apply_forget_factor_to_C;
|
|
}
|
|
|
|
template <typename EXP>
|
|
void train (
|
|
const matrix_exp<EXP>& x,
|
|
double y
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_col_vector(x) &&
|
|
(get_w().size() == 0 || get_w().size() == x.size()),
|
|
"\t void rls::train()"
|
|
<< "\n\t invalid arguments were given to this function"
|
|
<< "\n\t is_col_vector(x): " << is_col_vector(x)
|
|
<< "\n\t x.size(): " << x.size()
|
|
<< "\n\t get_w().size(): " << get_w().size()
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
if (R.size() == 0)
|
|
{
|
|
R = identity_matrix<double>(x.size())*C;
|
|
w.set_size(x.size());
|
|
w = 0;
|
|
}
|
|
|
|
// multiply by forget factor and incorporate x*trans(x) into R.
|
|
const double l = 1.0/forget_factor;
|
|
const double temp = 1 + l*trans(x)*R*x;
|
|
tmp = R*x;
|
|
R = l*R - l*l*(tmp*trans(tmp))/temp;
|
|
|
|
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
|
|
// identity matrix back in to keep the regularization alive.
|
|
if (forget_factor != 1 && !apply_forget_factor_to_C)
|
|
add_eye_to_inv(R, (1-forget_factor)/C);
|
|
|
|
// R should always be symmetric. This line improves numeric stability of this algorithm.
|
|
if (cnt%10 == 0)
|
|
R = 0.5*(R + trans(R));
|
|
++cnt;
|
|
|
|
w = w + R*x*(y - trans(x)*w);
|
|
|
|
}
|
|
|
|
|
|
|
|
const matrix<double,0,1>& get_w(
|
|
) const
|
|
{
|
|
return w;
|
|
}
|
|
|
|
template <typename EXP>
|
|
double operator() (
|
|
const matrix_exp<EXP>& x
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_col_vector(x) && get_w().size() == x.size(),
|
|
"\t double rls::operator()()"
|
|
<< "\n\t invalid arguments were given to this function"
|
|
<< "\n\t is_col_vector(x): " << is_col_vector(x)
|
|
<< "\n\t x.size(): " << x.size()
|
|
<< "\n\t get_w().size(): " << get_w().size()
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
return dot(x,w);
|
|
}
|
|
|
|
decision_function<linear_kernel<matrix<double,0,1> > > get_decision_function (
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(get_w().size() != 0,
|
|
"\t decision_function rls::get_decision_function()"
|
|
<< "\n\t invalid arguments were given to this function"
|
|
<< "\n\t get_w().size(): " << get_w().size()
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
decision_function<linear_kernel<matrix<double,0,1> > > df;
|
|
df.alpha.set_size(1);
|
|
df.basis_vectors.set_size(1);
|
|
df.b = 0;
|
|
df.alpha = 1;
|
|
df.basis_vectors(0) = w;
|
|
|
|
return df;
|
|
}
|
|
|
|
friend inline void serialize(const rls& item, std::ostream& out)
|
|
{
|
|
int version = 2;
|
|
serialize(version, out);
|
|
serialize(item.w, out);
|
|
serialize(item.R, out);
|
|
serialize(item.C, out);
|
|
serialize(item.forget_factor, out);
|
|
serialize(item.cnt, out);
|
|
serialize(item.apply_forget_factor_to_C, out);
|
|
}
|
|
|
|
friend inline void deserialize(rls& item, std::istream& in)
|
|
{
|
|
int version = 0;
|
|
deserialize(version, in);
|
|
if (!(1 <= version && version <= 2))
|
|
throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
|
|
|
|
if (version >= 1)
|
|
{
|
|
deserialize(item.w, in);
|
|
deserialize(item.R, in);
|
|
deserialize(item.C, in);
|
|
deserialize(item.forget_factor, in);
|
|
}
|
|
item.cnt = 0;
|
|
item.apply_forget_factor_to_C = false;
|
|
if (version >= 2)
|
|
{
|
|
deserialize(item.cnt, in);
|
|
deserialize(item.apply_forget_factor_to_C, in);
|
|
}
|
|
}
|
|
|
|
private:
|
|
|
|
void add_eye_to_inv(
|
|
matrix<double>& m,
|
|
double C
|
|
)
|
|
/*!
|
|
ensures
|
|
- Let m == inv(M)
|
|
- this function returns inv(M + C*identity_matrix<double>(m.nr()))
|
|
!*/
|
|
{
|
|
for (long r = 0; r < m.nr(); ++r)
|
|
{
|
|
m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r));
|
|
}
|
|
}
|
|
|
|
|
|
matrix<double,0,1> w;
|
|
matrix<double> R;
|
|
double C;
|
|
double forget_factor;
|
|
int cnt = 0;
|
|
bool apply_forget_factor_to_C;
|
|
|
|
|
|
// This object is here only to avoid reallocation during training. It don't
|
|
// logically contribute to the state of this object.
|
|
matrix<double,0,1> tmp;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_RLs_Hh_
|
|
|