blob: 97922aab6c3163eedb384c85156c0a9d796342d9 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
Austin Schuh3de38b02024-06-25 18:25:10 -07002// Copyright 2023 Google Inc. All rights reserved.
Austin Schuh70cc9552019-01-21 19:46:48 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: keir@google.com (Keir Mierle)
30
31#include "ceres/small_blas.h"
32
33#include <limits>
Austin Schuh3de38b02024-06-25 18:25:10 -070034#include <string>
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080035
Austin Schuh70cc9552019-01-21 19:46:48 -080036#include "ceres/internal/eigen.h"
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080037#include "gtest/gtest.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080038
39namespace ceres {
40namespace internal {
41
Austin Schuh3de38b02024-06-25 18:25:10 -070042const double kTolerance = 5.0 * std::numeric_limits<double>::epsilon();
Austin Schuh70cc9552019-01-21 19:46:48 -080043
Austin Schuh3de38b02024-06-25 18:25:10 -070044// Static or dynamic problem types.
45enum class DimType { Static, Dynamic };
Austin Schuh70cc9552019-01-21 19:46:48 -080046
Austin Schuh3de38b02024-06-25 18:25:10 -070047// Constructs matrix functor type.
48#define MATRIX_FUN_TY(FN) \
49 template <int kRowA, \
50 int kColA, \
51 int kRowB, \
52 int kColB, \
53 int kOperation, \
54 DimType kDimType> \
55 struct FN##Ty { \
56 void operator()(const double* A, \
57 const int num_row_a, \
58 const int num_col_a, \
59 const double* B, \
60 const int num_row_b, \
61 const int num_col_b, \
62 double* C, \
63 const int start_row_c, \
64 const int start_col_c, \
65 const int row_stride_c, \
66 const int col_stride_c) { \
67 if (kDimType == DimType::Static) { \
68 FN<kRowA, kColA, kRowB, kColB, kOperation>(A, \
69 num_row_a, \
70 num_col_a, \
71 B, \
72 num_row_b, \
73 num_col_b, \
74 C, \
75 start_row_c, \
76 start_col_c, \
77 row_stride_c, \
78 col_stride_c); \
79 } else { \
80 FN<Eigen::Dynamic, \
81 Eigen::Dynamic, \
82 Eigen::Dynamic, \
83 Eigen::Dynamic, \
84 kOperation>(A, \
85 num_row_a, \
86 num_col_a, \
87 B, \
88 num_row_b, \
89 num_col_b, \
90 C, \
91 start_row_c, \
92 start_col_c, \
93 row_stride_c, \
94 col_stride_c); \
95 } \
96 } \
97 };
Austin Schuh70cc9552019-01-21 19:46:48 -080098
Austin Schuh3de38b02024-06-25 18:25:10 -070099MATRIX_FUN_TY(MatrixMatrixMultiply)
100MATRIX_FUN_TY(MatrixMatrixMultiplyNaive)
101MATRIX_FUN_TY(MatrixTransposeMatrixMultiply)
102MATRIX_FUN_TY(MatrixTransposeMatrixMultiplyNaive)
Austin Schuh70cc9552019-01-21 19:46:48 -0800103
Austin Schuh3de38b02024-06-25 18:25:10 -0700104#undef MATRIX_FUN_TY
Austin Schuh70cc9552019-01-21 19:46:48 -0800105
Austin Schuh3de38b02024-06-25 18:25:10 -0700106// Initializes matrix entries.
107static void initMatrix(Matrix& mat) {
108 for (int i = 0; i < mat.rows(); ++i) {
109 for (int j = 0; j < mat.cols(); ++j) {
110 mat(i, j) = i + j + 1;
Austin Schuh70cc9552019-01-21 19:46:48 -0800111 }
112 }
113}
114
Austin Schuh3de38b02024-06-25 18:25:10 -0700115template <int kRowA,
116 int kColA,
117 int kColB,
118 DimType kDimType,
119 template <int, int, int, int, int, DimType>
120 class FunctorTy>
121struct TestMatrixFunctions {
122 void operator()() {
123 Matrix A(kRowA, kColA);
124 initMatrix(A);
125 const int kRowB = kColA;
126 Matrix B(kRowB, kColB);
127 initMatrix(B);
Austin Schuh70cc9552019-01-21 19:46:48 -0800128
Austin Schuh3de38b02024-06-25 18:25:10 -0700129 for (int row_stride_c = kRowA; row_stride_c < 3 * kRowA; ++row_stride_c) {
130 for (int col_stride_c = kColB; col_stride_c < 3 * kColB; ++col_stride_c) {
131 Matrix C(row_stride_c, col_stride_c);
132 C.setOnes();
Austin Schuh70cc9552019-01-21 19:46:48 -0800133
Austin Schuh3de38b02024-06-25 18:25:10 -0700134 Matrix C_plus = C;
135 Matrix C_minus = C;
136 Matrix C_assign = C;
Austin Schuh70cc9552019-01-21 19:46:48 -0800137
Austin Schuh3de38b02024-06-25 18:25:10 -0700138 Matrix C_plus_ref = C;
139 Matrix C_minus_ref = C;
140 Matrix C_assign_ref = C;
Austin Schuh70cc9552019-01-21 19:46:48 -0800141
Austin Schuh3de38b02024-06-25 18:25:10 -0700142 for (int start_row_c = 0; start_row_c + kRowA < row_stride_c;
143 ++start_row_c) {
144 for (int start_col_c = 0; start_col_c + kColB < col_stride_c;
145 ++start_col_c) {
146 C_plus_ref.block(start_row_c, start_col_c, kRowA, kColB) += A * B;
147 FunctorTy<kRowA, kColA, kRowB, kColB, 1, kDimType>()(A.data(),
148 kRowA,
149 kColA,
150 B.data(),
151 kRowB,
152 kColB,
153 C_plus.data(),
154 start_row_c,
155 start_col_c,
156 row_stride_c,
157 col_stride_c);
Austin Schuh70cc9552019-01-21 19:46:48 -0800158
Austin Schuh3de38b02024-06-25 18:25:10 -0700159 EXPECT_NEAR((C_plus_ref - C_plus).norm(), 0.0, kTolerance)
160 << "C += A * B \n"
161 << "row_stride_c : " << row_stride_c << "\n"
162 << "col_stride_c : " << col_stride_c << "\n"
163 << "start_row_c : " << start_row_c << "\n"
164 << "start_col_c : " << start_col_c << "\n"
165 << "Cref : \n"
166 << C_plus_ref << "\n"
167 << "C: \n"
168 << C_plus;
Austin Schuh70cc9552019-01-21 19:46:48 -0800169
Austin Schuh3de38b02024-06-25 18:25:10 -0700170 C_minus_ref.block(start_row_c, start_col_c, kRowA, kColB) -= A * B;
171 FunctorTy<kRowA, kColA, kRowB, kColB, -1, kDimType>()(
172 A.data(),
173 kRowA,
174 kColA,
175 B.data(),
176 kRowB,
177 kColB,
178 C_minus.data(),
179 start_row_c,
180 start_col_c,
181 row_stride_c,
182 col_stride_c);
Austin Schuh70cc9552019-01-21 19:46:48 -0800183
Austin Schuh3de38b02024-06-25 18:25:10 -0700184 EXPECT_NEAR((C_minus_ref - C_minus).norm(), 0.0, kTolerance)
185 << "C -= A * B \n"
186 << "row_stride_c : " << row_stride_c << "\n"
187 << "col_stride_c : " << col_stride_c << "\n"
188 << "start_row_c : " << start_row_c << "\n"
189 << "start_col_c : " << start_col_c << "\n"
190 << "Cref : \n"
191 << C_minus_ref << "\n"
192 << "C: \n"
193 << C_minus;
Austin Schuh70cc9552019-01-21 19:46:48 -0800194
Austin Schuh3de38b02024-06-25 18:25:10 -0700195 C_assign_ref.block(start_row_c, start_col_c, kRowA, kColB) = A * B;
Austin Schuh70cc9552019-01-21 19:46:48 -0800196
Austin Schuh3de38b02024-06-25 18:25:10 -0700197 FunctorTy<kRowA, kColA, kRowB, kColB, 0, kDimType>()(
198 A.data(),
199 kRowA,
200 kColA,
201 B.data(),
202 kRowB,
203 kColB,
204 C_assign.data(),
205 start_row_c,
206 start_col_c,
207 row_stride_c,
208 col_stride_c);
Austin Schuh70cc9552019-01-21 19:46:48 -0800209
Austin Schuh3de38b02024-06-25 18:25:10 -0700210 EXPECT_NEAR((C_assign_ref - C_assign).norm(), 0.0, kTolerance)
211 << "C = A * B \n"
212 << "row_stride_c : " << row_stride_c << "\n"
213 << "col_stride_c : " << col_stride_c << "\n"
214 << "start_row_c : " << start_row_c << "\n"
215 << "start_col_c : " << start_col_c << "\n"
216 << "Cref : \n"
217 << C_assign_ref << "\n"
218 << "C: \n"
219 << C_assign;
220 }
Austin Schuh70cc9552019-01-21 19:46:48 -0800221 }
222 }
223 }
224 }
Austin Schuh3de38b02024-06-25 18:25:10 -0700225};
226
227template <int kRowA,
228 int kColA,
229 int kColB,
230 DimType kDimType,
231 template <int, int, int, int, int, DimType>
232 class FunctorTy>
233struct TestMatrixTransposeFunctions {
234 void operator()() {
235 Matrix A(kRowA, kColA);
236 initMatrix(A);
237 const int kRowB = kRowA;
238 Matrix B(kRowB, kColB);
239 initMatrix(B);
240
241 for (int row_stride_c = kColA; row_stride_c < 3 * kColA; ++row_stride_c) {
242 for (int col_stride_c = kColB; col_stride_c < 3 * kColB; ++col_stride_c) {
243 Matrix C(row_stride_c, col_stride_c);
244 C.setOnes();
245
246 Matrix C_plus = C;
247 Matrix C_minus = C;
248 Matrix C_assign = C;
249
250 Matrix C_plus_ref = C;
251 Matrix C_minus_ref = C;
252 Matrix C_assign_ref = C;
253 for (int start_row_c = 0; start_row_c + kColA < row_stride_c;
254 ++start_row_c) {
255 for (int start_col_c = 0; start_col_c + kColB < col_stride_c;
256 ++start_col_c) {
257 C_plus_ref.block(start_row_c, start_col_c, kColA, kColB) +=
258 A.transpose() * B;
259
260 FunctorTy<kRowA, kColA, kRowB, kColB, 1, kDimType>()(A.data(),
261 kRowA,
262 kColA,
263 B.data(),
264 kRowB,
265 kColB,
266 C_plus.data(),
267 start_row_c,
268 start_col_c,
269 row_stride_c,
270 col_stride_c);
271
272 EXPECT_NEAR((C_plus_ref - C_plus).norm(), 0.0, kTolerance)
273 << "C += A' * B \n"
274 << "row_stride_c : " << row_stride_c << "\n"
275 << "col_stride_c : " << col_stride_c << "\n"
276 << "start_row_c : " << start_row_c << "\n"
277 << "start_col_c : " << start_col_c << "\n"
278 << "Cref : \n"
279 << C_plus_ref << "\n"
280 << "C: \n"
281 << C_plus;
282
283 C_minus_ref.block(start_row_c, start_col_c, kColA, kColB) -=
284 A.transpose() * B;
285
286 FunctorTy<kRowA, kColA, kRowB, kColB, -1, kDimType>()(
287 A.data(),
288 kRowA,
289 kColA,
290 B.data(),
291 kRowB,
292 kColB,
293 C_minus.data(),
294 start_row_c,
295 start_col_c,
296 row_stride_c,
297 col_stride_c);
298
299 EXPECT_NEAR((C_minus_ref - C_minus).norm(), 0.0, kTolerance)
300 << "C -= A' * B \n"
301 << "row_stride_c : " << row_stride_c << "\n"
302 << "col_stride_c : " << col_stride_c << "\n"
303 << "start_row_c : " << start_row_c << "\n"
304 << "start_col_c : " << start_col_c << "\n"
305 << "Cref : \n"
306 << C_minus_ref << "\n"
307 << "C: \n"
308 << C_minus;
309
310 C_assign_ref.block(start_row_c, start_col_c, kColA, kColB) =
311 A.transpose() * B;
312
313 FunctorTy<kRowA, kColA, kRowB, kColB, 0, kDimType>()(
314 A.data(),
315 kRowA,
316 kColA,
317 B.data(),
318 kRowB,
319 kColB,
320 C_assign.data(),
321 start_row_c,
322 start_col_c,
323 row_stride_c,
324 col_stride_c);
325
326 EXPECT_NEAR((C_assign_ref - C_assign).norm(), 0.0, kTolerance)
327 << "C = A' * B \n"
328 << "row_stride_c : " << row_stride_c << "\n"
329 << "col_stride_c : " << col_stride_c << "\n"
330 << "start_row_c : " << start_row_c << "\n"
331 << "start_col_c : " << start_col_c << "\n"
332 << "Cref : \n"
333 << C_assign_ref << "\n"
334 << "C: \n"
335 << C_assign;
336 }
337 }
338 }
339 }
340 }
341};
342
343TEST(BLAS, MatrixMatrixMultiply_5_3_7) {
344 TestMatrixFunctions<5, 3, 7, DimType::Static, MatrixMatrixMultiplyTy>()();
Austin Schuh70cc9552019-01-21 19:46:48 -0800345}
346
Austin Schuh3de38b02024-06-25 18:25:10 -0700347TEST(BLAS, MatrixMatrixMultiply_5_3_7_Dynamic) {
348 TestMatrixFunctions<5, 3, 7, DimType::Dynamic, MatrixMatrixMultiplyTy>()();
Austin Schuh70cc9552019-01-21 19:46:48 -0800349}
350
Austin Schuh3de38b02024-06-25 18:25:10 -0700351TEST(BLAS, MatrixMatrixMultiply_1_1_1) {
352 TestMatrixFunctions<1, 1, 1, DimType::Static, MatrixMatrixMultiplyTy>()();
353}
Austin Schuh70cc9552019-01-21 19:46:48 -0800354
Austin Schuh3de38b02024-06-25 18:25:10 -0700355TEST(BLAS, MatrixMatrixMultiply_1_1_1_Dynamic) {
356 TestMatrixFunctions<1, 1, 1, DimType::Dynamic, MatrixMatrixMultiplyTy>()();
357}
Austin Schuh70cc9552019-01-21 19:46:48 -0800358
Austin Schuh3de38b02024-06-25 18:25:10 -0700359TEST(BLAS, MatrixMatrixMultiply_9_9_9) {
360 TestMatrixFunctions<9, 9, 9, DimType::Static, MatrixMatrixMultiplyTy>()();
361}
Austin Schuh70cc9552019-01-21 19:46:48 -0800362
Austin Schuh3de38b02024-06-25 18:25:10 -0700363TEST(BLAS, MatrixMatrixMultiply_9_9_9_Dynamic) {
364 TestMatrixFunctions<9, 9, 9, DimType::Dynamic, MatrixMatrixMultiplyTy>()();
365}
Austin Schuh70cc9552019-01-21 19:46:48 -0800366
Austin Schuh3de38b02024-06-25 18:25:10 -0700367TEST(BLAS, MatrixMatrixMultiplyNaive_5_3_7) {
368 TestMatrixFunctions<5,
369 3,
370 7,
371 DimType::Static,
372 MatrixMatrixMultiplyNaiveTy>()();
373}
Austin Schuh70cc9552019-01-21 19:46:48 -0800374
Austin Schuh3de38b02024-06-25 18:25:10 -0700375TEST(BLAS, MatrixMatrixMultiplyNaive_5_3_7_Dynamic) {
376 TestMatrixFunctions<5,
377 3,
378 7,
379 DimType::Dynamic,
380 MatrixMatrixMultiplyNaiveTy>()();
381}
Austin Schuh70cc9552019-01-21 19:46:48 -0800382
Austin Schuh3de38b02024-06-25 18:25:10 -0700383TEST(BLAS, MatrixMatrixMultiplyNaive_1_1_1) {
384 TestMatrixFunctions<1,
385 1,
386 1,
387 DimType::Static,
388 MatrixMatrixMultiplyNaiveTy>()();
389}
Austin Schuh70cc9552019-01-21 19:46:48 -0800390
Austin Schuh3de38b02024-06-25 18:25:10 -0700391TEST(BLAS, MatrixMatrixMultiplyNaive_1_1_1_Dynamic) {
392 TestMatrixFunctions<1,
393 1,
394 1,
395 DimType::Dynamic,
396 MatrixMatrixMultiplyNaiveTy>()();
397}
Austin Schuh70cc9552019-01-21 19:46:48 -0800398
Austin Schuh3de38b02024-06-25 18:25:10 -0700399TEST(BLAS, MatrixMatrixMultiplyNaive_9_9_9) {
400 TestMatrixFunctions<9,
401 9,
402 9,
403 DimType::Static,
404 MatrixMatrixMultiplyNaiveTy>()();
405}
Austin Schuh70cc9552019-01-21 19:46:48 -0800406
Austin Schuh3de38b02024-06-25 18:25:10 -0700407TEST(BLAS, MatrixMatrixMultiplyNaive_9_9_9_Dynamic) {
408 TestMatrixFunctions<9,
409 9,
410 9,
411 DimType::Dynamic,
412 MatrixMatrixMultiplyNaiveTy>()();
413}
Austin Schuh70cc9552019-01-21 19:46:48 -0800414
Austin Schuh3de38b02024-06-25 18:25:10 -0700415TEST(BLAS, MatrixTransposeMatrixMultiply_5_3_7) {
416 TestMatrixTransposeFunctions<5,
417 3,
418 7,
419 DimType::Static,
420 MatrixTransposeMatrixMultiplyTy>()();
421}
Austin Schuh70cc9552019-01-21 19:46:48 -0800422
Austin Schuh3de38b02024-06-25 18:25:10 -0700423TEST(BLAS, MatrixTransposeMatrixMultiply_5_3_7_Dynamic) {
424 TestMatrixTransposeFunctions<5,
425 3,
426 7,
427 DimType::Dynamic,
428 MatrixTransposeMatrixMultiplyTy>()();
429}
Austin Schuh70cc9552019-01-21 19:46:48 -0800430
Austin Schuh3de38b02024-06-25 18:25:10 -0700431TEST(BLAS, MatrixTransposeMatrixMultiply_1_1_1) {
432 TestMatrixTransposeFunctions<1,
433 1,
434 1,
435 DimType::Static,
436 MatrixTransposeMatrixMultiplyTy>()();
437}
438
439TEST(BLAS, MatrixTransposeMatrixMultiply_1_1_1_Dynamic) {
440 TestMatrixTransposeFunctions<1,
441 1,
442 1,
443 DimType::Dynamic,
444 MatrixTransposeMatrixMultiplyTy>()();
445}
446
447TEST(BLAS, MatrixTransposeMatrixMultiply_9_9_9) {
448 TestMatrixTransposeFunctions<9,
449 9,
450 9,
451 DimType::Static,
452 MatrixTransposeMatrixMultiplyTy>()();
453}
454
455TEST(BLAS, MatrixTransposeMatrixMultiply_9_9_9_Dynamic) {
456 TestMatrixTransposeFunctions<9,
457 9,
458 9,
459 DimType::Dynamic,
460 MatrixTransposeMatrixMultiplyTy>()();
461}
462
463TEST(BLAS, MatrixTransposeMatrixMultiplyNaive_5_3_7) {
464 TestMatrixTransposeFunctions<5,
465 3,
466 7,
467 DimType::Static,
468 MatrixTransposeMatrixMultiplyNaiveTy>()();
469}
470
471TEST(BLAS, MatrixTransposeMatrixMultiplyNaive_5_3_7_Dynamic) {
472 TestMatrixTransposeFunctions<5,
473 3,
474 7,
475 DimType::Dynamic,
476 MatrixTransposeMatrixMultiplyNaiveTy>()();
477}
478
479TEST(BLAS, MatrixTransposeMatrixMultiplyNaive_1_1_1) {
480 TestMatrixTransposeFunctions<1,
481 1,
482 1,
483 DimType::Static,
484 MatrixTransposeMatrixMultiplyNaiveTy>()();
485}
486
487TEST(BLAS, MatrixTransposeMatrixMultiplyNaive_1_1_1_Dynamic) {
488 TestMatrixTransposeFunctions<1,
489 1,
490 1,
491 DimType::Dynamic,
492 MatrixTransposeMatrixMultiplyNaiveTy>()();
493}
494
495TEST(BLAS, MatrixTransposeMatrixMultiplyNaive_9_9_9) {
496 TestMatrixTransposeFunctions<9,
497 9,
498 9,
499 DimType::Static,
500 MatrixTransposeMatrixMultiplyNaiveTy>()();
501}
502
503TEST(BLAS, MatrixTransposeMatrixMultiplyNaive_9_9_9_Dynamic) {
504 TestMatrixTransposeFunctions<9,
505 9,
506 9,
507 DimType::Dynamic,
508 MatrixTransposeMatrixMultiplyNaiveTy>()();
Austin Schuh70cc9552019-01-21 19:46:48 -0800509}
510
511TEST(BLAS, MatrixVectorMultiply) {
512 for (int num_rows_a = 1; num_rows_a < 10; ++num_rows_a) {
513 for (int num_cols_a = 1; num_cols_a < 10; ++num_cols_a) {
514 Matrix A(num_rows_a, num_cols_a);
515 A.setOnes();
516
517 Vector b(num_cols_a);
518 b.setOnes();
519
520 Vector c(num_rows_a);
521 c.setOnes();
522
523 Vector c_plus = c;
524 Vector c_minus = c;
525 Vector c_assign = c;
526
527 Vector c_plus_ref = c;
528 Vector c_minus_ref = c;
529 Vector c_assign_ref = c;
530
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800531 // clang-format off
Austin Schuh70cc9552019-01-21 19:46:48 -0800532 c_plus_ref += A * b;
533 MatrixVectorMultiply<Eigen::Dynamic, Eigen::Dynamic, 1>(
534 A.data(), num_rows_a, num_cols_a,
535 b.data(),
536 c_plus.data());
537 EXPECT_NEAR((c_plus_ref - c_plus).norm(), 0.0, kTolerance)
538 << "c += A * b \n"
539 << "c_ref : \n" << c_plus_ref << "\n"
540 << "c: \n" << c_plus;
541
542 c_minus_ref -= A * b;
543 MatrixVectorMultiply<Eigen::Dynamic, Eigen::Dynamic, -1>(
544 A.data(), num_rows_a, num_cols_a,
545 b.data(),
546 c_minus.data());
547 EXPECT_NEAR((c_minus_ref - c_minus).norm(), 0.0, kTolerance)
Austin Schuh3de38b02024-06-25 18:25:10 -0700548 << "c -= A * b \n"
Austin Schuh70cc9552019-01-21 19:46:48 -0800549 << "c_ref : \n" << c_minus_ref << "\n"
550 << "c: \n" << c_minus;
551
552 c_assign_ref = A * b;
553 MatrixVectorMultiply<Eigen::Dynamic, Eigen::Dynamic, 0>(
554 A.data(), num_rows_a, num_cols_a,
555 b.data(),
556 c_assign.data());
557 EXPECT_NEAR((c_assign_ref - c_assign).norm(), 0.0, kTolerance)
Austin Schuh3de38b02024-06-25 18:25:10 -0700558 << "c = A * b \n"
Austin Schuh70cc9552019-01-21 19:46:48 -0800559 << "c_ref : \n" << c_assign_ref << "\n"
560 << "c: \n" << c_assign;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800561 // clang-format on
Austin Schuh70cc9552019-01-21 19:46:48 -0800562 }
563 }
564}
565
566TEST(BLAS, MatrixTransposeVectorMultiply) {
567 for (int num_rows_a = 1; num_rows_a < 10; ++num_rows_a) {
568 for (int num_cols_a = 1; num_cols_a < 10; ++num_cols_a) {
569 Matrix A(num_rows_a, num_cols_a);
570 A.setRandom();
571
572 Vector b(num_rows_a);
573 b.setRandom();
574
575 Vector c(num_cols_a);
576 c.setOnes();
577
578 Vector c_plus = c;
579 Vector c_minus = c;
580 Vector c_assign = c;
581
582 Vector c_plus_ref = c;
583 Vector c_minus_ref = c;
584 Vector c_assign_ref = c;
585
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800586 // clang-format off
Austin Schuh70cc9552019-01-21 19:46:48 -0800587 c_plus_ref += A.transpose() * b;
588 MatrixTransposeVectorMultiply<Eigen::Dynamic, Eigen::Dynamic, 1>(
589 A.data(), num_rows_a, num_cols_a,
590 b.data(),
591 c_plus.data());
592 EXPECT_NEAR((c_plus_ref - c_plus).norm(), 0.0, kTolerance)
593 << "c += A' * b \n"
594 << "c_ref : \n" << c_plus_ref << "\n"
595 << "c: \n" << c_plus;
596
597 c_minus_ref -= A.transpose() * b;
598 MatrixTransposeVectorMultiply<Eigen::Dynamic, Eigen::Dynamic, -1>(
599 A.data(), num_rows_a, num_cols_a,
600 b.data(),
601 c_minus.data());
602 EXPECT_NEAR((c_minus_ref - c_minus).norm(), 0.0, kTolerance)
Austin Schuh3de38b02024-06-25 18:25:10 -0700603 << "c -= A' * b \n"
Austin Schuh70cc9552019-01-21 19:46:48 -0800604 << "c_ref : \n" << c_minus_ref << "\n"
605 << "c: \n" << c_minus;
606
607 c_assign_ref = A.transpose() * b;
608 MatrixTransposeVectorMultiply<Eigen::Dynamic, Eigen::Dynamic, 0>(
609 A.data(), num_rows_a, num_cols_a,
610 b.data(),
611 c_assign.data());
612 EXPECT_NEAR((c_assign_ref - c_assign).norm(), 0.0, kTolerance)
Austin Schuh3de38b02024-06-25 18:25:10 -0700613 << "c = A' * b \n"
Austin Schuh70cc9552019-01-21 19:46:48 -0800614 << "c_ref : \n" << c_assign_ref << "\n"
615 << "c: \n" << c_assign;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800616 // clang-format on
Austin Schuh70cc9552019-01-21 19:46:48 -0800617 }
618 }
619}
620
621} // namespace internal
622} // namespace ceres