362 lines
11 KiB
C++
362 lines
11 KiB
C++
// Copyright (C) 2009 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
// This code was adapted from code from the JAMA part of NIST's TNT library.
|
|
// See: http://math.nist.gov/tnt/
|
|
#ifndef DLIB_MATRIX_LU_DECOMPOSITION_H
|
|
#define DLIB_MATRIX_LU_DECOMPOSITION_H
|
|
|
|
#include "matrix.h"
|
|
#include "matrix_utilities.h"
|
|
#include "matrix_subexp.h"
|
|
#include "matrix_trsm.h"
|
|
#include <algorithm>
|
|
|
|
#ifdef DLIB_USE_LAPACK
|
|
#include "lapack/getrf.h"
|
|
#endif
|
|
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
template <
|
|
typename matrix_exp_type
|
|
>
|
|
class lu_decomposition
|
|
{
|
|
public:
|
|
|
|
const static long NR = matrix_exp_type::NR;
|
|
const static long NC = matrix_exp_type::NC;
|
|
typedef typename matrix_exp_type::type type;
|
|
typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
|
|
typedef typename matrix_exp_type::layout_type layout_type;
|
|
|
|
typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
|
|
typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
|
|
typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type;
|
|
|
|
// You have supplied an invalid type of matrix_exp_type. You have
|
|
// to use this object with matrices that contain float or double type data.
|
|
COMPILE_TIME_ASSERT((is_same_type<float, type>::value ||
|
|
is_same_type<double, type>::value ));
|
|
|
|
template <typename EXP>
|
|
lu_decomposition (
|
|
const matrix_exp<EXP> &A
|
|
);
|
|
|
|
bool is_square (
|
|
) const;
|
|
|
|
bool is_singular (
|
|
) const;
|
|
|
|
long nr(
|
|
) const;
|
|
|
|
long nc(
|
|
) const;
|
|
|
|
const matrix_type get_l (
|
|
) const;
|
|
|
|
const matrix_type get_u (
|
|
) const;
|
|
|
|
const pivot_column_vector_type& get_pivot (
|
|
) const;
|
|
|
|
type det (
|
|
) const;
|
|
|
|
template <typename EXP>
|
|
const matrix_type solve (
|
|
const matrix_exp<EXP> &B
|
|
) const;
|
|
|
|
private:
|
|
|
|
/* Array for internal storage of decomposition. */
|
|
matrix<type,0,0,mem_manager_type,column_major_layout> LU;
|
|
long m, n, pivsign;
|
|
pivot_column_vector_type piv;
|
|
|
|
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
// ----------------------------------------------------------------------------------------
|
|
// Public member functions
|
|
// ----------------------------------------------------------------------------------------
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
template <typename EXP>
|
|
lu_decomposition<matrix_exp_type>::
|
|
lu_decomposition (
|
|
const matrix_exp<EXP>& A
|
|
) :
|
|
LU(A),
|
|
m(A.nr()),
|
|
n(A.nc())
|
|
{
|
|
using namespace std;
|
|
using std::abs;
|
|
|
|
COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
|
|
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(A.size() > 0,
|
|
"\tlu_decomposition::lu_decomposition(A)"
|
|
<< "\n\tInvalid inputs were given to this function"
|
|
<< "\n\tA.size(): " << A.size()
|
|
<< "\n\tthis: " << this
|
|
);
|
|
|
|
#ifdef DLIB_USE_LAPACK
|
|
matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp;
|
|
lapack::getrf(LU, piv_temp);
|
|
|
|
pivsign = 1;
|
|
|
|
// Turn the piv_temp vector into a more useful form. This way we will have the identity
|
|
// rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat
|
|
// different.
|
|
piv = trans(range(0,m-1));
|
|
for (long i = 0; i < piv_temp.size(); ++i)
|
|
{
|
|
// -1 because FORTRAN is indexed starting with 1 instead of 0
|
|
if (piv(piv_temp(i)-1) != piv(i))
|
|
{
|
|
std::swap(piv(i), piv(piv_temp(i)-1));
|
|
pivsign = -pivsign;
|
|
}
|
|
}
|
|
|
|
#else
|
|
|
|
// Use a "left-looking", dot-product, Crout/Doolittle algorithm.
|
|
|
|
|
|
piv = trans(range(0,m-1));
|
|
pivsign = 1;
|
|
|
|
column_vector_type LUcolj(m);
|
|
|
|
// Outer loop.
|
|
for (long j = 0; j < n; j++)
|
|
{
|
|
|
|
// Make a copy of the j-th column to localize references.
|
|
LUcolj = colm(LU,j);
|
|
|
|
// Apply previous transformations.
|
|
for (long i = 0; i < m; i++)
|
|
{
|
|
// Most of the time is spent in the following dot product.
|
|
const long kmax = std::min(i,j);
|
|
type s;
|
|
if (kmax > 0)
|
|
s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax);
|
|
else
|
|
s = 0;
|
|
|
|
LU(i,j) = LUcolj(i) -= s;
|
|
}
|
|
|
|
// Find pivot and exchange if necessary.
|
|
long p = j;
|
|
for (long i = j+1; i < m; i++)
|
|
{
|
|
if (abs(LUcolj(i)) > abs(LUcolj(p)))
|
|
{
|
|
p = i;
|
|
}
|
|
}
|
|
if (p != j)
|
|
{
|
|
long k=0;
|
|
for (k = 0; k < n; k++)
|
|
{
|
|
type t = LU(p,k);
|
|
LU(p,k) = LU(j,k);
|
|
LU(j,k) = t;
|
|
}
|
|
k = piv(p);
|
|
piv(p) = piv(j);
|
|
piv(j) = k;
|
|
pivsign = -pivsign;
|
|
}
|
|
|
|
// Compute multipliers.
|
|
if ((j < m) && (LU(j,j) != 0.0))
|
|
{
|
|
for (long i = j+1; i < m; i++)
|
|
{
|
|
LU(i,j) /= LU(j,j);
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
bool lu_decomposition<matrix_exp_type>::
|
|
is_square (
|
|
) const
|
|
{
|
|
return m == n;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
long lu_decomposition<matrix_exp_type>::
|
|
nr (
|
|
) const
|
|
{
|
|
return m;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
long lu_decomposition<matrix_exp_type>::
|
|
nc (
|
|
) const
|
|
{
|
|
return n;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
bool lu_decomposition<matrix_exp_type>::
|
|
is_singular (
|
|
) const
|
|
{
|
|
/* Is the matrix singular?
|
|
if upper triangular factor U (and hence A) is singular, false otherwise.
|
|
*/
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_square() == true,
|
|
"\tbool lu_decomposition::is_singular()"
|
|
<< "\n\tYou can only use this on square matrices"
|
|
<< "\n\tthis: " << this
|
|
);
|
|
|
|
type max_val, min_val;
|
|
find_min_and_max (abs(diag(LU)), min_val, max_val);
|
|
type eps = max_val;
|
|
if (eps != 0)
|
|
eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10;
|
|
else
|
|
eps = 1; // there is no max so just use 1
|
|
|
|
return min_val < eps;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
|
|
get_l (
|
|
) const
|
|
{
|
|
if (LU.nr() >= LU.nc())
|
|
return lowerm(LU,1.0);
|
|
else
|
|
return lowerm(subm(LU,0,0,m,m), 1.0);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
|
|
get_u (
|
|
) const
|
|
{
|
|
if (LU.nr() >= LU.nc())
|
|
return upperm(subm(LU,0,0,n,n));
|
|
else
|
|
return upperm(LU);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>::
|
|
get_pivot (
|
|
) const
|
|
{
|
|
return piv;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>::
|
|
det (
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_square() == true,
|
|
"\ttype lu_decomposition::det()"
|
|
<< "\n\tYou can only use this on square matrices"
|
|
<< "\n\tthis: " << this
|
|
);
|
|
|
|
// Check if it is singular and if it is just return 0.
|
|
// We want to do this because a prod() operation can easily
|
|
// overcome a single diagonal element that is effectively 0 when
|
|
// LU is a big enough matrix.
|
|
if (is_singular())
|
|
return 0;
|
|
|
|
return prod(diag(LU))*static_cast<type>(pivsign);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename matrix_exp_type>
|
|
template <typename EXP>
|
|
const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
|
|
solve (
|
|
const matrix_exp<EXP> &B
|
|
) const
|
|
{
|
|
COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
|
|
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_square() == true && B.nr() == nr(),
|
|
"\ttype lu_decomposition::solve()"
|
|
<< "\n\tInvalid arguments to this function"
|
|
<< "\n\tis_square(): " << (is_square()? "true":"false" )
|
|
<< "\n\tB.nr(): " << B.nr()
|
|
<< "\n\tnr(): " << nr()
|
|
<< "\n\tthis: " << this
|
|
);
|
|
|
|
// Copy right hand side with pivoting
|
|
matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv));
|
|
|
|
using namespace blas_bindings;
|
|
// Solve L*Y = B(piv,:)
|
|
triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X);
|
|
// Solve U*X = Y;
|
|
triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X);
|
|
return X;
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_MATRIX_LU_DECOMPOSITION_H
|
|
|
|
|