blob: 04a5dd558c5880663f0e1323e20a2a54adaa1de3 [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#ifndef CERES_INTERNAL_DENSE_CHOLESKY_H_
32#define CERES_INTERNAL_DENSE_CHOLESKY_H_
33
34// This include must come before any #ifndef check on Ceres compile options.
35// clang-format off
36#include "ceres/internal/config.h"
37// clang-format on
38
39#include <memory>
40#include <vector>
41
42#include "Eigen/Dense"
43#include "ceres/context_impl.h"
44#include "ceres/cuda_buffer.h"
45#include "ceres/linear_solver.h"
46#include "glog/logging.h"
47#ifndef CERES_NO_CUDA
48#include "ceres/context_impl.h"
49#include "cuda_runtime.h"
50#include "cusolverDn.h"
51#endif // CERES_NO_CUDA
52
53namespace ceres::internal {
54
55// An interface that abstracts away the internal details of various dense linear
56// algebra libraries and offers a simple API for solving dense symmetric
57// positive definite linear systems using a Cholesky factorization.
58class CERES_NO_EXPORT DenseCholesky {
59 public:
60 static std::unique_ptr<DenseCholesky> Create(
61 const LinearSolver::Options& options);
62
63 virtual ~DenseCholesky();
64
65 // Computes the Cholesky factorization of the given matrix.
66 //
67 // The input matrix lhs is assumed to be a column-major num_cols x num_cols
68 // matrix, that is symmetric positive definite with its lower triangular part
69 // containing the left hand side of the linear system being solved.
70 //
71 // The input matrix lhs may be modified by the implementation to store the
72 // factorization, irrespective of whether the factorization succeeds or not.
73 // As a result it is the user's responsibility to ensure that lhs is valid
74 // when Solve is called.
75 virtual LinearSolverTerminationType Factorize(int num_cols,
76 double* lhs,
77 std::string* message) = 0;
78
79 // Computes the solution to the equation
80 //
81 // lhs * solution = rhs
82 //
83 // Calling Solve without calling Factorize is undefined behaviour. It is the
84 // user's responsibility to ensure that the input matrix lhs passed to
85 // Factorize has not been freed/modified when Solve is called.
86 virtual LinearSolverTerminationType Solve(const double* rhs,
87 double* solution,
88 std::string* message) = 0;
89
90 // Convenience method which combines a call to Factorize and Solve. Solve is
91 // only called if Factorize returns LinearSolverTerminationType::SUCCESS.
92 //
93 // The input matrix lhs may be modified by the implementation to store the
94 // factorization, irrespective of whether the method succeeds or not. It is
95 // the user's responsibility to ensure that lhs is valid if and when Solve is
96 // called again after this call.
97 LinearSolverTerminationType FactorAndSolve(int num_cols,
98 double* lhs,
99 const double* rhs,
100 double* solution,
101 std::string* message);
102};
103
104class CERES_NO_EXPORT EigenDenseCholesky final : public DenseCholesky {
105 public:
106 LinearSolverTerminationType Factorize(int num_cols,
107 double* lhs,
108 std::string* message) override;
109 LinearSolverTerminationType Solve(const double* rhs,
110 double* solution,
111 std::string* message) override;
112
113 private:
114 using LLTType = Eigen::LLT<Eigen::Ref<Eigen::MatrixXd>, Eigen::Lower>;
115 std::unique_ptr<LLTType> llt_;
116};
117
118class CERES_NO_EXPORT FloatEigenDenseCholesky final : public DenseCholesky {
119 public:
120 LinearSolverTerminationType Factorize(int num_cols,
121 double* lhs,
122 std::string* message) override;
123 LinearSolverTerminationType Solve(const double* rhs,
124 double* solution,
125 std::string* message) override;
126
127 private:
128 Eigen::MatrixXf lhs_;
129 Eigen::VectorXf rhs_;
130 Eigen::VectorXf solution_;
131 using LLTType = Eigen::LLT<Eigen::MatrixXf, Eigen::Lower>;
132 std::unique_ptr<LLTType> llt_;
133};
134
135#ifndef CERES_NO_LAPACK
136class CERES_NO_EXPORT LAPACKDenseCholesky final : public DenseCholesky {
137 public:
138 LinearSolverTerminationType Factorize(int num_cols,
139 double* lhs,
140 std::string* message) override;
141 LinearSolverTerminationType Solve(const double* rhs,
142 double* solution,
143 std::string* message) override;
144
145 private:
146 double* lhs_ = nullptr;
147 int num_cols_ = -1;
148 LinearSolverTerminationType termination_type_ =
149 LinearSolverTerminationType::FATAL_ERROR;
150};
151
152class CERES_NO_EXPORT FloatLAPACKDenseCholesky final : public DenseCholesky {
153 public:
154 LinearSolverTerminationType Factorize(int num_cols,
155 double* lhs,
156 std::string* message) override;
157 LinearSolverTerminationType Solve(const double* rhs,
158 double* solution,
159 std::string* message) override;
160
161 private:
162 Eigen::MatrixXf lhs_;
163 Eigen::VectorXf rhs_and_solution_;
164 int num_cols_ = -1;
165 LinearSolverTerminationType termination_type_ =
166 LinearSolverTerminationType::FATAL_ERROR;
167};
168#endif // CERES_NO_LAPACK
169
170class DenseIterativeRefiner;
171
172// Computes an initial solution using the given instance of
173// DenseCholesky, and then refines it using the DenseIterativeRefiner.
174class CERES_NO_EXPORT RefinedDenseCholesky final : public DenseCholesky {
175 public:
176 RefinedDenseCholesky(
177 std::unique_ptr<DenseCholesky> dense_cholesky,
178 std::unique_ptr<DenseIterativeRefiner> iterative_refiner);
179 ~RefinedDenseCholesky() override;
180
181 LinearSolverTerminationType Factorize(int num_cols,
182 double* lhs,
183 std::string* message) override;
184 LinearSolverTerminationType Solve(const double* rhs,
185 double* solution,
186 std::string* message) override;
187
188 private:
189 std::unique_ptr<DenseCholesky> dense_cholesky_;
190 std::unique_ptr<DenseIterativeRefiner> iterative_refiner_;
191 double* lhs_ = nullptr;
192 int num_cols_;
193};
194
195#ifndef CERES_NO_CUDA
196// CUDA implementation of DenseCholesky using the cuSolverDN library using the
197// 32-bit legacy interface for maximum compatibility.
198class CERES_NO_EXPORT CUDADenseCholesky final : public DenseCholesky {
199 public:
200 static std::unique_ptr<CUDADenseCholesky> Create(
201 const LinearSolver::Options& options);
202 CUDADenseCholesky(const CUDADenseCholesky&) = delete;
203 CUDADenseCholesky& operator=(const CUDADenseCholesky&) = delete;
204 LinearSolverTerminationType Factorize(int num_cols,
205 double* lhs,
206 std::string* message) override;
207 LinearSolverTerminationType Solve(const double* rhs,
208 double* solution,
209 std::string* message) override;
210
211 private:
212 explicit CUDADenseCholesky(ContextImpl* context);
213
214 ContextImpl* context_ = nullptr;
215 // Number of columns in the A matrix, to be cached between calls to *Factorize
216 // and *Solve.
217 size_t num_cols_ = 0;
218 // GPU memory allocated for the A matrix (lhs matrix).
219 CudaBuffer<double> lhs_;
220 // GPU memory allocated for the B matrix (rhs vector).
221 CudaBuffer<double> rhs_;
222 // Scratch space for cuSOLVER on the GPU.
223 CudaBuffer<double> device_workspace_;
224 // Required for error handling with cuSOLVER.
225 CudaBuffer<int> error_;
226 // Cache the result of Factorize to ensure that when Solve is called, the
227 // factorization of lhs is valid.
228 LinearSolverTerminationType factorize_result_ =
229 LinearSolverTerminationType::FATAL_ERROR;
230};
231
232// A mixed-precision iterative refinement dense Cholesky solver using FP32 CUDA
233// Dense Cholesky for inner iterations, and FP64 outer refinements.
234// This class implements a modified version of the "Classical iterative
235// refinement" (Algorithm 4.1) from the following paper:
236// Haidar, Azzam, Harun Bayraktar, Stanimire Tomov, Jack Dongarra, and Nicholas
237// J. Higham. "Mixed-precision iterative refinement using tensor cores on GPUs
238// to accelerate solution of linear systems." Proceedings of the Royal Society A
239// 476, no. 2243 (2020): 20200110.
240//
241// The three key modifications from Algorithm 4.1 in the paper are:
242// 1. We use Cholesky factorization instead of LU factorization since our A is
243// symmetric positive definite.
244// 2. During the solution update, the up-cast and accumulation is performed in
245// one step with a custom kernel.
246class CERES_NO_EXPORT CUDADenseCholeskyMixedPrecision final
247 : public DenseCholesky {
248 public:
249 static std::unique_ptr<CUDADenseCholeskyMixedPrecision> Create(
250 const LinearSolver::Options& options);
251 CUDADenseCholeskyMixedPrecision(const CUDADenseCholeskyMixedPrecision&) =
252 delete;
253 CUDADenseCholeskyMixedPrecision& operator=(
254 const CUDADenseCholeskyMixedPrecision&) = delete;
255 LinearSolverTerminationType Factorize(int num_cols,
256 double* lhs,
257 std::string* message) override;
258 LinearSolverTerminationType Solve(const double* rhs,
259 double* solution,
260 std::string* message) override;
261
262 private:
263 CUDADenseCholeskyMixedPrecision(ContextImpl* context,
264 int max_num_refinement_iterations);
265
266 // Helper function to wrap Cuda boilerplate needed to call Spotrf.
267 LinearSolverTerminationType CudaCholeskyFactorize(std::string* message);
268 // Helper function to wrap Cuda boilerplate needed to call Spotrs.
269 LinearSolverTerminationType CudaCholeskySolve(std::string* message);
270 // Picks up the cuSolverDN and cuStream handles from the context in the
271 // options, and the number of refinement iterations from the options. If
272 // the context is unable to initialize CUDA, returns false with a
273 // human-readable message indicating the reason.
274 bool Init(const LinearSolver::Options& options, std::string* message);
275
276 ContextImpl* context_ = nullptr;
277 // Number of columns in the A matrix, to be cached between calls to *Factorize
278 // and *Solve.
279 size_t num_cols_ = 0;
280 CudaBuffer<double> lhs_fp64_;
281 CudaBuffer<double> rhs_fp64_;
282 CudaBuffer<float> lhs_fp32_;
283 // Scratch space for cuSOLVER on the GPU.
284 CudaBuffer<float> device_workspace_;
285 // Required for error handling with cuSOLVER.
286 CudaBuffer<int> error_;
287
288 // Solution to lhs * x = rhs.
289 CudaBuffer<double> x_fp64_;
290 // Incremental correction to x.
291 CudaBuffer<float> correction_fp32_;
292 // Residual to iterative refinement.
293 CudaBuffer<float> residual_fp32_;
294 CudaBuffer<double> residual_fp64_;
295
296 // Number of inner refinement iterations to perform.
297 int max_num_refinement_iterations_ = 0;
298 // Cache the result of Factorize to ensure that when Solve is called, the
299 // factorization of lhs is valid.
300 LinearSolverTerminationType factorize_result_ =
301 LinearSolverTerminationType::FATAL_ERROR;
302};
303
304#endif // CERES_NO_CUDA
305
306} // namespace ceres::internal
307
308#endif // CERES_INTERNAL_DENSE_CHOLESKY_H_