// Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include #include #include #include #include #include "../stl_checked.h" #include "../array.h" #include "../rand.h" #include #include "tester.h" namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.matrix_chol"); dlib::rand rnd; // ---------------------------------------------------------------------------------------- template const matrix symm(const mat_type& m) { return m*trans(m); } // ---------------------------------------------------------------------------------------- template const matrix randmat(long r, long c) { matrix m(r,c); for (long row = 0; row < m.nr(); ++row) { for (long col = 0; col < m.nc(); ++col) { m(row,col) = static_cast(rnd.get_random_double()); } } return m; } template const matrix randmat() { matrix m; for (long row = 0; row < m.nr(); ++row) { for (long col = 0; col < m.nc(); ++col) { m(row,col) = static_cast(rnd.get_random_double()); } } return m; } // ---------------------------------------------------------------------------------------- template void test_cholesky ( const matrix_type& m) { typedef typename matrix_type::type type; const type eps = 10*max(abs(m))*sqrt(std::numeric_limits::epsilon()); dlog << LDEBUG << "test_cholesky(): " << m.nr() << " x " << m.nc() << " eps: " << eps; print_spinner(); cholesky_decomposition test(m); // none of the matrices we should be passing in to test_cholesky() should be non-spd. DLIB_TEST(test.is_spd() == true); type temp; DLIB_TEST_MSG( (temp= max(abs(test.get_l()*trans(test.get_l()) - m))) < eps,temp); { matrix mat = chol(m); DLIB_TEST_MSG( (temp= max(abs(mat*trans(mat) - m))) < eps,temp); } matrix m2; matrix col; m2 = identity_matrix(m.nr()); DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); m2 = randmat(m.nr(),5); DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); m2 = randmat(m.nr(),1); DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); col = randmat(m.nr(),1); DLIB_TEST_MSG(equal(m*test.solve(col), col,eps),max(abs(m*test.solve(m2)- m2))); // now make us a non-spd matrix if (m.nr() > 2) { matrix sm(lowerm(m)); sm(1,1) = 0; cholesky_decomposition test2(sm); DLIB_TEST_MSG(test2.is_spd() == false, test2.get_l()); cholesky_decomposition test3(sm*trans(sm)); DLIB_TEST_MSG(test3.is_spd() == false, test3.get_l()); sm = sm*trans(sm); sm(1,1) = 5; sm(1,0) -= 1; cholesky_decomposition test4(sm); DLIB_TEST_MSG(test4.is_spd() == false, test4.get_l()); } } // ---------------------------------------------------------------------------------------- void matrix_test_double() { test_cholesky(uniform_matrix(1,1,1) + 10*symm(randmat(1,1))); test_cholesky(uniform_matrix(2,2,1) + 10*symm(randmat(2,2))); test_cholesky(uniform_matrix(3,3,1) + 10*symm(randmat(3,3))); test_cholesky(uniform_matrix(4,4,1) + 10*symm(randmat(4,4))); test_cholesky(uniform_matrix(15,15,1) + 10*symm(randmat(15,15))); test_cholesky(uniform_matrix(101,101,1) + 10*symm(randmat(101,101))); typedef matrix mat; test_cholesky(mat(uniform_matrix(101,101,1) + 10*symm(randmat(101,101)))); } // ---------------------------------------------------------------------------------------- void matrix_test_float() { test_cholesky(uniform_matrix(1,1,1) + 2*symm(randmat(1,1))); test_cholesky(uniform_matrix(2,2,1) + 2*symm(randmat(2,2))); test_cholesky(uniform_matrix(3,3,1) + 2*symm(randmat(3,3))); typedef matrix mat; test_cholesky(mat(uniform_matrix(3,3,1) + 2*symm(randmat(3,3)))); } // ---------------------------------------------------------------------------------------- class matrix_tester : public tester { public: matrix_tester ( ) : tester ("test_matrix_chol", "Runs tests on the matrix cholesky component.") { //rnd.set_seed(cast_to_string(time(0))); } void perform_test ( ) { dlog << LINFO << "seed string: " << rnd.get_seed(); dlog << LINFO << "begin testing with double"; matrix_test_double(); dlog << LINFO << "begin testing with float"; matrix_test_float(); } } a; }