blob: 8db5649fe8bc7dda536f89ea22ec2ecf618f1c7a [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: joydeepb@cs.utexas.edu (Joydeep Biswas)
30//
31// A simple CUDA vector class.
32
33#ifndef CERES_INTERNAL_CUDA_VECTOR_H_
34#define CERES_INTERNAL_CUDA_VECTOR_H_
35
36// This include must come before any #ifndef check on Ceres compile options.
37// clang-format off
38#include "ceres/internal/config.h"
39// clang-format on
40
41#include <math.h>
42
43#include <memory>
44#include <string>
45
46#include "ceres/context_impl.h"
47#include "ceres/internal/export.h"
48#include "ceres/types.h"
49
50#ifndef CERES_NO_CUDA
51
52#include "ceres/cuda_buffer.h"
53#include "ceres/cuda_kernels_vector_ops.h"
54#include "ceres/internal/eigen.h"
55#include "cublas_v2.h"
56#include "cusparse.h"
57
58namespace ceres::internal {
59
60// An Nx1 vector, denoted y hosted on the GPU, with CUDA-accelerated operations.
61class CERES_NO_EXPORT CudaVector {
62 public:
63 // Create a pre-allocated vector of size N and return a pointer to it. The
64 // caller must ensure that InitCuda() has already been successfully called on
65 // context before calling this method.
66 CudaVector(ContextImpl* context, int size);
67
68 CudaVector(CudaVector&& other);
69
70 ~CudaVector();
71
72 void Resize(int size);
73
74 // Perform a deep copy of the vector.
75 CudaVector& operator=(const CudaVector&);
76
77 // Return the inner product x' * y.
78 double Dot(const CudaVector& x) const;
79
80 // Return the L2 norm of the vector (||y||_2).
81 double Norm() const;
82
83 // Set all elements to zero.
84 void SetZero();
85
86 // Copy from Eigen vector.
87 void CopyFromCpu(const Vector& x);
88
89 // Copy from CPU memory array.
90 void CopyFromCpu(const double* x);
91
92 // Copy to Eigen vector.
93 void CopyTo(Vector* x) const;
94
95 // Copy to CPU memory array. It is the caller's responsibility to ensure
96 // that the array is large enough.
97 void CopyTo(double* x) const;
98
99 // y = a * x + b * y.
100 void Axpby(double a, const CudaVector& x, double b);
101
102 // y = diag(d)' * diag(d) * x + y.
103 void DtDxpy(const CudaVector& D, const CudaVector& x);
104
105 // y = s * y.
106 void Scale(double s);
107
108 int num_rows() const { return num_rows_; }
109 int num_cols() const { return 1; }
110
111 const double* data() const { return data_.data(); }
112 double* mutable_data() { return data_.data(); }
113
114 const cusparseDnVecDescr_t& descr() const { return descr_; }
115
116 private:
117 CudaVector(const CudaVector&) = delete;
118 void DestroyDescriptor();
119
120 int num_rows_ = 0;
121 ContextImpl* context_ = nullptr;
122 CudaBuffer<double> data_;
123 // CuSparse object that describes this dense vector.
124 cusparseDnVecDescr_t descr_ = nullptr;
125};
126
127// Blas1 operations on Cuda vectors. These functions are needed as an
128// abstraction layer so that we can use different versions of a vector style
129// object in the conjugate gradients linear solver.
130// Context and num_threads arguments are not used by CUDA implementation,
131// context embedded into CudaVector is used instead.
132inline double Norm(const CudaVector& x,
133 ContextImpl* context = nullptr,
134 int num_threads = 1) {
135 (void)context;
136 (void)num_threads;
137 return x.Norm();
138}
139inline void SetZero(CudaVector& x,
140 ContextImpl* context = nullptr,
141 int num_threads = 1) {
142 (void)context;
143 (void)num_threads;
144 x.SetZero();
145}
146inline void Axpby(double a,
147 const CudaVector& x,
148 double b,
149 const CudaVector& y,
150 CudaVector& z,
151 ContextImpl* context = nullptr,
152 int num_threads = 1) {
153 (void)context;
154 (void)num_threads;
155 if (&x == &y && &y == &z) {
156 // z = (a + b) * z;
157 z.Scale(a + b);
158 } else if (&x == &z) {
159 // x is aliased to z.
160 // z = x
161 // = b * y + a * x;
162 z.Axpby(b, y, a);
163 } else if (&y == &z) {
164 // y is aliased to z.
165 // z = y = a * x + b * y;
166 z.Axpby(a, x, b);
167 } else {
168 // General case: all inputs and outputs are distinct.
169 z = y;
170 z.Axpby(a, x, b);
171 }
172}
173inline double Dot(const CudaVector& x,
174 const CudaVector& y,
175 ContextImpl* context = nullptr,
176 int num_threads = 1) {
177 (void)context;
178 (void)num_threads;
179 return x.Dot(y);
180}
181inline void Copy(const CudaVector& from,
182 CudaVector& to,
183 ContextImpl* context = nullptr,
184 int num_threads = 1) {
185 (void)context;
186 (void)num_threads;
187 to = from;
188}
189
190} // namespace ceres::internal
191
192#endif // CERES_NO_CUDA
193#endif // CERES_INTERNAL_CUDA_SPARSE_LINEAR_OPERATOR_H_