blob: fbbcadc87dc935daa52e69e0c726fcca99897687 [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/dense_qr.h"
32
33#include <algorithm>
34#include <memory>
35#include <string>
36
37#ifndef CERES_NO_CUDA
38#include "ceres/context_impl.h"
39#include "cublas_v2.h"
40#include "cusolverDn.h"
41#endif // CERES_NO_CUDA
42
43#ifndef CERES_NO_LAPACK
44
45// LAPACK routines for solving a linear least squares problem using QR
46// factorization. This is done in three stages:
47//
48// A * x = b
49// Q * R * x = b (dgeqrf)
50// R * x = Q' * b (dormqr)
51// x = R^{-1} * Q'* b (dtrtrs)
52
53// clang-format off
54
55// Compute the QR factorization of a.
56//
57// a is an m x n column major matrix (Denoted by "A" in the above description)
58// lda is the leading dimension of a. lda >= max(1, num_rows)
59// tau is an array of size min(m,n). It contains the scalar factors of the
60// elementary reflectors.
61// work is an array of size max(1,lwork). On exit, if info=0, work[0] contains
62// the optimal size of work.
63//
64// if lwork >= 1 it is the size of work. If lwork = -1, then a workspace query is assumed.
65// dgeqrf computes the optimal size of the work array and returns it as work[0].
66//
67// info = 0, successful exit.
68// info < 0, if info = -i, then the i^th argument had illegal value.
69extern "C" void dgeqrf_(const int* m, const int* n, double* a, const int* lda,
70 double* tau, double* work, const int* lwork, int* info);
71
72// Apply Q or Q' to b.
73//
74// b is a m times n column major matrix.
75// size = 'L' applies Q or Q' on the left, size = 'R' applies Q or Q' on the right.
76// trans = 'N', applies Q, trans = 'T', applies Q'.
77// k is the number of elementary reflectors whose product defines the matrix Q.
78// If size = 'L', m >= k >= 0 and if side = 'R', n >= k >= 0.
79// a is an lda x k column major matrix containing the reflectors as returned by dgeqrf.
80// ldb is the leading dimension of b.
81// work is an array of size max(1, lwork)
82// lwork if positive is the size of work. If lwork = -1, then a
83// workspace query is assumed.
84//
85// info = 0, successful exit.
86// info < 0, if info = -i, then the i^th argument had illegal value.
87extern "C" void dormqr_(const char* side, const char* trans, const int* m,
88 const int* n ,const int* k, double* a, const int* lda,
89 double* tau, double* b, const int* ldb, double* work,
90 const int* lwork, int* info);
91
92// Solve a triangular system of the form A * x = b
93//
94// uplo = 'U', A is upper triangular. uplo = 'L' is lower triangular.
95// trans = 'N', 'T', 'C' specifies the form - A, A^T, A^H.
96// DIAG = 'N', A is not unit triangular. 'U' is unit triangular.
97// n is the order of the matrix A.
98// nrhs number of columns of b.
99// a is a column major lda x n.
100// b is a column major matrix of ldb x nrhs
101//
102// info = 0 successful.
103// = -i < 0 i^th argument is an illegal value.
104// = i > 0, i^th diagonal element of A is zero.
105extern "C" void dtrtrs_(const char* uplo, const char* trans, const char* diag,
106 const int* n, const int* nrhs, double* a, const int* lda,
107 double* b, const int* ldb, int* info);
108// clang-format on
109
110#endif
111
112namespace ceres::internal {
113
114DenseQR::~DenseQR() = default;
115
116std::unique_ptr<DenseQR> DenseQR::Create(const LinearSolver::Options& options) {
117 std::unique_ptr<DenseQR> dense_qr;
118
119 switch (options.dense_linear_algebra_library_type) {
120 case EIGEN:
121 dense_qr = std::make_unique<EigenDenseQR>();
122 break;
123
124 case LAPACK:
125#ifndef CERES_NO_LAPACK
126 dense_qr = std::make_unique<LAPACKDenseQR>();
127 break;
128#else
129 LOG(FATAL) << "Ceres was compiled without support for LAPACK.";
130#endif
131
132 case CUDA:
133#ifndef CERES_NO_CUDA
134 dense_qr = CUDADenseQR::Create(options);
135 break;
136#else
137 LOG(FATAL) << "Ceres was compiled without support for CUDA.";
138#endif
139
140 default:
141 LOG(FATAL) << "Unknown dense linear algebra library type : "
142 << DenseLinearAlgebraLibraryTypeToString(
143 options.dense_linear_algebra_library_type);
144 }
145 return dense_qr;
146}
147
148LinearSolverTerminationType DenseQR::FactorAndSolve(int num_rows,
149 int num_cols,
150 double* lhs,
151 const double* rhs,
152 double* solution,
153 std::string* message) {
154 LinearSolverTerminationType termination_type =
155 Factorize(num_rows, num_cols, lhs, message);
156 if (termination_type == LinearSolverTerminationType::SUCCESS) {
157 termination_type = Solve(rhs, solution, message);
158 }
159 return termination_type;
160}
161
162LinearSolverTerminationType EigenDenseQR::Factorize(int num_rows,
163 int num_cols,
164 double* lhs,
165 std::string* message) {
166 Eigen::Map<ColMajorMatrix> m(lhs, num_rows, num_cols);
167 qr_ = std::make_unique<QRType>(m);
168 *message = "Success.";
169 return LinearSolverTerminationType::SUCCESS;
170}
171
172LinearSolverTerminationType EigenDenseQR::Solve(const double* rhs,
173 double* solution,
174 std::string* message) {
175 VectorRef(solution, qr_->cols()) =
176 qr_->solve(ConstVectorRef(rhs, qr_->rows()));
177 *message = "Success.";
178 return LinearSolverTerminationType::SUCCESS;
179}
180
181#ifndef CERES_NO_LAPACK
182LinearSolverTerminationType LAPACKDenseQR::Factorize(int num_rows,
183 int num_cols,
184 double* lhs,
185 std::string* message) {
186 int lwork = -1;
187 double work_size;
188 int info = 0;
189
190 // Compute the size of the temporary workspace needed to compute the QR
191 // factorization in the dgeqrf call below.
192 dgeqrf_(&num_rows,
193 &num_cols,
194 lhs_,
195 &num_rows,
196 tau_.data(),
197 &work_size,
198 &lwork,
199 &info);
200 if (info < 0) {
201 LOG(FATAL) << "Congratulations, you found a bug in Ceres."
202 << "Please report it."
203 << "LAPACK::dgels fatal error."
204 << "Argument: " << -info << " is invalid.";
205 }
206
207 lhs_ = lhs;
208 num_rows_ = num_rows;
209 num_cols_ = num_cols;
210
211 lwork = static_cast<int>(work_size);
212
213 if (work_.size() < lwork) {
214 work_.resize(lwork);
215 }
216 if (tau_.size() < num_cols) {
217 tau_.resize(num_cols);
218 }
219
220 if (q_transpose_rhs_.size() < num_rows) {
221 q_transpose_rhs_.resize(num_rows);
222 }
223
224 // Factorize the lhs_ using the workspace that we just constructed above.
225 dgeqrf_(&num_rows,
226 &num_cols,
227 lhs_,
228 &num_rows,
229 tau_.data(),
230 work_.data(),
231 &lwork,
232 &info);
233
234 if (info < 0) {
235 LOG(FATAL) << "Congratulations, you found a bug in Ceres."
236 << "Please report it. dgeqrf fatal error."
237 << "Argument: " << -info << " is invalid.";
238 }
239
240 termination_type_ = LinearSolverTerminationType::SUCCESS;
241 *message = "Success.";
242 return termination_type_;
243}
244
245LinearSolverTerminationType LAPACKDenseQR::Solve(const double* rhs,
246 double* solution,
247 std::string* message) {
248 if (termination_type_ != LinearSolverTerminationType::SUCCESS) {
249 *message = "QR factorization failed and solve called.";
250 return termination_type_;
251 }
252
253 std::copy_n(rhs, num_rows_, q_transpose_rhs_.data());
254
255 const char side = 'L';
256 char trans = 'T';
257 const int num_c_cols = 1;
258 const int lwork = work_.size();
259 int info = 0;
260 dormqr_(&side,
261 &trans,
262 &num_rows_,
263 &num_c_cols,
264 &num_cols_,
265 lhs_,
266 &num_rows_,
267 tau_.data(),
268 q_transpose_rhs_.data(),
269 &num_rows_,
270 work_.data(),
271 &lwork,
272 &info);
273 if (info < 0) {
274 LOG(FATAL) << "Congratulations, you found a bug in Ceres."
275 << "Please report it. dormr fatal error."
276 << "Argument: " << -info << " is invalid.";
277 }
278
279 const char uplo = 'U';
280 trans = 'N';
281 const char diag = 'N';
282 dtrtrs_(&uplo,
283 &trans,
284 &diag,
285 &num_cols_,
286 &num_c_cols,
287 lhs_,
288 &num_rows_,
289 q_transpose_rhs_.data(),
290 &num_rows_,
291 &info);
292
293 if (info < 0) {
294 LOG(FATAL) << "Congratulations, you found a bug in Ceres."
295 << "Please report it. dormr fatal error."
296 << "Argument: " << -info << " is invalid.";
297 } else if (info > 0) {
298 *message =
299 "QR factorization failure. The factorization is not full rank. R has "
300 "zeros on the diagonal.";
301 termination_type_ = LinearSolverTerminationType::FAILURE;
302 } else {
303 std::copy_n(q_transpose_rhs_.data(), num_cols_, solution);
304 termination_type_ = LinearSolverTerminationType::SUCCESS;
305 }
306
307 return termination_type_;
308}
309
310#endif // CERES_NO_LAPACK
311
312#ifndef CERES_NO_CUDA
313
314CUDADenseQR::CUDADenseQR(ContextImpl* context)
315 : context_(context),
316 lhs_{context},
317 rhs_{context},
318 tau_{context},
319 device_workspace_{context},
320 error_(context, 1) {}
321
322LinearSolverTerminationType CUDADenseQR::Factorize(int num_rows,
323 int num_cols,
324 double* lhs,
325 std::string* message) {
326 factorize_result_ = LinearSolverTerminationType::FATAL_ERROR;
327 lhs_.Reserve(num_rows * num_cols);
328 tau_.Reserve(std::min(num_rows, num_cols));
329 num_rows_ = num_rows;
330 num_cols_ = num_cols;
331 lhs_.CopyFromCpu(lhs, num_rows * num_cols);
332 int device_workspace_size = 0;
333 if (cusolverDnDgeqrf_bufferSize(context_->cusolver_handle_,
334 num_rows,
335 num_cols,
336 lhs_.data(),
337 num_rows,
338 &device_workspace_size) !=
339 CUSOLVER_STATUS_SUCCESS) {
340 *message = "cuSolverDN::cusolverDnDgeqrf_bufferSize failed.";
341 return LinearSolverTerminationType::FATAL_ERROR;
342 }
343 device_workspace_.Reserve(device_workspace_size);
344 if (cusolverDnDgeqrf(context_->cusolver_handle_,
345 num_rows,
346 num_cols,
347 lhs_.data(),
348 num_rows,
349 tau_.data(),
350 reinterpret_cast<double*>(device_workspace_.data()),
351 device_workspace_.size(),
352 error_.data()) != CUSOLVER_STATUS_SUCCESS) {
353 *message = "cuSolverDN::cusolverDnDgeqrf failed.";
354 return LinearSolverTerminationType::FATAL_ERROR;
355 }
356 int error = 0;
357 error_.CopyToCpu(&error, 1);
358 if (error < 0) {
359 LOG(FATAL) << "Congratulations, you found a bug in Ceres - "
360 << "please report it. "
361 << "cuSolverDN::cusolverDnDgeqrf fatal error. "
362 << "Argument: " << -error << " is invalid.";
363 // The following line is unreachable, but return failure just to be
364 // pedantic, since the compiler does not know that.
365 return LinearSolverTerminationType::FATAL_ERROR;
366 }
367
368 *message = "Success";
369 factorize_result_ = LinearSolverTerminationType::SUCCESS;
370 return LinearSolverTerminationType::SUCCESS;
371}
372
373LinearSolverTerminationType CUDADenseQR::Solve(const double* rhs,
374 double* solution,
375 std::string* message) {
376 if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
377 *message = "Factorize did not complete successfully previously.";
378 return factorize_result_;
379 }
380 rhs_.CopyFromCpu(rhs, num_rows_);
381 int device_workspace_size = 0;
382 if (cusolverDnDormqr_bufferSize(context_->cusolver_handle_,
383 CUBLAS_SIDE_LEFT,
384 CUBLAS_OP_T,
385 num_rows_,
386 1,
387 num_cols_,
388 lhs_.data(),
389 num_rows_,
390 tau_.data(),
391 rhs_.data(),
392 num_rows_,
393 &device_workspace_size) !=
394 CUSOLVER_STATUS_SUCCESS) {
395 *message = "cuSolverDN::cusolverDnDormqr_bufferSize failed.";
396 return LinearSolverTerminationType::FATAL_ERROR;
397 }
398 device_workspace_.Reserve(device_workspace_size);
399 // Compute rhs = Q^T * rhs, assuming that lhs has already been factorized.
400 // The result of factorization would have stored Q in a packed form in lhs_.
401 if (cusolverDnDormqr(context_->cusolver_handle_,
402 CUBLAS_SIDE_LEFT,
403 CUBLAS_OP_T,
404 num_rows_,
405 1,
406 num_cols_,
407 lhs_.data(),
408 num_rows_,
409 tau_.data(),
410 rhs_.data(),
411 num_rows_,
412 reinterpret_cast<double*>(device_workspace_.data()),
413 device_workspace_.size(),
414 error_.data()) != CUSOLVER_STATUS_SUCCESS) {
415 *message = "cuSolverDN::cusolverDnDormqr failed.";
416 return LinearSolverTerminationType::FATAL_ERROR;
417 }
418 int error = 0;
419 error_.CopyToCpu(&error, 1);
420 if (error < 0) {
421 LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
422 << "Please report it."
423 << "cuSolverDN::cusolverDnDormqr fatal error. "
424 << "Argument: " << -error << " is invalid.";
425 }
426 // Compute the solution vector as x = R \ (Q^T * rhs). Since the previous step
427 // replaced rhs by (Q^T * rhs), this is just x = R \ rhs.
428 if (cublasDtrsv(context_->cublas_handle_,
429 CUBLAS_FILL_MODE_UPPER,
430 CUBLAS_OP_N,
431 CUBLAS_DIAG_NON_UNIT,
432 num_cols_,
433 lhs_.data(),
434 num_rows_,
435 rhs_.data(),
436 1) != CUBLAS_STATUS_SUCCESS) {
437 *message = "cuBLAS::cublasDtrsv failed.";
438 return LinearSolverTerminationType::FATAL_ERROR;
439 }
440 rhs_.CopyToCpu(solution, num_cols_);
441 *message = "Success";
442 return LinearSolverTerminationType::SUCCESS;
443}
444
445std::unique_ptr<CUDADenseQR> CUDADenseQR::Create(
446 const LinearSolver::Options& options) {
447 if (options.dense_linear_algebra_library_type != CUDA ||
448 options.context == nullptr || !options.context->IsCudaInitialized()) {
449 return nullptr;
450 }
451 return std::unique_ptr<CUDADenseQR>(new CUDADenseQR(options.context));
452}
453
454#endif // CERES_NO_CUDA
455
456} // namespace ceres::internal