365 lines
14 KiB
C++
365 lines
14 KiB
C++
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_LAPACk_SDD_Hh_
|
|
#define DLIB_LAPACk_SDD_Hh_
|
|
|
|
#include "fortran_id.h"
|
|
#include "../matrix.h"
|
|
|
|
namespace dlib
|
|
{
|
|
namespace lapack
|
|
{
|
|
namespace binding
|
|
{
|
|
extern "C"
|
|
{
|
|
void DLIB_FORTRAN_ID(dgesdd) (char const* jobz,
|
|
const integer* m, const integer* n, double* a, const integer* lda,
|
|
double* s, double* u, const integer* ldu,
|
|
double* vt, const integer* ldvt,
|
|
double* work, const integer* lwork, integer* iwork, integer* info);
|
|
|
|
void DLIB_FORTRAN_ID(sgesdd) (char const* jobz,
|
|
const integer* m, const integer* n, float* a, const integer* lda,
|
|
float* s, float* u, const integer* ldu,
|
|
float* vt, const integer* ldvt,
|
|
float* work, const integer* lwork, integer* iwork, integer* info);
|
|
|
|
}
|
|
|
|
inline integer gesdd (const char jobz,
|
|
const integer m, const integer n, double* a, const integer lda,
|
|
double* s, double* u, const integer ldu,
|
|
double* vt, const integer ldvt,
|
|
double* work, const integer lwork, integer* iwork)
|
|
{
|
|
integer info = 0;
|
|
DLIB_FORTRAN_ID(dgesdd)(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info);
|
|
return info;
|
|
}
|
|
|
|
inline integer gesdd (const char jobz,
|
|
const integer m, const integer n, float* a, const integer lda,
|
|
float* s, float* u, const integer ldu,
|
|
float* vt, const integer ldvt,
|
|
float* work, const integer lwork, integer* iwork)
|
|
{
|
|
integer info = 0;
|
|
DLIB_FORTRAN_ID(sgesdd)(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info);
|
|
return info;
|
|
}
|
|
}
|
|
|
|
// ------------------------------------------------------------------------------------
|
|
|
|
/* -- LAPACK driver routine (version 3.1) -- */
|
|
/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */
|
|
/* November 2006 */
|
|
|
|
/* .. Scalar Arguments .. */
|
|
/* .. */
|
|
/* .. Array Arguments .. */
|
|
/* .. */
|
|
|
|
/* Purpose */
|
|
/* ======= */
|
|
|
|
/* DGESDD computes the singular value decomposition (SVD) of a real */
|
|
/* M-by-N matrix A, optionally computing the left and right singular */
|
|
/* vectors. If singular vectors are desired, it uses a */
|
|
/* divide-and-conquer algorithm. */
|
|
|
|
/* The SVD is written */
|
|
|
|
/* A = U * SIGMA * transpose(V) */
|
|
|
|
/* where SIGMA is an M-by-N matrix which is zero except for its */
|
|
/* min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and */
|
|
/* V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA */
|
|
/* are the singular values of A; they are real and non-negative, and */
|
|
/* are returned in descending order. The first min(m,n) columns of */
|
|
/* U and V are the left and right singular vectors of A. */
|
|
|
|
/* Note that the routine returns VT = V**T, not V. */
|
|
|
|
/* The divide and conquer algorithm makes very mild assumptions about */
|
|
/* floating point arithmetic. It will work on machines with a guard */
|
|
/* digit in add/subtract, or on those binary machines without guard */
|
|
/* digits which subtract like the Cray X-MP, Cray Y-MP, Cray C-90, or */
|
|
/* Cray-2. It could conceivably fail on hexadecimal or decimal machines */
|
|
/* without guard digits, but we know of none. */
|
|
|
|
/* Arguments */
|
|
/* ========= */
|
|
|
|
/* JOBZ (input) CHARACTER*1 */
|
|
/* Specifies options for computing all or part of the matrix U: */
|
|
/* = 'A': all M columns of U and all N rows of V**T are */
|
|
/* returned in the arrays U and VT; */
|
|
/* = 'S': the first min(M,N) columns of U and the first */
|
|
/* min(M,N) rows of V**T are returned in the arrays U */
|
|
/* and VT; */
|
|
/* = 'O': If M >= N, the first N columns of U are overwritten */
|
|
/* on the array A and all rows of V**T are returned in */
|
|
/* the array VT; */
|
|
/* otherwise, all columns of U are returned in the */
|
|
/* array U and the first M rows of V**T are overwritten */
|
|
/* in the array A; */
|
|
/* = 'N': no columns of U or rows of V**T are computed. */
|
|
|
|
/* M (input) INTEGER */
|
|
/* The number of rows of the input matrix A. M >= 0. */
|
|
|
|
/* N (input) INTEGER */
|
|
/* The number of columns of the input matrix A. N >= 0. */
|
|
|
|
/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */
|
|
/* On entry, the M-by-N matrix A. */
|
|
/* On exit, */
|
|
/* if JOBZ = 'O', A is overwritten with the first N columns */
|
|
/* of U (the left singular vectors, stored */
|
|
/* columnwise) if M >= N; */
|
|
/* A is overwritten with the first M rows */
|
|
/* of V**T (the right singular vectors, stored */
|
|
/* rowwise) otherwise. */
|
|
/* if JOBZ .ne. 'O', the contents of A are destroyed. */
|
|
|
|
/* LDA (input) INTEGER */
|
|
/* The leading dimension of the array A. LDA >= max(1,M). */
|
|
|
|
/* S (output) DOUBLE PRECISION array, dimension (min(M,N)) */
|
|
/* The singular values of A, sorted so that S(i) >= S(i+1). */
|
|
|
|
/* U (output) DOUBLE PRECISION array, dimension (LDU,UCOL) */
|
|
/* UCOL = M if JOBZ = 'A' or JOBZ = 'O' and M < N; */
|
|
/* UCOL = min(M,N) if JOBZ = 'S'. */
|
|
/* If JOBZ = 'A' or JOBZ = 'O' and M < N, U contains the M-by-M */
|
|
/* orthogonal matrix U; */
|
|
/* if JOBZ = 'S', U contains the first min(M,N) columns of U */
|
|
/* (the left singular vectors, stored columnwise); */
|
|
/* if JOBZ = 'O' and M >= N, or JOBZ = 'N', U is not referenced. */
|
|
|
|
/* LDU (input) INTEGER */
|
|
/* The leading dimension of the array U. LDU >= 1; if */
|
|
/* JOBZ = 'S' or 'A' or JOBZ = 'O' and M < N, LDU >= M. */
|
|
|
|
/* VT (output) DOUBLE PRECISION array, dimension (LDVT,N) */
|
|
/* If JOBZ = 'A' or JOBZ = 'O' and M >= N, VT contains the */
|
|
/* N-by-N orthogonal matrix V**T; */
|
|
/* if JOBZ = 'S', VT contains the first min(M,N) rows of */
|
|
/* V**T (the right singular vectors, stored rowwise); */
|
|
/* if JOBZ = 'O' and M < N, or JOBZ = 'N', VT is not referenced. */
|
|
|
|
/* LDVT (input) INTEGER */
|
|
/* The leading dimension of the array VT. LDVT >= 1; if */
|
|
/* JOBZ = 'A' or JOBZ = 'O' and M >= N, LDVT >= N; */
|
|
/* if JOBZ = 'S', LDVT >= min(M,N). */
|
|
|
|
/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */
|
|
/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK; */
|
|
|
|
/* LWORK (input) INTEGER */
|
|
/* The dimension of the array WORK. LWORK >= 1. */
|
|
/* If JOBZ = 'N', */
|
|
/* LWORK >= 3*min(M,N) + max(max(M,N),7*min(M,N)). */
|
|
/* If JOBZ = 'O', */
|
|
/* LWORK >= 3*min(M,N)*min(M,N) + */
|
|
/* max(max(M,N),5*min(M,N)*min(M,N)+4*min(M,N)). */
|
|
/* If JOBZ = 'S' or 'A' */
|
|
/* LWORK >= 3*min(M,N)*min(M,N) + */
|
|
/* max(max(M,N),4*min(M,N)*min(M,N)+4*min(M,N)). */
|
|
/* For good performance, LWORK should generally be larger. */
|
|
/* If LWORK = -1 but other input arguments are legal, WORK(1) */
|
|
/* returns the optimal LWORK. */
|
|
|
|
/* IWORK (workspace) INTEGER array, dimension (8*min(M,N)) */
|
|
|
|
/* INFO (output) INTEGER */
|
|
/* = 0: successful exit. */
|
|
/* < 0: if INFO = -i, the i-th argument had an illegal value. */
|
|
/* > 0: DBDSDC did not converge, updating process failed. */
|
|
|
|
/* Further Details */
|
|
/* =============== */
|
|
|
|
/* Based on contributions by */
|
|
/* Ming Gu and Huan Ren, Computer Science Division, University of */
|
|
/* California at Berkeley, USA */
|
|
|
|
// ------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename T,
|
|
long NR1, long NR2, long NR3, long NR4,
|
|
long NC1, long NC2, long NC3, long NC4,
|
|
typename MM
|
|
>
|
|
int gesdd (
|
|
const char jobz,
|
|
matrix<T,NR1,NC1,MM,column_major_layout>& a,
|
|
matrix<T,NR2,NC2,MM,column_major_layout>& s,
|
|
matrix<T,NR3,NC3,MM,column_major_layout>& u,
|
|
matrix<T,NR4,NC4,MM,column_major_layout>& vt
|
|
)
|
|
{
|
|
matrix<T,0,1,MM,column_major_layout> work;
|
|
matrix<integer,0,1,MM,column_major_layout> iwork;
|
|
|
|
const long m = a.nr();
|
|
const long n = a.nc();
|
|
s.set_size(std::min(m,n), 1);
|
|
|
|
// make sure the iwork memory block is big enough
|
|
if (iwork.size() < 8*std::min(m,n))
|
|
iwork.set_size(8*std::min(m,n), 1);
|
|
|
|
if (jobz == 'A')
|
|
{
|
|
u.set_size(m,m);
|
|
vt.set_size(n,n);
|
|
}
|
|
else if (jobz == 'S')
|
|
{
|
|
u.set_size(m, std::min(m,n));
|
|
vt.set_size(std::min(m,n), n);
|
|
}
|
|
else if (jobz == 'O')
|
|
{
|
|
DLIB_CASSERT(false, "jobz == 'O' not supported");
|
|
}
|
|
else
|
|
{
|
|
u.set_size(NR3?NR3:1, NC3?NC3:1);
|
|
vt.set_size(NR4?NR4:1, NC4?NC4:1);
|
|
}
|
|
|
|
// figure out how big the workspace needs to be.
|
|
T work_size = 1;
|
|
int info = binding::gesdd(jobz, a.nr(), a.nc(), &a(0,0), a.nr(),
|
|
&s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(),
|
|
&work_size, -1, &iwork(0,0));
|
|
|
|
if (info != 0)
|
|
return info;
|
|
|
|
// There is a bug in an older version of LAPACK in Debian etch
|
|
// that causes the gesdd to return the wrong value for work_size
|
|
// when jobz == 'N'. So verify the value of work_size.
|
|
if (jobz == 'N')
|
|
{
|
|
using std::min;
|
|
using std::max;
|
|
const T min_work_size = 3*min(m,n) + max(max(m,n),7*min(m,n));
|
|
if (work_size < min_work_size)
|
|
work_size = min_work_size;
|
|
}
|
|
|
|
if (work.size() < work_size)
|
|
work.set_size(static_cast<long>(work_size), 1);
|
|
|
|
// compute the actual SVD
|
|
info = binding::gesdd(jobz, a.nr(), a.nc(), &a(0,0), a.nr(),
|
|
&s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(),
|
|
&work(0,0), work.size(), &iwork(0,0));
|
|
|
|
return info;
|
|
}
|
|
|
|
// ------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename T,
|
|
long NR1, long NR2, long NR3, long NR4,
|
|
long NC1, long NC2, long NC3, long NC4,
|
|
typename MM
|
|
>
|
|
int gesdd (
|
|
const char jobz,
|
|
matrix<T,NR1,NC1,MM,row_major_layout>& a,
|
|
matrix<T,NR2,NC2,MM,row_major_layout>& s,
|
|
matrix<T,NR3,NC3,MM,row_major_layout>& u_,
|
|
matrix<T,NR4,NC4,MM,row_major_layout>& vt_
|
|
)
|
|
{
|
|
matrix<T,0,1,MM,row_major_layout> work;
|
|
matrix<integer,0,1,MM,row_major_layout> iwork;
|
|
|
|
// Row major order matrices are transposed from LAPACK's point of view.
|
|
matrix<T,NR4,NC4,MM,row_major_layout>& u = vt_;
|
|
matrix<T,NR3,NC3,MM,row_major_layout>& vt = u_;
|
|
|
|
|
|
const long m = a.nc();
|
|
const long n = a.nr();
|
|
s.set_size(std::min(m,n), 1);
|
|
|
|
// make sure the iwork memory block is big enough
|
|
if (iwork.size() < 8*std::min(m,n))
|
|
iwork.set_size(8*std::min(m,n), 1);
|
|
|
|
if (jobz == 'A')
|
|
{
|
|
u.set_size(m,m);
|
|
vt.set_size(n,n);
|
|
}
|
|
else if (jobz == 'S')
|
|
{
|
|
u.set_size(std::min(m,n), m);
|
|
vt.set_size(n, std::min(m,n));
|
|
}
|
|
else if (jobz == 'O')
|
|
{
|
|
DLIB_CASSERT(false, "jobz == 'O' not supported");
|
|
}
|
|
else
|
|
{
|
|
u.set_size(NR4?NR4:1, NC4?NC4:1);
|
|
vt.set_size(NR3?NR3:1, NC3?NC3:1);
|
|
}
|
|
|
|
// figure out how big the workspace needs to be.
|
|
T work_size = 1;
|
|
int info = binding::gesdd(jobz, m, n, &a(0,0), a.nc(),
|
|
&s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(),
|
|
&work_size, -1, &iwork(0,0));
|
|
|
|
if (info != 0)
|
|
return info;
|
|
|
|
// There is a bug in an older version of LAPACK in Debian etch
|
|
// that causes the gesdd to return the wrong value for work_size
|
|
// when jobz == 'N'. So verify the value of work_size.
|
|
if (jobz == 'N')
|
|
{
|
|
using std::min;
|
|
using std::max;
|
|
const T min_work_size = 3*min(m,n) + max(max(m,n),7*min(m,n));
|
|
if (work_size < min_work_size)
|
|
work_size = min_work_size;
|
|
}
|
|
|
|
|
|
if (work.size() < work_size)
|
|
work.set_size(static_cast<long>(work_size), 1);
|
|
|
|
// compute the actual SVD
|
|
info = binding::gesdd(jobz, m, n, &a(0,0), a.nc(),
|
|
&s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(),
|
|
&work(0,0), work.size(), &iwork(0,0));
|
|
|
|
return info;
|
|
}
|
|
|
|
// ------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
#endif // DLIB_LAPACk_SDD_Hh_
|
|
|
|
|