197 lines
6.1 KiB
C++
197 lines
6.1 KiB
C++
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
|
|
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <cstdlib>
|
|
#include <ctime>
|
|
#include <cmath>
|
|
#include <dlib/svm.h>
|
|
|
|
#include "tester.h"
|
|
|
|
namespace
|
|
{
|
|
|
|
using namespace test;
|
|
using namespace dlib;
|
|
using namespace std;
|
|
|
|
logger dlog("test.rls");
|
|
|
|
|
|
void test_rls()
|
|
{
|
|
dlib::rand rnd;
|
|
|
|
running_stats<double> rs1, rs2, rs3, rs4, rs5;
|
|
|
|
for (int k = 0; k < 2; ++k)
|
|
{
|
|
for (long num_vars = 1; num_vars < 4; ++num_vars)
|
|
{
|
|
print_spinner();
|
|
for (long size = 1; size < 300; ++size)
|
|
{
|
|
{
|
|
matrix<double> X = randm(size,num_vars,rnd);
|
|
matrix<double,0,1> Y = randm(size,1,rnd);
|
|
|
|
|
|
const double C = 1000;
|
|
const double forget_factor = 1.0;
|
|
rls r(forget_factor, C);
|
|
for (long i = 0; i < Y.size(); ++i)
|
|
{
|
|
r.train(trans(rowm(X,i)), Y(i));
|
|
}
|
|
|
|
|
|
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
|
|
|
|
rs1.add(length(r.get_w() - w));
|
|
}
|
|
|
|
{
|
|
matrix<double> X = randm(size,num_vars,rnd);
|
|
matrix<double,0,1> Y = randm(size,1,rnd);
|
|
|
|
matrix<double,0,1> G(size,1);
|
|
|
|
const double C = 10000;
|
|
const double forget_factor = 0.8;
|
|
rls r(forget_factor, C);
|
|
for (long i = 0; i < Y.size(); ++i)
|
|
{
|
|
r.train(trans(rowm(X,i)), Y(i));
|
|
|
|
G(i) = std::pow(forget_factor, i/2.0);
|
|
}
|
|
|
|
G = flipud(G);
|
|
|
|
X = diagm(G)*X;
|
|
Y = diagm(G)*Y;
|
|
|
|
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
|
|
|
|
rs5.add(length(r.get_w() - w));
|
|
}
|
|
|
|
{
|
|
matrix<double> X = randm(size,num_vars,rnd);
|
|
matrix<double> Y = colm(X,0)*10;
|
|
|
|
|
|
const double C = 1000000;
|
|
const double forget_factor = 1.0;
|
|
rls r(forget_factor, C);
|
|
for (long i = 0; i < Y.size(); ++i)
|
|
{
|
|
r.train(trans(rowm(X,i)), Y(i));
|
|
}
|
|
|
|
|
|
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
|
|
|
|
rs2.add(length(r.get_w() - w));
|
|
}
|
|
|
|
{
|
|
matrix<double> X = join_rows(randm(size,num_vars,rnd)-0.5, ones_matrix<double>(size,1));
|
|
matrix<double> Y = uniform_matrix<double>(size,1,10);
|
|
|
|
|
|
const double C = 1e7;
|
|
const double forget_factor = 1.0;
|
|
|
|
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
|
|
|
|
rls r(forget_factor, C);
|
|
for (long i = 0; i < Y.size(); ++i)
|
|
{
|
|
r.train(trans(rowm(X,i)), Y(i));
|
|
rs3.add(std::abs(r(trans(rowm(X,i))) - 10));
|
|
}
|
|
|
|
|
|
}
|
|
{
|
|
matrix<double> X = randm(size,num_vars,rnd)-0.5;
|
|
matrix<double> Y = colm(X,0)*10;
|
|
|
|
|
|
const double C = 1e6;
|
|
const double forget_factor = 0.7;
|
|
|
|
|
|
rls r(forget_factor, C);
|
|
DLIB_TEST(std::abs(r.get_c() - C) < 1e-10);
|
|
DLIB_TEST(std::abs(r.get_forget_factor() - forget_factor) < 1e-15);
|
|
DLIB_TEST(r.get_w().size() == 0);
|
|
|
|
for (long i = 0; i < Y.size(); ++i)
|
|
{
|
|
r.train(trans(rowm(X,i)), Y(i));
|
|
rs4.add(std::abs(r(trans(rowm(X,i))) - X(i,0)*10));
|
|
}
|
|
|
|
DLIB_TEST(r.get_w().size() == num_vars);
|
|
|
|
decision_function<linear_kernel<matrix<double,0,1> > > df = r.get_decision_function();
|
|
DLIB_TEST(std::abs(df(trans(rowm(X,0))) - r(trans(rowm(X,0)))) < 1e-15);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
dlog << LINFO << "rs1.mean(): " << rs1.mean();
|
|
dlog << LINFO << "rs2.mean(): " << rs2.mean();
|
|
dlog << LINFO << "rs3.mean(): " << rs3.mean();
|
|
dlog << LINFO << "rs4.mean(): " << rs4.mean();
|
|
dlog << LINFO << "rs5.mean(): " << rs5.mean();
|
|
dlog << LINFO << "rs1.max(): " << rs1.max();
|
|
dlog << LINFO << "rs2.max(): " << rs2.max();
|
|
dlog << LINFO << "rs3.max(): " << rs3.max();
|
|
dlog << LINFO << "rs4.max(): " << rs4.max();
|
|
dlog << LINFO << "rs5.max(): " << rs5.max();
|
|
|
|
DLIB_TEST_MSG(rs1.mean() < 1e-10, rs1.mean());
|
|
DLIB_TEST_MSG(rs2.mean() < 1e-9, rs2.mean());
|
|
DLIB_TEST_MSG(rs3.mean() < 1e-6, rs3.mean());
|
|
DLIB_TEST_MSG(rs4.mean() < 1e-6, rs4.mean());
|
|
DLIB_TEST_MSG(rs5.mean() < 1e-3, rs5.mean());
|
|
|
|
DLIB_TEST_MSG(rs1.max() < 1e-10, rs1.max());
|
|
DLIB_TEST_MSG(rs2.max() < 1e-6, rs2.max());
|
|
DLIB_TEST_MSG(rs3.max() < 0.001, rs3.max());
|
|
DLIB_TEST_MSG(rs4.max() < 0.01, rs4.max());
|
|
DLIB_TEST_MSG(rs5.max() < 0.1, rs5.max());
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class rls_tester : public tester
|
|
{
|
|
public:
|
|
rls_tester (
|
|
) :
|
|
tester ("test_rls",
|
|
"Runs tests on the rls component.")
|
|
{}
|
|
|
|
void perform_test (
|
|
)
|
|
{
|
|
test_rls();
|
|
}
|
|
} a;
|
|
|
|
}
|
|
|
|
|
|
|