123 lines
3.5 KiB
C++
123 lines
3.5 KiB
C++
// Copyright (C) 2016 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
|
|
#include <dlib/optimization/elastic_net.h>
|
|
#include "tester.h"
|
|
#include <dlib/svm.h>
|
|
#include <dlib/rand.h>
|
|
#include <dlib/string.h>
|
|
#include <vector>
|
|
#include <sstream>
|
|
#include <ctime>
|
|
|
|
namespace
|
|
{
|
|
using namespace test;
|
|
using namespace dlib;
|
|
using namespace std;
|
|
dlib::logger dlog("test.elastic_net");
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
matrix<double,0,1> basic_elastic_net(
|
|
const matrix<double>& X,
|
|
const matrix<double,0,1>& Y,
|
|
double ridge_lambda,
|
|
double lasso_budget,
|
|
double eps
|
|
)
|
|
{
|
|
DLIB_CASSERT(X.nc() == Y.nr(),"");
|
|
|
|
|
|
typedef matrix<double,0,1> sample_type;
|
|
typedef linear_kernel<sample_type> kernel_type;
|
|
|
|
svm_c_linear_dcd_trainer<kernel_type> trainer;
|
|
trainer.solve_svm_l2_problem(true);
|
|
const double C = 1/(2*ridge_lambda);
|
|
trainer.set_c(C);
|
|
trainer.set_epsilon(eps);
|
|
trainer.enable_shrinking(true);
|
|
trainer.include_bias(false);
|
|
|
|
|
|
std::vector<sample_type> samples;
|
|
std::vector<double> labels;
|
|
for (long r = 0; r < X.nr(); ++r)
|
|
{
|
|
sample_type temp = trans(rowm(X,r));
|
|
|
|
const double xmul = (1/lasso_budget);
|
|
samples.push_back(temp - xmul*Y);
|
|
labels.push_back(+1);
|
|
samples.push_back(temp + xmul*Y);
|
|
labels.push_back(-1);
|
|
}
|
|
|
|
svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
|
|
auto df = trainer.train(samples, labels, state);
|
|
auto&& alpha = state.get_alpha();
|
|
|
|
matrix<double,0,1> betas(alpha.size()/2);
|
|
for (long i = 0; i < betas.size(); ++i)
|
|
betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]);
|
|
betas /= sum(mat(alpha));
|
|
return betas;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class test_elastic_net : public tester
|
|
{
|
|
public:
|
|
test_elastic_net (
|
|
) :
|
|
tester (
|
|
"test_elastic_net",
|
|
"Run tests on the elastic_net object.",
|
|
0
|
|
)
|
|
{
|
|
}
|
|
|
|
void perform_test (
|
|
)
|
|
{
|
|
matrix<double> w = {1,2,0,4, 0,0,0,0,0, 6, 7,8,0, 9, 0};
|
|
|
|
matrix<double> X = randm(w.size(),1000);
|
|
matrix<double> Y = trans(X)*w;
|
|
Y += 0.1*(randm(Y.nr(), Y.nc())-0.5);
|
|
|
|
|
|
double ridge_lambda = 0.1;
|
|
double lasso_budget = sum(abs(w));
|
|
double eps = 0.0000001;
|
|
|
|
dlib::elastic_net solver(X*trans(X),X*Y);
|
|
solver.set_epsilon(eps);
|
|
|
|
|
|
matrix<double,0,1> results;
|
|
matrix<double,0,1> results2;
|
|
for (double s = 1.2; s > 0.10; s *= 0.9)
|
|
{
|
|
print_spinner();
|
|
dlog << LINFO << "s: "<< s;
|
|
// make sure the two solvers agree.
|
|
results = basic_elastic_net(X, Y, ridge_lambda, lasso_budget*s, eps);
|
|
results2 = solver(ridge_lambda, lasso_budget*s);
|
|
dlog << LINFO << "error: "<< max(abs(results - results2));
|
|
DLIB_TEST(max(abs(results - results2)) < 1e-3);
|
|
}
|
|
}
|
|
} a;
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
|
|
|