blob: 0ba17c4df9447a0803cdc0ebebd5f904e2525d6d [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#ifndef CERES_INTERNAL_DENSE_QR_H_
32#define CERES_INTERNAL_DENSE_QR_H_
33
34// This include must come before any #ifndef check on Ceres compile options.
35// clang-format off
36#include "ceres/internal/config.h"
37// clang-format on
38
39#include <memory>
40#include <vector>
41
42#include "Eigen/Dense"
43#include "ceres/context_impl.h"
44#include "ceres/internal/disable_warnings.h"
45#include "ceres/internal/eigen.h"
46#include "ceres/internal/export.h"
47#include "ceres/linear_solver.h"
48#include "glog/logging.h"
49
50#ifndef CERES_NO_CUDA
51#include "ceres/context_impl.h"
52#include "ceres/cuda_buffer.h"
53#include "cublas_v2.h"
54#include "cuda_runtime.h"
55#include "cusolverDn.h"
56#endif // CERES_NO_CUDA
57
58namespace ceres::internal {
59
60// An interface that abstracts away the internal details of various dense linear
61// algebra libraries and offers a simple API for solving dense linear systems
62// using a QR factorization.
63class CERES_NO_EXPORT DenseQR {
64 public:
65 static std::unique_ptr<DenseQR> Create(const LinearSolver::Options& options);
66
67 virtual ~DenseQR();
68
69 // Computes the QR factorization of the given matrix.
70 //
71 // The input matrix lhs is assumed to be a column-major num_rows x num_cols
72 // matrix.
73 //
74 // The input matrix lhs may be modified by the implementation to store the
75 // factorization, irrespective of whether the factorization succeeds or not.
76 // As a result it is the user's responsibility to ensure that lhs is valid
77 // when Solve is called.
78 virtual LinearSolverTerminationType Factorize(int num_rows,
79 int num_cols,
80 double* lhs,
81 std::string* message) = 0;
82
83 // Computes the solution to the equation
84 //
85 // lhs * solution = rhs
86 //
87 // Calling Solve without calling Factorize is undefined behaviour. It is the
88 // user's responsibility to ensure that the input matrix lhs passed to
89 // Factorize has not been freed/modified when Solve is called.
90 virtual LinearSolverTerminationType Solve(const double* rhs,
91 double* solution,
92 std::string* message) = 0;
93
94 // Convenience method which combines a call to Factorize and Solve. Solve is
95 // only called if Factorize returns LinearSolverTerminationType::SUCCESS.
96 //
97 // The input matrix lhs may be modified by the implementation to store the
98 // factorization, irrespective of whether the method succeeds or not. It is
99 // the user's responsibility to ensure that lhs is valid if and when Solve is
100 // called again after this call.
101 LinearSolverTerminationType FactorAndSolve(int num_rows,
102 int num_cols,
103 double* lhs,
104 const double* rhs,
105 double* solution,
106 std::string* message);
107};
108
109class CERES_NO_EXPORT EigenDenseQR final : public DenseQR {
110 public:
111 LinearSolverTerminationType Factorize(int num_rows,
112 int num_cols,
113 double* lhs,
114 std::string* message) override;
115 LinearSolverTerminationType Solve(const double* rhs,
116 double* solution,
117 std::string* message) override;
118
119 private:
120 using QRType = Eigen::HouseholderQR<Eigen::Ref<ColMajorMatrix>>;
121 std::unique_ptr<QRType> qr_;
122};
123
124#ifndef CERES_NO_LAPACK
125class CERES_NO_EXPORT LAPACKDenseQR final : public DenseQR {
126 public:
127 LinearSolverTerminationType Factorize(int num_rows,
128 int num_cols,
129 double* lhs,
130 std::string* message) override;
131 LinearSolverTerminationType Solve(const double* rhs,
132 double* solution,
133 std::string* message) override;
134
135 private:
136 double* lhs_ = nullptr;
137 int num_rows_;
138 int num_cols_;
139 LinearSolverTerminationType termination_type_ =
140 LinearSolverTerminationType::FATAL_ERROR;
141 Vector work_;
142 Vector tau_;
143 Vector q_transpose_rhs_;
144};
145#endif // CERES_NO_LAPACK
146
147#ifndef CERES_NO_CUDA
148// Implementation of DenseQR using the 32-bit cuSolverDn interface. A
149// requirement for using this solver is that the lhs must not be rank deficient.
150// This is because cuSolverDn does not implement the singularity-checking
151// wrapper trtrs, hence this solver directly uses trsv from CUBLAS for the
152// backsubstitution.
153class CERES_NO_EXPORT CUDADenseQR final : public DenseQR {
154 public:
155 static std::unique_ptr<CUDADenseQR> Create(
156 const LinearSolver::Options& options);
157 CUDADenseQR(const CUDADenseQR&) = delete;
158 CUDADenseQR& operator=(const CUDADenseQR&) = delete;
159 LinearSolverTerminationType Factorize(int num_rows,
160 int num_cols,
161 double* lhs,
162 std::string* message) override;
163 LinearSolverTerminationType Solve(const double* rhs,
164 double* solution,
165 std::string* message) override;
166
167 private:
168 explicit CUDADenseQR(ContextImpl* context);
169
170 ContextImpl* context_ = nullptr;
171 // Number of rowns in the A matrix, to be cached between calls to *Factorize
172 // and *Solve.
173 size_t num_rows_ = 0;
174 // Number of columns in the A matrix, to be cached between calls to *Factorize
175 // and *Solve.
176 size_t num_cols_ = 0;
177 // GPU memory allocated for the A matrix (lhs matrix).
178 CudaBuffer<double> lhs_;
179 // GPU memory allocated for the B matrix (rhs vector).
180 CudaBuffer<double> rhs_;
181 // GPU memory allocated for the TAU matrix (scaling of householder vectors).
182 CudaBuffer<double> tau_;
183 // Scratch space for cuSOLVER on the GPU.
184 CudaBuffer<double> device_workspace_;
185 // Required for error handling with cuSOLVER.
186 CudaBuffer<int> error_;
187 // Cache the result of Factorize to ensure that when Solve is called, the
188 // factiorization of lhs is valid.
189 LinearSolverTerminationType factorize_result_ =
190 LinearSolverTerminationType::FATAL_ERROR;
191};
192
193#endif // CERES_NO_CUDA
194
195} // namespace ceres::internal
196
197#include "ceres/internal/reenable_warnings.h"
198
199#endif // CERES_INTERNAL_DENSE_QR_H_