blob: 74699381c7e834159317c19002165d3ff7f357b2 [file] [log] [blame]
Austin Schuh189376f2018-12-20 22:11:15 +11001#include <iostream>
2#include <Eigen/Core>
3#include <Eigen/Dense>
4#include <Eigen/IterativeLinearSolvers>
5#include <unsupported/Eigen/IterativeSolvers>
6
7class MatrixReplacement;
8using Eigen::SparseMatrix;
9
10namespace Eigen {
11namespace internal {
12 // MatrixReplacement looks-like a SparseMatrix, so let's inherits its traits:
13 template<>
14 struct traits<MatrixReplacement> : public Eigen::internal::traits<Eigen::SparseMatrix<double> >
15 {};
16}
17}
18
19// Example of a matrix-free wrapper from a user type to Eigen's compatible type
20// For the sake of simplicity, this example simply wrap a Eigen::SparseMatrix.
21class MatrixReplacement : public Eigen::EigenBase<MatrixReplacement> {
22public:
23 // Required typedefs, constants, and method:
24 typedef double Scalar;
25 typedef double RealScalar;
26 typedef int StorageIndex;
27 enum {
28 ColsAtCompileTime = Eigen::Dynamic,
29 MaxColsAtCompileTime = Eigen::Dynamic,
30 IsRowMajor = false
31 };
32
33 Index rows() const { return mp_mat->rows(); }
34 Index cols() const { return mp_mat->cols(); }
35
36 template<typename Rhs>
37 Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct> operator*(const Eigen::MatrixBase<Rhs>& x) const {
38 return Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct>(*this, x.derived());
39 }
40
41 // Custom API:
42 MatrixReplacement() : mp_mat(0) {}
43
44 void attachMyMatrix(const SparseMatrix<double> &mat) {
45 mp_mat = &mat;
46 }
47 const SparseMatrix<double> my_matrix() const { return *mp_mat; }
48
49private:
50 const SparseMatrix<double> *mp_mat;
51};
52
53
54// Implementation of MatrixReplacement * Eigen::DenseVector though a specialization of internal::generic_product_impl:
55namespace Eigen {
56namespace internal {
57
58 template<typename Rhs>
59 struct generic_product_impl<MatrixReplacement, Rhs, SparseShape, DenseShape, GemvProduct> // GEMV stands for matrix-vector
60 : generic_product_impl_base<MatrixReplacement,Rhs,generic_product_impl<MatrixReplacement,Rhs> >
61 {
62 typedef typename Product<MatrixReplacement,Rhs>::Scalar Scalar;
63
64 template<typename Dest>
65 static void scaleAndAddTo(Dest& dst, const MatrixReplacement& lhs, const Rhs& rhs, const Scalar& alpha)
66 {
67 // This method should implement "dst += alpha * lhs * rhs" inplace,
68 // however, for iterative solvers, alpha is always equal to 1, so let's not bother about it.
69 assert(alpha==Scalar(1) && "scaling is not implemented");
70 EIGEN_ONLY_USED_FOR_DEBUG(alpha);
71
72 // Here we could simply call dst.noalias() += lhs.my_matrix() * rhs,
73 // but let's do something fancier (and less efficient):
74 for(Index i=0; i<lhs.cols(); ++i)
75 dst += rhs(i) * lhs.my_matrix().col(i);
76 }
77 };
78
79}
80}
81
82int main()
83{
84 int n = 10;
85 Eigen::SparseMatrix<double> S = Eigen::MatrixXd::Random(n,n).sparseView(0.5,1);
86 S = S.transpose()*S;
87
88 MatrixReplacement A;
89 A.attachMyMatrix(S);
90
91 Eigen::VectorXd b(n), x;
92 b.setRandom();
93
94 // Solve Ax = b using various iterative solver with matrix-free version:
95 {
96 Eigen::ConjugateGradient<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> cg;
97 cg.compute(A);
98 x = cg.solve(b);
99 std::cout << "CG: #iterations: " << cg.iterations() << ", estimated error: " << cg.error() << std::endl;
100 }
101
102 {
103 Eigen::BiCGSTAB<MatrixReplacement, Eigen::IdentityPreconditioner> bicg;
104 bicg.compute(A);
105 x = bicg.solve(b);
106 std::cout << "BiCGSTAB: #iterations: " << bicg.iterations() << ", estimated error: " << bicg.error() << std::endl;
107 }
108
109 {
110 Eigen::GMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres;
111 gmres.compute(A);
112 x = gmres.solve(b);
113 std::cout << "GMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
114 }
115
116 {
117 Eigen::DGMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres;
118 gmres.compute(A);
119 x = gmres.solve(b);
120 std::cout << "DGMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
121 }
122
123 {
124 Eigen::MINRES<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> minres;
125 minres.compute(A);
126 x = minres.solve(b);
127 std::cout << "MINRES: #iterations: " << minres.iterations() << ", estimated error: " << minres.error() << std::endl;
128 }
129}