blob: 5a3e7e2cad19c5ea225f84fef9e9fbddf8b97fc6 [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/dense_cholesky.h"
32
33#include <algorithm>
34#include <memory>
35#include <string>
36#include <utility>
37#include <vector>
38
39#include "ceres/internal/config.h"
40#include "ceres/iterative_refiner.h"
41
42#ifndef CERES_NO_CUDA
43#include "ceres/context_impl.h"
44#include "ceres/cuda_kernels_vector_ops.h"
45#include "cuda_runtime.h"
46#include "cusolverDn.h"
47#endif // CERES_NO_CUDA
48
49#ifndef CERES_NO_LAPACK
50
51// C interface to the LAPACK Cholesky factorization and triangular solve.
52extern "C" void dpotrf_(
53 const char* uplo, const int* n, double* a, const int* lda, int* info);
54
55extern "C" void dpotrs_(const char* uplo,
56 const int* n,
57 const int* nrhs,
58 const double* a,
59 const int* lda,
60 double* b,
61 const int* ldb,
62 int* info);
63
64extern "C" void spotrf_(
65 const char* uplo, const int* n, float* a, const int* lda, int* info);
66
67extern "C" void spotrs_(const char* uplo,
68 const int* n,
69 const int* nrhs,
70 const float* a,
71 const int* lda,
72 float* b,
73 const int* ldb,
74 int* info);
75#endif
76
77namespace ceres::internal {
78
79DenseCholesky::~DenseCholesky() = default;
80
81std::unique_ptr<DenseCholesky> DenseCholesky::Create(
82 const LinearSolver::Options& options) {
83 std::unique_ptr<DenseCholesky> dense_cholesky;
84
85 switch (options.dense_linear_algebra_library_type) {
86 case EIGEN:
87 // Eigen mixed precision solver not yet implemented.
88 if (options.use_mixed_precision_solves) {
89 dense_cholesky = std::make_unique<FloatEigenDenseCholesky>();
90 } else {
91 dense_cholesky = std::make_unique<EigenDenseCholesky>();
92 }
93 break;
94
95 case LAPACK:
96#ifndef CERES_NO_LAPACK
97 // LAPACK mixed precision solver not yet implemented.
98 if (options.use_mixed_precision_solves) {
99 dense_cholesky = std::make_unique<FloatLAPACKDenseCholesky>();
100 } else {
101 dense_cholesky = std::make_unique<LAPACKDenseCholesky>();
102 }
103 break;
104#else
105 LOG(FATAL) << "Ceres was compiled without support for LAPACK.";
106#endif
107
108 case CUDA:
109#ifndef CERES_NO_CUDA
110 if (options.use_mixed_precision_solves) {
111 dense_cholesky = CUDADenseCholeskyMixedPrecision::Create(options);
112 } else {
113 dense_cholesky = CUDADenseCholesky::Create(options);
114 }
115 break;
116#else
117 LOG(FATAL) << "Ceres was compiled without support for CUDA.";
118#endif
119
120 default:
121 LOG(FATAL) << "Unknown dense linear algebra library type : "
122 << DenseLinearAlgebraLibraryTypeToString(
123 options.dense_linear_algebra_library_type);
124 }
125
126 if (options.max_num_refinement_iterations > 0) {
127 auto refiner = std::make_unique<DenseIterativeRefiner>(
128 options.max_num_refinement_iterations);
129 dense_cholesky = std::make_unique<RefinedDenseCholesky>(
130 std::move(dense_cholesky), std::move(refiner));
131 }
132
133 return dense_cholesky;
134}
135
136LinearSolverTerminationType DenseCholesky::FactorAndSolve(
137 int num_cols,
138 double* lhs,
139 const double* rhs,
140 double* solution,
141 std::string* message) {
142 LinearSolverTerminationType termination_type =
143 Factorize(num_cols, lhs, message);
144 if (termination_type == LinearSolverTerminationType::SUCCESS) {
145 termination_type = Solve(rhs, solution, message);
146 }
147 return termination_type;
148}
149
150LinearSolverTerminationType EigenDenseCholesky::Factorize(
151 int num_cols, double* lhs, std::string* message) {
152 Eigen::Map<Eigen::MatrixXd> m(lhs, num_cols, num_cols);
153 llt_ = std::make_unique<LLTType>(m);
154 if (llt_->info() != Eigen::Success) {
155 *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
156 return LinearSolverTerminationType::FAILURE;
157 }
158
159 *message = "Success.";
160 return LinearSolverTerminationType::SUCCESS;
161}
162
163LinearSolverTerminationType EigenDenseCholesky::Solve(const double* rhs,
164 double* solution,
165 std::string* message) {
166 if (llt_->info() != Eigen::Success) {
167 *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
168 return LinearSolverTerminationType::FAILURE;
169 }
170
171 VectorRef(solution, llt_->cols()) =
172 llt_->solve(ConstVectorRef(rhs, llt_->cols()));
173 *message = "Success.";
174 return LinearSolverTerminationType::SUCCESS;
175}
176
177LinearSolverTerminationType FloatEigenDenseCholesky::Factorize(
178 int num_cols, double* lhs, std::string* message) {
179 // TODO(sameeragarwal): Check if this causes a double allocation.
180 lhs_ = Eigen::Map<Eigen::MatrixXd>(lhs, num_cols, num_cols).cast<float>();
181 llt_ = std::make_unique<LLTType>(lhs_);
182 if (llt_->info() != Eigen::Success) {
183 *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
184 return LinearSolverTerminationType::FAILURE;
185 }
186
187 *message = "Success.";
188 return LinearSolverTerminationType::SUCCESS;
189}
190
191LinearSolverTerminationType FloatEigenDenseCholesky::Solve(
192 const double* rhs, double* solution, std::string* message) {
193 if (llt_->info() != Eigen::Success) {
194 *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
195 return LinearSolverTerminationType::FAILURE;
196 }
197
198 rhs_ = ConstVectorRef(rhs, llt_->cols()).cast<float>();
199 solution_ = llt_->solve(rhs_);
200 VectorRef(solution, llt_->cols()) = solution_.cast<double>();
201 *message = "Success.";
202 return LinearSolverTerminationType::SUCCESS;
203}
204
205#ifndef CERES_NO_LAPACK
206LinearSolverTerminationType LAPACKDenseCholesky::Factorize(
207 int num_cols, double* lhs, std::string* message) {
208 lhs_ = lhs;
209 num_cols_ = num_cols;
210
211 const char uplo = 'L';
212 int info = 0;
213 dpotrf_(&uplo, &num_cols_, lhs_, &num_cols_, &info);
214
215 if (info < 0) {
216 termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
217 LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
218 << "Please report it. "
219 << "LAPACK::dpotrf fatal error. "
220 << "Argument: " << -info << " is invalid.";
221 } else if (info > 0) {
222 termination_type_ = LinearSolverTerminationType::FAILURE;
223 *message = StringPrintf(
224 "LAPACK::dpotrf numerical failure. "
225 "The leading minor of order %d is not positive definite.",
226 info);
227 } else {
228 termination_type_ = LinearSolverTerminationType::SUCCESS;
229 *message = "Success.";
230 }
231 return termination_type_;
232}
233
234LinearSolverTerminationType LAPACKDenseCholesky::Solve(const double* rhs,
235 double* solution,
236 std::string* message) {
237 const char uplo = 'L';
238 const int nrhs = 1;
239 int info = 0;
240
241 VectorRef(solution, num_cols_) = ConstVectorRef(rhs, num_cols_);
242 dpotrs_(
243 &uplo, &num_cols_, &nrhs, lhs_, &num_cols_, solution, &num_cols_, &info);
244
245 if (info < 0) {
246 termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
247 LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
248 << "Please report it. "
249 << "LAPACK::dpotrs fatal error. "
250 << "Argument: " << -info << " is invalid.";
251 }
252
253 *message = "Success";
254 termination_type_ = LinearSolverTerminationType::SUCCESS;
255
256 return termination_type_;
257}
258
259LinearSolverTerminationType FloatLAPACKDenseCholesky::Factorize(
260 int num_cols, double* lhs, std::string* message) {
261 num_cols_ = num_cols;
262 lhs_ = Eigen::Map<Eigen::MatrixXd>(lhs, num_cols, num_cols).cast<float>();
263
264 const char uplo = 'L';
265 int info = 0;
266 spotrf_(&uplo, &num_cols_, lhs_.data(), &num_cols_, &info);
267
268 if (info < 0) {
269 termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
270 LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
271 << "Please report it. "
272 << "LAPACK::spotrf fatal error. "
273 << "Argument: " << -info << " is invalid.";
274 } else if (info > 0) {
275 termination_type_ = LinearSolverTerminationType::FAILURE;
276 *message = StringPrintf(
277 "LAPACK::spotrf numerical failure. "
278 "The leading minor of order %d is not positive definite.",
279 info);
280 } else {
281 termination_type_ = LinearSolverTerminationType::SUCCESS;
282 *message = "Success.";
283 }
284 return termination_type_;
285}
286
287LinearSolverTerminationType FloatLAPACKDenseCholesky::Solve(
288 const double* rhs, double* solution, std::string* message) {
289 const char uplo = 'L';
290 const int nrhs = 1;
291 int info = 0;
292 rhs_and_solution_ = ConstVectorRef(rhs, num_cols_).cast<float>();
293 spotrs_(&uplo,
294 &num_cols_,
295 &nrhs,
296 lhs_.data(),
297 &num_cols_,
298 rhs_and_solution_.data(),
299 &num_cols_,
300 &info);
301
302 if (info < 0) {
303 termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
304 LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
305 << "Please report it. "
306 << "LAPACK::dpotrs fatal error. "
307 << "Argument: " << -info << " is invalid.";
308 }
309
310 *message = "Success";
311 termination_type_ = LinearSolverTerminationType::SUCCESS;
312 VectorRef(solution, num_cols_) =
313 rhs_and_solution_.head(num_cols_).cast<double>();
314 return termination_type_;
315}
316
317#endif // CERES_NO_LAPACK
318
319RefinedDenseCholesky::RefinedDenseCholesky(
320 std::unique_ptr<DenseCholesky> dense_cholesky,
321 std::unique_ptr<DenseIterativeRefiner> iterative_refiner)
322 : dense_cholesky_(std::move(dense_cholesky)),
323 iterative_refiner_(std::move(iterative_refiner)) {}
324
325RefinedDenseCholesky::~RefinedDenseCholesky() = default;
326
327LinearSolverTerminationType RefinedDenseCholesky::Factorize(
328 const int num_cols, double* lhs, std::string* message) {
329 lhs_ = lhs;
330 num_cols_ = num_cols;
331 return dense_cholesky_->Factorize(num_cols, lhs, message);
332}
333
334LinearSolverTerminationType RefinedDenseCholesky::Solve(const double* rhs,
335 double* solution,
336 std::string* message) {
337 CHECK(lhs_ != nullptr);
338 auto termination_type = dense_cholesky_->Solve(rhs, solution, message);
339 if (termination_type != LinearSolverTerminationType::SUCCESS) {
340 return termination_type;
341 }
342
343 iterative_refiner_->Refine(
344 num_cols_, lhs_, rhs, dense_cholesky_.get(), solution);
345 return LinearSolverTerminationType::SUCCESS;
346}
347
348#ifndef CERES_NO_CUDA
349
350CUDADenseCholesky::CUDADenseCholesky(ContextImpl* context)
351 : context_(context),
352 lhs_{context},
353 rhs_{context},
354 device_workspace_{context},
355 error_(context, 1) {}
356
357LinearSolverTerminationType CUDADenseCholesky::Factorize(int num_cols,
358 double* lhs,
359 std::string* message) {
360 factorize_result_ = LinearSolverTerminationType::FATAL_ERROR;
361 lhs_.Reserve(num_cols * num_cols);
362 num_cols_ = num_cols;
363 lhs_.CopyFromCpu(lhs, num_cols * num_cols);
364 int device_workspace_size = 0;
365 if (cusolverDnDpotrf_bufferSize(context_->cusolver_handle_,
366 CUBLAS_FILL_MODE_LOWER,
367 num_cols,
368 lhs_.data(),
369 num_cols,
370 &device_workspace_size) !=
371 CUSOLVER_STATUS_SUCCESS) {
372 *message = "cuSolverDN::cusolverDnDpotrf_bufferSize failed.";
373 return LinearSolverTerminationType::FATAL_ERROR;
374 }
375 device_workspace_.Reserve(device_workspace_size);
376 if (cusolverDnDpotrf(context_->cusolver_handle_,
377 CUBLAS_FILL_MODE_LOWER,
378 num_cols,
379 lhs_.data(),
380 num_cols,
381 reinterpret_cast<double*>(device_workspace_.data()),
382 device_workspace_.size(),
383 error_.data()) != CUSOLVER_STATUS_SUCCESS) {
384 *message = "cuSolverDN::cusolverDnDpotrf failed.";
385 return LinearSolverTerminationType::FATAL_ERROR;
386 }
387 int error = 0;
388 error_.CopyToCpu(&error, 1);
389 if (error < 0) {
390 LOG(FATAL) << "Congratulations, you found a bug in Ceres - "
391 << "please report it. "
392 << "cuSolverDN::cusolverDnXpotrf fatal error. "
393 << "Argument: " << -error << " is invalid.";
394 // The following line is unreachable, but return failure just to be
395 // pedantic, since the compiler does not know that.
396 return LinearSolverTerminationType::FATAL_ERROR;
397 } else if (error > 0) {
398 *message = StringPrintf(
399 "cuSolverDN::cusolverDnDpotrf numerical failure. "
400 "The leading minor of order %d is not positive definite.",
401 error);
402 factorize_result_ = LinearSolverTerminationType::FAILURE;
403 return LinearSolverTerminationType::FAILURE;
404 }
405 *message = "Success";
406 factorize_result_ = LinearSolverTerminationType::SUCCESS;
407 return LinearSolverTerminationType::SUCCESS;
408}
409
410LinearSolverTerminationType CUDADenseCholesky::Solve(const double* rhs,
411 double* solution,
412 std::string* message) {
413 if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
414 *message = "Factorize did not complete successfully previously.";
415 return factorize_result_;
416 }
417 rhs_.CopyFromCpu(rhs, num_cols_);
418 if (cusolverDnDpotrs(context_->cusolver_handle_,
419 CUBLAS_FILL_MODE_LOWER,
420 num_cols_,
421 1,
422 lhs_.data(),
423 num_cols_,
424 rhs_.data(),
425 num_cols_,
426 error_.data()) != CUSOLVER_STATUS_SUCCESS) {
427 *message = "cuSolverDN::cusolverDnDpotrs failed.";
428 return LinearSolverTerminationType::FATAL_ERROR;
429 }
430 int error = 0;
431 error_.CopyToCpu(&error, 1);
432 if (error != 0) {
433 LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
434 << "Please report it."
435 << "cuSolverDN::cusolverDnDpotrs fatal error. "
436 << "Argument: " << -error << " is invalid.";
437 }
438 rhs_.CopyToCpu(solution, num_cols_);
439 *message = "Success";
440 return LinearSolverTerminationType::SUCCESS;
441}
442
443std::unique_ptr<CUDADenseCholesky> CUDADenseCholesky::Create(
444 const LinearSolver::Options& options) {
445 if (options.dense_linear_algebra_library_type != CUDA ||
446 options.context == nullptr || !options.context->IsCudaInitialized()) {
447 return nullptr;
448 }
449 return std::unique_ptr<CUDADenseCholesky>(
450 new CUDADenseCholesky(options.context));
451}
452
453std::unique_ptr<CUDADenseCholeskyMixedPrecision>
454CUDADenseCholeskyMixedPrecision::Create(const LinearSolver::Options& options) {
455 if (options.dense_linear_algebra_library_type != CUDA ||
456 !options.use_mixed_precision_solves || options.context == nullptr ||
457 !options.context->IsCudaInitialized()) {
458 return nullptr;
459 }
460 return std::unique_ptr<CUDADenseCholeskyMixedPrecision>(
461 new CUDADenseCholeskyMixedPrecision(
462 options.context, options.max_num_refinement_iterations));
463}
464
465LinearSolverTerminationType
466CUDADenseCholeskyMixedPrecision::CudaCholeskyFactorize(std::string* message) {
467 int device_workspace_size = 0;
468 if (cusolverDnSpotrf_bufferSize(context_->cusolver_handle_,
469 CUBLAS_FILL_MODE_LOWER,
470 num_cols_,
471 lhs_fp32_.data(),
472 num_cols_,
473 &device_workspace_size) !=
474 CUSOLVER_STATUS_SUCCESS) {
475 *message = "cuSolverDN::cusolverDnSpotrf_bufferSize failed.";
476 return LinearSolverTerminationType::FATAL_ERROR;
477 }
478 device_workspace_.Reserve(device_workspace_size);
479 if (cusolverDnSpotrf(context_->cusolver_handle_,
480 CUBLAS_FILL_MODE_LOWER,
481 num_cols_,
482 lhs_fp32_.data(),
483 num_cols_,
484 device_workspace_.data(),
485 device_workspace_.size(),
486 error_.data()) != CUSOLVER_STATUS_SUCCESS) {
487 *message = "cuSolverDN::cusolverDnSpotrf failed.";
488 return LinearSolverTerminationType::FATAL_ERROR;
489 }
490 int error = 0;
491 error_.CopyToCpu(&error, 1);
492 if (error < 0) {
493 LOG(FATAL) << "Congratulations, you found a bug in Ceres - "
494 << "please report it. "
495 << "cuSolverDN::cusolverDnSpotrf fatal error. "
496 << "Argument: " << -error << " is invalid.";
497 // The following line is unreachable, but return failure just to be
498 // pedantic, since the compiler does not know that.
499 return LinearSolverTerminationType::FATAL_ERROR;
500 }
501 if (error > 0) {
502 *message = StringPrintf(
503 "cuSolverDN::cusolverDnSpotrf numerical failure. "
504 "The leading minor of order %d is not positive definite.",
505 error);
506 factorize_result_ = LinearSolverTerminationType::FAILURE;
507 return LinearSolverTerminationType::FAILURE;
508 }
509 *message = "Success";
510 return LinearSolverTerminationType::SUCCESS;
511}
512
513LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::CudaCholeskySolve(
514 std::string* message) {
515 CHECK_EQ(cudaMemcpyAsync(correction_fp32_.data(),
516 residual_fp32_.data(),
517 num_cols_ * sizeof(float),
518 cudaMemcpyDeviceToDevice,
519 context_->DefaultStream()),
520 cudaSuccess);
521 if (cusolverDnSpotrs(context_->cusolver_handle_,
522 CUBLAS_FILL_MODE_LOWER,
523 num_cols_,
524 1,
525 lhs_fp32_.data(),
526 num_cols_,
527 correction_fp32_.data(),
528 num_cols_,
529 error_.data()) != CUSOLVER_STATUS_SUCCESS) {
530 *message = "cuSolverDN::cusolverDnDpotrs failed.";
531 return LinearSolverTerminationType::FATAL_ERROR;
532 }
533 int error = 0;
534 error_.CopyToCpu(&error, 1);
535 if (error != 0) {
536 LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
537 << "Please report it."
538 << "cuSolverDN::cusolverDnDpotrs fatal error. "
539 << "Argument: " << -error << " is invalid.";
540 }
541 *message = "Success";
542 return LinearSolverTerminationType::SUCCESS;
543}
544
545CUDADenseCholeskyMixedPrecision::CUDADenseCholeskyMixedPrecision(
546 ContextImpl* context, int max_num_refinement_iterations)
547 : context_(context),
548 lhs_fp64_{context},
549 rhs_fp64_{context},
550 lhs_fp32_{context},
551 device_workspace_{context},
552 error_(context, 1),
553 x_fp64_{context},
554 correction_fp32_{context},
555 residual_fp32_{context},
556 residual_fp64_{context},
557 max_num_refinement_iterations_(max_num_refinement_iterations) {}
558
559LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::Factorize(
560 int num_cols, double* lhs, std::string* message) {
561 num_cols_ = num_cols;
562
563 // Copy fp64 version of lhs to GPU.
564 lhs_fp64_.Reserve(num_cols * num_cols);
565 lhs_fp64_.CopyFromCpu(lhs, num_cols * num_cols);
566
567 // Create an fp32 copy of lhs, lhs_fp32.
568 lhs_fp32_.Reserve(num_cols * num_cols);
569 CudaFP64ToFP32(lhs_fp64_.data(),
570 lhs_fp32_.data(),
571 num_cols * num_cols,
572 context_->DefaultStream());
573
574 // Factorize lhs_fp32.
575 factorize_result_ = CudaCholeskyFactorize(message);
576 return factorize_result_;
577}
578
579LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::Solve(
580 const double* rhs, double* solution, std::string* message) {
581 // If factorization failed, return failure.
582 if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
583 *message = "Factorize did not complete successfully previously.";
584 return factorize_result_;
585 }
586
587 // Reserve memory for all arrays.
588 rhs_fp64_.Reserve(num_cols_);
589 x_fp64_.Reserve(num_cols_);
590 correction_fp32_.Reserve(num_cols_);
591 residual_fp32_.Reserve(num_cols_);
592 residual_fp64_.Reserve(num_cols_);
593
594 // Initialize x = 0.
595 CudaSetZeroFP64(x_fp64_.data(), num_cols_, context_->DefaultStream());
596
597 // Initialize residual = rhs.
598 rhs_fp64_.CopyFromCpu(rhs, num_cols_);
599 residual_fp64_.CopyFromGPUArray(rhs_fp64_.data(), num_cols_);
600
601 for (int i = 0; i <= max_num_refinement_iterations_; ++i) {
602 // Cast residual from fp64 to fp32.
603 CudaFP64ToFP32(residual_fp64_.data(),
604 residual_fp32_.data(),
605 num_cols_,
606 context_->DefaultStream());
607 // [fp32] c = lhs^-1 * residual.
608 auto result = CudaCholeskySolve(message);
609 if (result != LinearSolverTerminationType::SUCCESS) {
610 return result;
611 }
612 // [fp64] x += c.
613 CudaDsxpy(x_fp64_.data(),
614 correction_fp32_.data(),
615 num_cols_,
616 context_->DefaultStream());
617 if (i < max_num_refinement_iterations_) {
618 // [fp64] residual = rhs - lhs * x
619 // This is done in two steps:
620 // 1. [fp64] residual = rhs
621 residual_fp64_.CopyFromGPUArray(rhs_fp64_.data(), num_cols_);
622 // 2. [fp64] residual = residual - lhs * x
623 double alpha = -1.0;
624 double beta = 1.0;
625 cublasDsymv(context_->cublas_handle_,
626 CUBLAS_FILL_MODE_LOWER,
627 num_cols_,
628 &alpha,
629 lhs_fp64_.data(),
630 num_cols_,
631 x_fp64_.data(),
632 1,
633 &beta,
634 residual_fp64_.data(),
635 1);
636 }
637 }
638 x_fp64_.CopyToCpu(solution, num_cols_);
639 *message = "Success.";
640 return LinearSolverTerminationType::SUCCESS;
641}
642
643#endif // CERES_NO_CUDA
644
645} // namespace ceres::internal