blob: 08217b282866485f2071b6aa3f8787200ac578f1 [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: joydeepb@cs.utexas.edu (Joydeep Biswas)
30//
31// A simple CUDA vector class.
32
33// This include must come before any #ifndef check on Ceres compile options.
34// clang-format off
35#include "ceres/internal/config.h"
36// clang-format on
37
38#include <math.h>
39
40#include "ceres/context_impl.h"
41#include "ceres/internal/export.h"
42#include "ceres/types.h"
43
44#ifndef CERES_NO_CUDA
45
46#include "ceres/cuda_buffer.h"
47#include "ceres/cuda_kernels_vector_ops.h"
48#include "ceres/cuda_vector.h"
49#include "cublas_v2.h"
50
51namespace ceres::internal {
52
53CudaVector::CudaVector(ContextImpl* context, int size)
54 : context_(context), data_(context, size) {
55 DCHECK_NE(context, nullptr);
56 DCHECK(context->IsCudaInitialized());
57 Resize(size);
58}
59
60CudaVector::CudaVector(CudaVector&& other)
61 : num_rows_(other.num_rows_),
62 context_(other.context_),
63 data_(std::move(other.data_)),
64 descr_(other.descr_) {
65 other.num_rows_ = 0;
66 other.descr_ = nullptr;
67}
68
69CudaVector& CudaVector::operator=(const CudaVector& other) {
70 if (this != &other) {
71 Resize(other.num_rows());
72 data_.CopyFromGPUArray(other.data_.data(), num_rows_);
73 }
74 return *this;
75}
76
77void CudaVector::DestroyDescriptor() {
78 if (descr_ != nullptr) {
79 CHECK_EQ(cusparseDestroyDnVec(descr_), CUSPARSE_STATUS_SUCCESS);
80 descr_ = nullptr;
81 }
82}
83
84CudaVector::~CudaVector() { DestroyDescriptor(); }
85
86void CudaVector::Resize(int size) {
87 data_.Reserve(size);
88 num_rows_ = size;
89 DestroyDescriptor();
90 CHECK_EQ(cusparseCreateDnVec(&descr_, num_rows_, data_.data(), CUDA_R_64F),
91 CUSPARSE_STATUS_SUCCESS);
92}
93
94double CudaVector::Dot(const CudaVector& x) const {
95 double result = 0;
96 CHECK_EQ(cublasDdot(context_->cublas_handle_,
97 num_rows_,
98 data_.data(),
99 1,
100 x.data(),
101 1,
102 &result),
103 CUBLAS_STATUS_SUCCESS)
104 << "CuBLAS cublasDdot failed.";
105 return result;
106}
107
108double CudaVector::Norm() const {
109 double result = 0;
110 CHECK_EQ(cublasDnrm2(
111 context_->cublas_handle_, num_rows_, data_.data(), 1, &result),
112 CUBLAS_STATUS_SUCCESS)
113 << "CuBLAS cublasDnrm2 failed.";
114 return result;
115}
116
117void CudaVector::CopyFromCpu(const double* x) {
118 data_.CopyFromCpu(x, num_rows_);
119}
120
121void CudaVector::CopyFromCpu(const Vector& x) {
122 if (x.rows() != num_rows_) {
123 Resize(x.rows());
124 }
125 CopyFromCpu(x.data());
126}
127
128void CudaVector::CopyTo(Vector* x) const {
129 CHECK(x != nullptr);
130 x->resize(num_rows_);
131 data_.CopyToCpu(x->data(), num_rows_);
132}
133
134void CudaVector::CopyTo(double* x) const {
135 CHECK(x != nullptr);
136 data_.CopyToCpu(x, num_rows_);
137}
138
139void CudaVector::SetZero() {
140 // Allow empty vector to be zeroed
141 if (num_rows_ == 0) return;
142 CHECK(data_.data() != nullptr);
143 CudaSetZeroFP64(data_.data(), num_rows_, context_->DefaultStream());
144}
145
146void CudaVector::Axpby(double a, const CudaVector& x, double b) {
147 if (&x == this) {
148 Scale(a + b);
149 return;
150 }
151 CHECK_EQ(num_rows_, x.num_rows_);
152 if (b != 1.0) {
153 // First scale y by b.
154 CHECK_EQ(
155 cublasDscal(context_->cublas_handle_, num_rows_, &b, data_.data(), 1),
156 CUBLAS_STATUS_SUCCESS)
157 << "CuBLAS cublasDscal failed.";
158 }
159 // Then add a * x to y.
160 CHECK_EQ(cublasDaxpy(context_->cublas_handle_,
161 num_rows_,
162 &a,
163 x.data(),
164 1,
165 data_.data(),
166 1),
167 CUBLAS_STATUS_SUCCESS)
168 << "CuBLAS cublasDaxpy failed.";
169}
170
171void CudaVector::DtDxpy(const CudaVector& D, const CudaVector& x) {
172 CudaDtDxpy(
173 data_.data(), D.data(), x.data(), num_rows_, context_->DefaultStream());
174}
175
176void CudaVector::Scale(double s) {
177 CHECK_EQ(
178 cublasDscal(context_->cublas_handle_, num_rows_, &s, data_.data(), 1),
179 CUBLAS_STATUS_SUCCESS)
180 << "CuBLAS cublasDscal failed.";
181}
182
183} // namespace ceres::internal
184
185#endif // CERES_NO_CUDA