blob: 7ca0a5e770cfbc487cb0b25629117d68a0a49066 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2018 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/iterative_refiner.h"
32
33#include "Eigen/Dense"
34#include "ceres/internal/eigen.h"
35#include "ceres/sparse_cholesky.h"
36#include "ceres/sparse_matrix.h"
37#include "glog/logging.h"
38#include "gtest/gtest.h"
39
40namespace ceres {
41namespace internal {
42
43// Macros to help us define virtual methods which we do not expect to
44// use/call in this test.
45#define DO_NOT_CALL \
46 { LOG(FATAL) << "DO NOT CALL"; }
47#define DO_NOT_CALL_WITH_RETURN(x) \
48 { \
49 LOG(FATAL) << "DO NOT CALL"; \
50 return x; \
51 }
52
53// A fake SparseMatrix, which uses an Eigen matrix to do the real work.
54class FakeSparseMatrix : public SparseMatrix {
55 public:
56 FakeSparseMatrix(const Matrix& m) : m_(m) {}
57 virtual ~FakeSparseMatrix() {}
58
59 // y += Ax
60 virtual void RightMultiply(const double* x, double* y) const {
61 VectorRef(y, m_.cols()) += m_ * ConstVectorRef(x, m_.cols());
62 }
63 // y += A'x
64 virtual void LeftMultiply(const double* x, double* y) const {
65 // We will assume that this is a symmetric matrix.
66 RightMultiply(x, y);
67 }
68
69 virtual double* mutable_values() { return m_.data(); }
70 virtual const double* values() const { return m_.data(); }
71 virtual int num_rows() const { return m_.cols(); }
72 virtual int num_cols() const { return m_.cols(); }
73 virtual int num_nonzeros() const { return m_.cols() * m_.cols(); }
74
75 // The following methods are not needed for tests in this file.
76 virtual void SquaredColumnNorm(double* x) const DO_NOT_CALL;
77 virtual void ScaleColumns(const double* scale) DO_NOT_CALL;
78 virtual void SetZero() DO_NOT_CALL;
79 virtual void ToDenseMatrix(Matrix* dense_matrix) const DO_NOT_CALL;
80 virtual void ToTextFile(FILE* file) const DO_NOT_CALL;
81
82 private:
83 Matrix m_;
84};
85
86// A fake SparseCholesky which uses Eigen's Cholesky factorization to
87// do the real work. The template parameter allows us to work in
88// doubles or floats, even though the source matrix is double.
89template <typename Scalar>
90class FakeSparseCholesky : public SparseCholesky {
91 public:
92 FakeSparseCholesky(const Matrix& lhs) { lhs_ = lhs.cast<Scalar>(); }
93 virtual ~FakeSparseCholesky() {}
94
95 virtual LinearSolverTerminationType Solve(const double* rhs_ptr,
96 double* solution_ptr,
97 std::string* message) {
98 const int num_cols = lhs_.cols();
99 VectorRef solution(solution_ptr, num_cols);
100 ConstVectorRef rhs(rhs_ptr, num_cols);
101 solution = lhs_.llt().solve(rhs.cast<Scalar>()).template cast<double>();
102 return LINEAR_SOLVER_SUCCESS;
103 }
104
105 // The following methods are not needed for tests in this file.
106 virtual CompressedRowSparseMatrix::StorageType StorageType() const
107 DO_NOT_CALL_WITH_RETURN(CompressedRowSparseMatrix::UPPER_TRIANGULAR);
108 virtual LinearSolverTerminationType Factorize(CompressedRowSparseMatrix* lhs,
109 std::string* message)
110 DO_NOT_CALL_WITH_RETURN(LINEAR_SOLVER_FAILURE);
111
112 virtual LinearSolverTerminationType FactorAndSolve(
113 CompressedRowSparseMatrix* lhs,
114 const double* rhs,
115 double* solution,
116 std::string* message) DO_NOT_CALL_WITH_RETURN(LINEAR_SOLVER_FAILURE);
117
118 private:
119 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> lhs_;
120};
121
122#undef DO_NOT_CALL
123#undef DO_NOT_CALL_WITH_RETURN
124
125class IterativeRefinerTest : public ::testing::Test {
126 public:
127 void SetUp() {
128 num_cols_ = 5;
129 max_num_iterations_ = 30;
130 Matrix m(num_cols_, num_cols_);
131 m.setRandom();
132 lhs_ = m * m.transpose();
133 solution_.resize(num_cols_);
134 solution_.setRandom();
135 rhs_ = lhs_ * solution_;
136 };
137
138 protected:
139 int num_cols_;
140 int max_num_iterations_;
141 Matrix lhs_;
142 Vector rhs_, solution_;
143};
144
145TEST_F(IterativeRefinerTest, RandomSolutionWithExactFactorizationConverges) {
146 FakeSparseMatrix lhs(lhs_);
147 FakeSparseCholesky<double> sparse_cholesky(lhs_);
148 IterativeRefiner refiner(max_num_iterations_);
149 Vector refined_solution(num_cols_);
150 refined_solution.setRandom();
151 refiner.Refine(lhs, rhs_.data(), &sparse_cholesky, refined_solution.data());
152 EXPECT_NEAR((lhs_ * refined_solution - rhs_).norm(),
153 0.0,
154 std::numeric_limits<double>::epsilon() * 10);
155}
156
157TEST_F(IterativeRefinerTest,
158 RandomSolutionWithApproximationFactorizationConverges) {
159 FakeSparseMatrix lhs(lhs_);
160 // Use a single precision Cholesky factorization of the double
161 // precision matrix. This will give us an approximate factorization.
162 FakeSparseCholesky<float> sparse_cholesky(lhs_);
163 IterativeRefiner refiner(max_num_iterations_);
164 Vector refined_solution(num_cols_);
165 refined_solution.setRandom();
166 refiner.Refine(lhs, rhs_.data(), &sparse_cholesky, refined_solution.data());
167 EXPECT_NEAR((lhs_ * refined_solution - rhs_).norm(),
168 0.0,
169 std::numeric_limits<double>::epsilon() * 10);
170}
171
172} // namespace internal
173} // namespace ceres