00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042 #ifndef NOX_LAPACK_LINEARSOLVER_H
00043 #define NOX_LAPACK_LINEARSOLVER_H
00044
00045 #include "Teuchos_BLAS.hpp"
00046 #include "Teuchos_LAPACK.hpp"
00047 #include "Teuchos_ScalarTraits.hpp"
00048
00049 #include "NOX_LAPACK_Matrix.H"
00050
00051 namespace NOX {
00052
00053 namespace LAPACK {
00054
00056
00065 template <typename T>
00066 class LinearSolver {
00067
00068 public:
00069
00071 LinearSolver(int n);
00072
00074 LinearSolver(const LinearSolver<T>& s);
00075
00077 ~LinearSolver();
00078
00080 LinearSolver& operator=(const LinearSolver<T>& s);
00081
00083 Matrix<T>& getMatrix();
00084
00086 const Matrix<T>& getMatrix() const;
00087
00089
00092 void reset();
00093
00095
00100 void apply(bool trans, int ncols, const T* input, T* output) const;
00101
00103
00109 bool solve(bool trans, int ncols, T* output);
00110
00111 protected:
00112
00114 Matrix<T> mat;
00115
00117 Matrix<T> lu;
00118
00120 std::vector<int> pivots;
00121
00123 bool isValidLU;
00124
00126 Teuchos::BLAS<int,T> blas;
00127
00129 Teuchos::LAPACK<int,T> lapack;
00130
00131 };
00132
00133 }
00134
00135 }
00136
00137 template <typename T>
00138 NOX::LAPACK::LinearSolver<T>::LinearSolver(int n) :
00139 mat(n,n),
00140 lu(n,n),
00141 pivots(n),
00142 isValidLU(false),
00143 blas(),
00144 lapack()
00145 {
00146 }
00147
00148 template <typename T>
00149 NOX::LAPACK::LinearSolver<T>::LinearSolver(const NOX::LAPACK::LinearSolver<T>& s) :
00150 mat(s.mat),
00151 lu(s.lu),
00152 pivots(s.pivots),
00153 isValidLU(s.isValidLU),
00154 blas(),
00155 lapack()
00156 {
00157 }
00158
00159 template <typename T>
00160 NOX::LAPACK::LinearSolver<T>::~LinearSolver()
00161 {
00162 }
00163
00164 template <typename T>
00165 NOX::LAPACK::LinearSolver<T>&
00166 NOX::LAPACK::LinearSolver<T>::operator=(const NOX::LAPACK::LinearSolver<T>& s)
00167 {
00168 if (this != &s) {
00169 mat = s.mat;
00170 lu = s.lu;
00171 pivots = s.pivots;
00172 isValidLU = s.isValidLU;
00173 }
00174
00175 return *this;
00176 }
00177
00178 template <typename T>
00179 NOX::LAPACK::Matrix<T>&
00180 NOX::LAPACK::LinearSolver<T>::getMatrix()
00181 {
00182 return mat;
00183 }
00184
00185 template <typename T>
00186 const NOX::LAPACK::Matrix<T>&
00187 NOX::LAPACK::LinearSolver<T>::getMatrix() const
00188 {
00189 return mat;
00190 }
00191
00192 template <typename T>
00193 void
00194 NOX::LAPACK::LinearSolver<T>::reset()
00195 {
00196 isValidLU = false;
00197 }
00198
00199 template <typename T>
00200 void
00201 NOX::LAPACK::LinearSolver<T>::apply(bool trans, int ncols, const T* input,
00202 T* output) const
00203 {
00204 Teuchos::ETransp tr = Teuchos::NO_TRANS;
00205 if (trans) {
00206 if (Teuchos::ScalarTraits<T>::isComplex)
00207 tr = Teuchos::CONJ_TRANS;
00208 else
00209 tr = Teuchos::TRANS;
00210 }
00211
00212 int n = mat.numRows();
00213 blas.GEMM(tr, Teuchos::NO_TRANS, n, ncols, n, 1.0, &mat(0,0), n,
00214 input, n, 0.0, output, n);
00215 }
00216
00217 template <typename T>
00218 bool
00219 NOX::LAPACK::LinearSolver<T>::solve(bool trans, int ncols, T* output)
00220 {
00221 int info;
00222 int n = mat.numRows();
00223
00224
00225 if (!isValidLU) {
00226 lu = mat;
00227 lapack.GETRF(n, n, &lu(0,0), n, &pivots[0], &info);
00228 if (info != 0)
00229 return false;
00230 isValidLU = true;
00231 }
00232
00233
00234 char tr = 'N';
00235 if (trans) {
00236 if (Teuchos::ScalarTraits<T>::isComplex)
00237 tr = 'C';
00238 else
00239 tr = 'T';
00240 }
00241 lapack.GETRS(tr, n, ncols, &lu(0,0), n, &pivots[0], output, n, &info);
00242
00243 if (info != 0)
00244 return false;
00245 return true;
00246 }
00247
00248 #endif