390 lines
13 KiB
C++
390 lines
13 KiB
C++
// Copyright (C) 2016 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_ElASTIC_NET_Hh_
|
|
#define DLIB_ElASTIC_NET_Hh_
|
|
|
|
#include "../matrix.h"
|
|
#include "elastic_net_abstract.h"
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class elastic_net
|
|
{
|
|
public:
|
|
|
|
template <typename EXP>
|
|
explicit elastic_net(
|
|
const matrix_exp<EXP>& XX
|
|
) : eps(1e-5), max_iterations(50000), verbose(false)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(XX.size() > 0 &&
|
|
XX.nr() == XX.nc(),
|
|
"\t elastic_net::elastic_net(XX)"
|
|
<< " \n\t XX must be a non-empty square matrix."
|
|
<< " \n\t XX.nr(): " << XX.nr()
|
|
<< " \n\t XX.nc(): " << XX.nc()
|
|
<< " \n\t this: " << this
|
|
);
|
|
|
|
|
|
// If the number of columns in X is big and in particular bigger than the number of
|
|
// rows then we can get rid of them by doing some SVD magic. Doing this doesn't
|
|
// make the final results of anything change but makes all the matrices have
|
|
// dimensions that are X.nr() in size, which can be much smaller.
|
|
matrix<double,0,1> s;
|
|
svd3(XX,u,eig_vals,eig_vects);
|
|
s = sqrt(eig_vals);
|
|
X = eig_vects*diagm(s);
|
|
u = eig_vects*inv(diagm(s));
|
|
|
|
|
|
|
|
samples.resize(X.nr()*2);
|
|
|
|
for (size_t i = 0; i < samples.size(); ++i)
|
|
index.push_back(i);
|
|
active_size = index.size();
|
|
|
|
|
|
// setup the training samples used in the SVM optimizer below
|
|
for (size_t i = 0; i < samples.size(); ++i)
|
|
{
|
|
auto& x = samples[i];
|
|
const long idx = i/2;
|
|
if (i%2 == 0)
|
|
x.label = +1;
|
|
else
|
|
x.label = -1;
|
|
|
|
x.r = idx%X.nr();
|
|
}
|
|
}
|
|
|
|
template <typename EXP1, typename EXP2>
|
|
elastic_net(
|
|
const matrix_exp<EXP1>& XX,
|
|
const matrix_exp<EXP2>& XY
|
|
) : elastic_net(XX)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(XX.size() > 0 &&
|
|
XX.nr() == XX.nc() &&
|
|
is_col_vector(XY) &&
|
|
XX.nc() == XY.size() ,
|
|
"\t elastic_net::elastic_net(XX,XY)"
|
|
<< " \n\t Invalid inputs were given to this function."
|
|
<< " \n\t XX.size(): " << XX.size()
|
|
<< " \n\t is_col_vector(XY): " << is_col_vector(XY)
|
|
<< " \n\t XX.nr(): " << XX.nr()
|
|
<< " \n\t XX.nc(): " << XX.nc()
|
|
<< " \n\t XY.size(): " << XY.size()
|
|
<< " \n\t this: " << this
|
|
);
|
|
|
|
set_xy(XY);
|
|
}
|
|
|
|
long size (
|
|
) const { return u.nr(); }
|
|
|
|
template <typename EXP>
|
|
void set_xy(
|
|
const matrix_exp<EXP>& XY
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_col_vector(XY) &&
|
|
XY.size() == size(),
|
|
"\t void elastic_net::set_y(Y)"
|
|
<< " \n\t Invalid inputs were given to this function."
|
|
<< " \n\t is_col_vector(XY): " << is_col_vector(XY)
|
|
<< " \n\t size(): " << size()
|
|
<< " \n\t XY.size(): " << XY.size()
|
|
<< " \n\t this: " << this
|
|
);
|
|
|
|
Y = trans(u)*XY;
|
|
// We can use the ynorm after it has been projected because the only place Y
|
|
// appears in the algorithm is in terms of dot products with w and x vectors.
|
|
// But those vectors are always in the span of X and therefore we only see the
|
|
// part of the norm of Y that is in the span of X (and hence u since u and X
|
|
// have the same span by construction)
|
|
ynorm = length_squared(Y);
|
|
xdoty = X*Y;
|
|
eig_vects_xdoty = trans(eig_vects)*xdoty;
|
|
|
|
w.set_size(Y.size());
|
|
// zero out any memory of previous solutions
|
|
alpha.assign(X.nr()*2, 0);
|
|
}
|
|
|
|
bool have_target_values (
|
|
) const { return Y.size() != 0; }
|
|
|
|
void set_epsilon(
|
|
double eps_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(eps_ > 0,
|
|
"\t void elastic_net::set_epsilon()"
|
|
<< " \n\t eps_ must be greater than 0"
|
|
<< " \n\t eps_: " << eps_
|
|
<< " \n\t this: " << this
|
|
);
|
|
|
|
eps = eps_;
|
|
}
|
|
|
|
unsigned long get_max_iterations (
|
|
) const { return max_iterations; }
|
|
|
|
void set_max_iterations (
|
|
unsigned long max_iter
|
|
)
|
|
{
|
|
max_iterations = max_iter;
|
|
}
|
|
|
|
void be_verbose (
|
|
)
|
|
{
|
|
verbose = true;
|
|
}
|
|
|
|
void be_quiet (
|
|
)
|
|
{
|
|
verbose = false;
|
|
}
|
|
|
|
double get_epsilon (
|
|
) const { return eps; }
|
|
|
|
matrix<double,0,1> operator() (
|
|
double ridge_lambda,
|
|
double lasso_budget = std::numeric_limits<double>::infinity()
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(have_target_values() &&
|
|
ridge_lambda > 0 &&
|
|
lasso_budget > 0 ,
|
|
"\t matrix<double,0,1> elastic_net::operator()()"
|
|
<< " \n\t Invalid inputs were given to this function."
|
|
<< " \n\t have_target_values(): " << have_target_values()
|
|
<< " \n\t ridge_lambda: " << ridge_lambda
|
|
<< " \n\t lasso_budget: " << lasso_budget
|
|
<< " \n\t this: " << this
|
|
);
|
|
|
|
|
|
// First check if lasso_budget is so big that it isn't even active. We do this
|
|
// by doing just ridge regression and checking the result.
|
|
matrix<double,0,1> betas = eig_vects*tmp(inv(diagm(eig_vals + ridge_lambda))*eig_vects_xdoty);
|
|
if (sum(abs(betas)) <= lasso_budget)
|
|
return betas;
|
|
|
|
|
|
// Set w back to 0. We will compute the w corresponding to what is currently
|
|
// in alpha layer on. This way w and alpha are always in sync.
|
|
w = 0;
|
|
wy_mult = 0;
|
|
wdoty = 0;
|
|
|
|
|
|
// return dot(w,x)
|
|
auto dot = [&](const matrix<double,0,1>& w, const en_sample2& x)
|
|
{
|
|
const double xmul = -x.label*(1/lasso_budget);
|
|
// Do the base dot product but don't forget to add in the -(1/t)*y part from the svm reduction paper
|
|
double val = rowm(X,x.r)*w + xmul*wdoty + wy_mult*xdoty(x.r) + xmul*wy_mult*ynorm;
|
|
|
|
return val;
|
|
};
|
|
|
|
|
|
// perform w += scale*x;
|
|
auto add_to = [&](matrix<double,0,1>& w, double scale, const en_sample2& x)
|
|
{
|
|
const double xmul = -x.label*(1/lasso_budget);
|
|
wy_mult += scale*xmul;
|
|
wdoty += scale*xdoty(x.r);
|
|
w += scale*trans(rowm(X,x.r));
|
|
|
|
};
|
|
|
|
const double Dii = ridge_lambda;
|
|
|
|
// setup the training samples used in the SVM optimizer below
|
|
for (size_t i = 0; i < samples.size(); ++i)
|
|
{
|
|
auto& x = samples[i];
|
|
|
|
const double xmul = -x.label*(1/lasso_budget);
|
|
x.xdotx = xmul*xmul*ynorm;
|
|
for (long c = 0; c < X.nc(); ++c)
|
|
x.xdotx += std::pow(X(x.r,c)+xmul*Y(c), 2.0) - std::pow(xmul*Y(c),2.0);
|
|
|
|
// compute the correct w given whatever might be in alpha.
|
|
if (alpha[i] != 0)
|
|
add_to(w, x.label*alpha[i], samples[i]);
|
|
}
|
|
|
|
|
|
// Now run the optimizer
|
|
double PG_max_prev = std::numeric_limits<double>::infinity();
|
|
double PG_min_prev = -std::numeric_limits<double>::infinity();
|
|
|
|
|
|
unsigned int iter;
|
|
for (iter = 0; iter < max_iterations; ++iter)
|
|
{
|
|
// randomly shuffle the indices
|
|
for (unsigned long i = 0; i < active_size; ++i)
|
|
{
|
|
// pick a random index >= i
|
|
const long j = i + rnd.get_random_32bit_number()%(active_size-i);
|
|
std::swap(index[i], index[j]);
|
|
}
|
|
|
|
double PG_max = -std::numeric_limits<double>::infinity();
|
|
double PG_min = std::numeric_limits<double>::infinity();
|
|
for (size_t ii = 0; ii < active_size; ++ii)
|
|
{
|
|
const auto i = index[ii];
|
|
const auto& x = samples[i];
|
|
double G = x.label*dot(w, x) - 1 + Dii*alpha[i];
|
|
|
|
double PG = 0;
|
|
if (alpha[i] == 0)
|
|
{
|
|
if (G > PG_max_prev)
|
|
{
|
|
// shrink the active set of training examples
|
|
--active_size;
|
|
std::swap(index[ii], index[active_size]);
|
|
--ii;
|
|
continue;
|
|
}
|
|
|
|
if (G < 0)
|
|
PG = G;
|
|
}
|
|
else
|
|
{
|
|
PG = G;
|
|
}
|
|
|
|
if (PG > PG_max)
|
|
PG_max = PG;
|
|
if (PG < PG_min)
|
|
PG_min = PG;
|
|
|
|
// if PG != 0
|
|
if (std::abs(PG) > 1e-12)
|
|
{
|
|
const double alpha_old = alpha[i];
|
|
alpha[i] = std::max(alpha[i] - G/(x.xdotx+Dii), (double)0.0);
|
|
const double delta = (alpha[i]-alpha_old)*x.label;
|
|
add_to(w, delta, x);
|
|
}
|
|
}
|
|
|
|
if (verbose)
|
|
{
|
|
using namespace std;
|
|
cout << "gap: " << PG_max - PG_min << endl;
|
|
cout << "active_size: " << active_size << endl;
|
|
cout << "iter: " << iter << endl;
|
|
cout << endl;
|
|
}
|
|
|
|
if (PG_max - PG_min <= eps)
|
|
{
|
|
// stop if we are within eps tolerance and the last iteration
|
|
// was over all the samples
|
|
if (active_size == index.size())
|
|
break;
|
|
|
|
// Turn off shrinking on the next iteration. We will stop if the
|
|
// tolerance is still <= eps when shrinking is off.
|
|
active_size = index.size();
|
|
PG_max_prev = std::numeric_limits<double>::infinity();
|
|
PG_min_prev = -std::numeric_limits<double>::infinity();
|
|
}
|
|
else
|
|
{
|
|
PG_max_prev = PG_max;
|
|
PG_min_prev = PG_min;
|
|
if (PG_max_prev <= 0)
|
|
PG_max_prev = std::numeric_limits<double>::infinity();
|
|
if (PG_min_prev >= 0)
|
|
PG_min_prev = -std::numeric_limits<double>::infinity();
|
|
}
|
|
|
|
|
|
// recalculate wdoty every so often to avoid drift.
|
|
if (iter%100 == 0)
|
|
wdoty = dlib::dot(Y, w);
|
|
}
|
|
|
|
|
|
betas.set_size(alpha.size()/2);
|
|
for (long i = 0; i < betas.size(); ++i)
|
|
betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]);
|
|
betas /= sum(mat(alpha));
|
|
return betas;
|
|
}
|
|
|
|
|
|
private:
|
|
|
|
struct en_sample2
|
|
{
|
|
// X location
|
|
long r;
|
|
|
|
|
|
double label;
|
|
|
|
double xdotx;
|
|
};
|
|
|
|
std::vector<en_sample2> samples;
|
|
std::vector<double> alpha;
|
|
double ynorm;
|
|
matrix<double> X;
|
|
matrix<double,0,1> Y;
|
|
matrix<double,0,1> xdoty;
|
|
double wdoty;
|
|
double wy_mult; // logically, the real w is what is in the w vector + wy_mult*Y
|
|
matrix<double,0,1> w;
|
|
std::vector<long> index;
|
|
unsigned long active_size;
|
|
|
|
matrix<double,0,1> eig_vects_xdoty;
|
|
matrix<double,0,1> eig_vals;
|
|
matrix<double> eig_vects;
|
|
matrix<double> u;
|
|
|
|
dlib::rand rnd;
|
|
|
|
|
|
double eps;
|
|
unsigned long max_iterations;
|
|
bool verbose;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_ElASTIC_NET_Hh_
|
|
|
|
|