Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 1 | #include <iostream> |
| 2 | #include <Eigen/Core> |
| 3 | #include <Eigen/Dense> |
| 4 | #include <Eigen/IterativeLinearSolvers> |
| 5 | #include <unsupported/Eigen/IterativeSolvers> |
| 6 | |
| 7 | class MatrixReplacement; |
| 8 | using Eigen::SparseMatrix; |
| 9 | |
| 10 | namespace Eigen { |
| 11 | namespace internal { |
| 12 | // MatrixReplacement looks-like a SparseMatrix, so let's inherits its traits: |
| 13 | template<> |
| 14 | struct traits<MatrixReplacement> : public Eigen::internal::traits<Eigen::SparseMatrix<double> > |
| 15 | {}; |
| 16 | } |
| 17 | } |
| 18 | |
| 19 | // Example of a matrix-free wrapper from a user type to Eigen's compatible type |
| 20 | // For the sake of simplicity, this example simply wrap a Eigen::SparseMatrix. |
| 21 | class MatrixReplacement : public Eigen::EigenBase<MatrixReplacement> { |
| 22 | public: |
| 23 | // Required typedefs, constants, and method: |
| 24 | typedef double Scalar; |
| 25 | typedef double RealScalar; |
| 26 | typedef int StorageIndex; |
| 27 | enum { |
| 28 | ColsAtCompileTime = Eigen::Dynamic, |
| 29 | MaxColsAtCompileTime = Eigen::Dynamic, |
| 30 | IsRowMajor = false |
| 31 | }; |
| 32 | |
| 33 | Index rows() const { return mp_mat->rows(); } |
| 34 | Index cols() const { return mp_mat->cols(); } |
| 35 | |
| 36 | template<typename Rhs> |
| 37 | Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct> operator*(const Eigen::MatrixBase<Rhs>& x) const { |
| 38 | return Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct>(*this, x.derived()); |
| 39 | } |
| 40 | |
| 41 | // Custom API: |
| 42 | MatrixReplacement() : mp_mat(0) {} |
| 43 | |
| 44 | void attachMyMatrix(const SparseMatrix<double> &mat) { |
| 45 | mp_mat = &mat; |
| 46 | } |
| 47 | const SparseMatrix<double> my_matrix() const { return *mp_mat; } |
| 48 | |
| 49 | private: |
| 50 | const SparseMatrix<double> *mp_mat; |
| 51 | }; |
| 52 | |
| 53 | |
| 54 | // Implementation of MatrixReplacement * Eigen::DenseVector though a specialization of internal::generic_product_impl: |
| 55 | namespace Eigen { |
| 56 | namespace internal { |
| 57 | |
| 58 | template<typename Rhs> |
| 59 | struct generic_product_impl<MatrixReplacement, Rhs, SparseShape, DenseShape, GemvProduct> // GEMV stands for matrix-vector |
| 60 | : generic_product_impl_base<MatrixReplacement,Rhs,generic_product_impl<MatrixReplacement,Rhs> > |
| 61 | { |
| 62 | typedef typename Product<MatrixReplacement,Rhs>::Scalar Scalar; |
| 63 | |
| 64 | template<typename Dest> |
| 65 | static void scaleAndAddTo(Dest& dst, const MatrixReplacement& lhs, const Rhs& rhs, const Scalar& alpha) |
| 66 | { |
| 67 | // This method should implement "dst += alpha * lhs * rhs" inplace, |
| 68 | // however, for iterative solvers, alpha is always equal to 1, so let's not bother about it. |
| 69 | assert(alpha==Scalar(1) && "scaling is not implemented"); |
| 70 | EIGEN_ONLY_USED_FOR_DEBUG(alpha); |
| 71 | |
| 72 | // Here we could simply call dst.noalias() += lhs.my_matrix() * rhs, |
| 73 | // but let's do something fancier (and less efficient): |
| 74 | for(Index i=0; i<lhs.cols(); ++i) |
| 75 | dst += rhs(i) * lhs.my_matrix().col(i); |
| 76 | } |
| 77 | }; |
| 78 | |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | int main() |
| 83 | { |
| 84 | int n = 10; |
| 85 | Eigen::SparseMatrix<double> S = Eigen::MatrixXd::Random(n,n).sparseView(0.5,1); |
| 86 | S = S.transpose()*S; |
| 87 | |
| 88 | MatrixReplacement A; |
| 89 | A.attachMyMatrix(S); |
| 90 | |
| 91 | Eigen::VectorXd b(n), x; |
| 92 | b.setRandom(); |
| 93 | |
| 94 | // Solve Ax = b using various iterative solver with matrix-free version: |
| 95 | { |
| 96 | Eigen::ConjugateGradient<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> cg; |
| 97 | cg.compute(A); |
| 98 | x = cg.solve(b); |
| 99 | std::cout << "CG: #iterations: " << cg.iterations() << ", estimated error: " << cg.error() << std::endl; |
| 100 | } |
| 101 | |
| 102 | { |
| 103 | Eigen::BiCGSTAB<MatrixReplacement, Eigen::IdentityPreconditioner> bicg; |
| 104 | bicg.compute(A); |
| 105 | x = bicg.solve(b); |
| 106 | std::cout << "BiCGSTAB: #iterations: " << bicg.iterations() << ", estimated error: " << bicg.error() << std::endl; |
| 107 | } |
| 108 | |
| 109 | { |
| 110 | Eigen::GMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres; |
| 111 | gmres.compute(A); |
| 112 | x = gmres.solve(b); |
| 113 | std::cout << "GMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl; |
| 114 | } |
| 115 | |
| 116 | { |
| 117 | Eigen::DGMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres; |
| 118 | gmres.compute(A); |
| 119 | x = gmres.solve(b); |
| 120 | std::cout << "DGMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl; |
| 121 | } |
| 122 | |
| 123 | { |
| 124 | Eigen::MINRES<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> minres; |
| 125 | minres.compute(A); |
| 126 | x = minres.solve(b); |
| 127 | std::cout << "MINRES: #iterations: " << minres.iterations() << ", estimated error: " << minres.error() << std::endl; |
| 128 | } |
| 129 | } |