blob: 81c58722d5bd977ee2de7d1366089e0d84f99d36 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2015 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30//
31// Simple blas functions for use in the Schur Eliminator. These are
32// fairly basic implementations which already yield a significant
33// speedup in the eliminator performance.
34
35#ifndef CERES_INTERNAL_SMALL_BLAS_H_
36#define CERES_INTERNAL_SMALL_BLAS_H_
37
38#include "ceres/internal/port.h"
39#include "ceres/internal/eigen.h"
40#include "glog/logging.h"
41#include "small_blas_generic.h"
42
43namespace ceres {
44namespace internal {
45
46// The following three macros are used to share code and reduce
47// template junk across the various GEMM variants.
48#define CERES_GEMM_BEGIN(name) \
49 template<int kRowA, int kColA, int kRowB, int kColB, int kOperation> \
50 inline void name(const double* A, \
51 const int num_row_a, \
52 const int num_col_a, \
53 const double* B, \
54 const int num_row_b, \
55 const int num_col_b, \
56 double* C, \
57 const int start_row_c, \
58 const int start_col_c, \
59 const int row_stride_c, \
60 const int col_stride_c)
61
62#define CERES_GEMM_NAIVE_HEADER \
63 DCHECK_GT(num_row_a, 0); \
64 DCHECK_GT(num_col_a, 0); \
65 DCHECK_GT(num_row_b, 0); \
66 DCHECK_GT(num_col_b, 0); \
67 DCHECK_GE(start_row_c, 0); \
68 DCHECK_GE(start_col_c, 0); \
69 DCHECK_GT(row_stride_c, 0); \
70 DCHECK_GT(col_stride_c, 0); \
71 DCHECK((kRowA == Eigen::Dynamic) || (kRowA == num_row_a)); \
72 DCHECK((kColA == Eigen::Dynamic) || (kColA == num_col_a)); \
73 DCHECK((kRowB == Eigen::Dynamic) || (kRowB == num_row_b)); \
74 DCHECK((kColB == Eigen::Dynamic) || (kColB == num_col_b)); \
75 const int NUM_ROW_A = (kRowA != Eigen::Dynamic ? kRowA : num_row_a); \
76 const int NUM_COL_A = (kColA != Eigen::Dynamic ? kColA : num_col_a); \
77 const int NUM_ROW_B = (kRowB != Eigen::Dynamic ? kRowB : num_row_b); \
78 const int NUM_COL_B = (kColB != Eigen::Dynamic ? kColB : num_col_b);
79
80#define CERES_GEMM_EIGEN_HEADER \
81 const typename EigenTypes<kRowA, kColA>::ConstMatrixRef \
82 Aref(A, num_row_a, num_col_a); \
83 const typename EigenTypes<kRowB, kColB>::ConstMatrixRef \
84 Bref(B, num_row_b, num_col_b); \
85 MatrixRef Cref(C, row_stride_c, col_stride_c); \
86
87#define CERES_CALL_GEMM(name) \
88 name<kRowA, kColA, kRowB, kColB, kOperation>( \
89 A, num_row_a, num_col_a, \
90 B, num_row_b, num_col_b, \
91 C, start_row_c, start_col_c, row_stride_c, col_stride_c);
92
93#define CERES_GEMM_STORE_SINGLE(p, index, value) \
94 if (kOperation > 0) { \
95 p[index] += value; \
96 } else if (kOperation < 0) { \
97 p[index] -= value; \
98 } else { \
99 p[index] = value; \
100 }
101
102#define CERES_GEMM_STORE_PAIR(p, index, v1, v2) \
103 if (kOperation > 0) { \
104 p[index] += v1; \
105 p[index + 1] += v2; \
106 } else if (kOperation < 0) { \
107 p[index] -= v1; \
108 p[index + 1] -= v2; \
109 } else { \
110 p[index] = v1; \
111 p[index + 1] = v2; \
112 }
113
114// For the matrix-matrix functions below, there are three variants for
115// each functionality. Foo, FooNaive and FooEigen. Foo is the one to
116// be called by the user. FooNaive is a basic loop based
117// implementation and FooEigen uses Eigen's implementation. Foo
118// chooses between FooNaive and FooEigen depending on how many of the
119// template arguments are fixed at compile time. Currently, FooEigen
120// is called if all matrix dimensions are compile time
121// constants. FooNaive is called otherwise. This leads to the best
122// performance currently.
123//
124// The MatrixMatrixMultiply variants compute:
125//
126// C op A * B;
127//
128// The MatrixTransposeMatrixMultiply variants compute:
129//
130// C op A' * B
131//
132// where op can be +=, -=, or =.
133//
134// The template parameters (kRowA, kColA, kRowB, kColB) allow
135// specialization of the loop at compile time. If this information is
136// not available, then Eigen::Dynamic should be used as the template
137// argument.
138//
139// kOperation = 1 -> C += A * B
140// kOperation = -1 -> C -= A * B
141// kOperation = 0 -> C = A * B
142//
143// The functions can write into matrices C which are larger than the
144// matrix A * B. This is done by specifying the true size of C via
145// row_stride_c and col_stride_c, and then indicating where A * B
146// should be written into by start_row_c and start_col_c.
147//
148// Graphically if row_stride_c = 10, col_stride_c = 12, start_row_c =
149// 4 and start_col_c = 5, then if A = 3x2 and B = 2x4, we get
150//
151// ------------
152// ------------
153// ------------
154// ------------
155// -----xxxx---
156// -----xxxx---
157// -----xxxx---
158// ------------
159// ------------
160// ------------
161//
162CERES_GEMM_BEGIN(MatrixMatrixMultiplyEigen) {
163 CERES_GEMM_EIGEN_HEADER
164 Eigen::Block<MatrixRef, kRowA, kColB>
165 block(Cref, start_row_c, start_col_c, num_row_a, num_col_b);
166
167 if (kOperation > 0) {
168 block.noalias() += Aref * Bref;
169 } else if (kOperation < 0) {
170 block.noalias() -= Aref * Bref;
171 } else {
172 block.noalias() = Aref * Bref;
173 }
174}
175
176CERES_GEMM_BEGIN(MatrixMatrixMultiplyNaive) {
177 CERES_GEMM_NAIVE_HEADER
178 DCHECK_EQ(NUM_COL_A, NUM_ROW_B);
179
180 const int NUM_ROW_C = NUM_ROW_A;
181 const int NUM_COL_C = NUM_COL_B;
182 DCHECK_LE(start_row_c + NUM_ROW_C, row_stride_c);
183 DCHECK_LE(start_col_c + NUM_COL_C, col_stride_c);
184 const int span = 4;
185
186 // Calculate the remainder part first.
187
188 // Process the last odd column if present.
189 if (NUM_COL_C & 1) {
190 int col = NUM_COL_C - 1;
191 const double* pa = &A[0];
192 for (int row = 0; row < NUM_ROW_C; ++row, pa += NUM_COL_A) {
193 const double* pb = &B[col];
194 double tmp = 0.0;
195 for (int k = 0; k < NUM_COL_A; ++k, pb += NUM_COL_B) {
196 tmp += pa[k] * pb[0];
197 }
198
199 const int index = (row + start_row_c) * col_stride_c + start_col_c + col;
200 CERES_GEMM_STORE_SINGLE(C, index, tmp);
201 }
202
203 // Return directly for efficiency of extremely small matrix multiply.
204 if (NUM_COL_C == 1) {
205 return;
206 }
207 }
208
209 // Process the couple columns in remainder if present.
210 if (NUM_COL_C & 2) {
211 int col = NUM_COL_C & (int)(~(span - 1)) ;
212 const double* pa = &A[0];
213 for (int row = 0; row < NUM_ROW_C; ++row, pa += NUM_COL_A) {
214 const double* pb = &B[col];
215 double tmp1 = 0.0, tmp2 = 0.0;
216 for (int k = 0; k < NUM_COL_A; ++k, pb += NUM_COL_B) {
217 double av = pa[k];
218 tmp1 += av * pb[0];
219 tmp2 += av * pb[1];
220 }
221
222 const int index = (row + start_row_c) * col_stride_c + start_col_c + col;
223 CERES_GEMM_STORE_PAIR(C, index, tmp1, tmp2);
224 }
225
226 // Return directly for efficiency of extremely small matrix multiply.
227 if (NUM_COL_C < span) {
228 return;
229 }
230 }
231
232 // Calculate the main part with multiples of 4.
233 int col_m = NUM_COL_C & (int)(~(span - 1));
234 for (int col = 0; col < col_m; col += span) {
235 for (int row = 0; row < NUM_ROW_C; ++row) {
236 const int index = (row + start_row_c) * col_stride_c + start_col_c + col;
237 MMM_mat1x4(NUM_COL_A, &A[row * NUM_COL_A],
238 &B[col], NUM_COL_B, &C[index], kOperation);
239 }
240 }
241
242}
243
244CERES_GEMM_BEGIN(MatrixMatrixMultiply) {
245#ifdef CERES_NO_CUSTOM_BLAS
246
247 CERES_CALL_GEMM(MatrixMatrixMultiplyEigen)
248 return;
249
250#else
251
252 if (kRowA != Eigen::Dynamic && kColA != Eigen::Dynamic &&
253 kRowB != Eigen::Dynamic && kColB != Eigen::Dynamic) {
254 CERES_CALL_GEMM(MatrixMatrixMultiplyEigen)
255 } else {
256 CERES_CALL_GEMM(MatrixMatrixMultiplyNaive)
257 }
258
259#endif
260}
261
262CERES_GEMM_BEGIN(MatrixTransposeMatrixMultiplyEigen) {
263 CERES_GEMM_EIGEN_HEADER
264 Eigen::Block<MatrixRef, kColA, kColB> block(Cref,
265 start_row_c, start_col_c,
266 num_col_a, num_col_b);
267 if (kOperation > 0) {
268 block.noalias() += Aref.transpose() * Bref;
269 } else if (kOperation < 0) {
270 block.noalias() -= Aref.transpose() * Bref;
271 } else {
272 block.noalias() = Aref.transpose() * Bref;
273 }
274}
275
276CERES_GEMM_BEGIN(MatrixTransposeMatrixMultiplyNaive) {
277 CERES_GEMM_NAIVE_HEADER
278 DCHECK_EQ(NUM_ROW_A, NUM_ROW_B);
279
280 const int NUM_ROW_C = NUM_COL_A;
281 const int NUM_COL_C = NUM_COL_B;
282 DCHECK_LE(start_row_c + NUM_ROW_C, row_stride_c);
283 DCHECK_LE(start_col_c + NUM_COL_C, col_stride_c);
284 const int span = 4;
285
286 // Process the remainder part first.
287
288 // Process the last odd column if present.
289 if (NUM_COL_C & 1) {
290 int col = NUM_COL_C - 1;
291 for (int row = 0; row < NUM_ROW_C; ++row) {
292 const double* pa = &A[row];
293 const double* pb = &B[col];
294 double tmp = 0.0;
295 for (int k = 0; k < NUM_ROW_A; ++k) {
296 tmp += pa[0] * pb[0];
297 pa += NUM_COL_A;
298 pb += NUM_COL_B;
299 }
300
301 const int index = (row + start_row_c) * col_stride_c + start_col_c + col;
302 CERES_GEMM_STORE_SINGLE(C, index, tmp);
303 }
304
305 // Return directly for efficiency of extremely small matrix multiply.
306 if (NUM_COL_C == 1) {
307 return;
308 }
309 }
310
311 // Process the couple columns in remainder if present.
312 if (NUM_COL_C & 2) {
313 int col = NUM_COL_C & (int)(~(span - 1)) ;
314 for (int row = 0; row < NUM_ROW_C; ++row) {
315 const double* pa = &A[row];
316 const double* pb = &B[col];
317 double tmp1 = 0.0, tmp2 = 0.0;
318 for (int k = 0; k < NUM_ROW_A; ++k) {
319 double av = *pa;
320 tmp1 += av * pb[0];
321 tmp2 += av * pb[1];
322 pa += NUM_COL_A;
323 pb += NUM_COL_B;
324 }
325
326 const int index = (row + start_row_c) * col_stride_c + start_col_c + col;
327 CERES_GEMM_STORE_PAIR(C, index, tmp1, tmp2);
328 }
329
330 // Return directly for efficiency of extremely small matrix multiply.
331 if (NUM_COL_C < span) {
332 return;
333 }
334 }
335
336 // Process the main part with multiples of 4.
337 int col_m = NUM_COL_C & (int)(~(span - 1));
338 for (int col = 0; col < col_m; col += span) {
339 for (int row = 0; row < NUM_ROW_C; ++row) {
340 const int index = (row + start_row_c) * col_stride_c + start_col_c + col;
341 MTM_mat1x4(NUM_ROW_A, &A[row], NUM_COL_A,
342 &B[col], NUM_COL_B, &C[index], kOperation);
343 }
344 }
345
346}
347
348CERES_GEMM_BEGIN(MatrixTransposeMatrixMultiply) {
349#ifdef CERES_NO_CUSTOM_BLAS
350
351 CERES_CALL_GEMM(MatrixTransposeMatrixMultiplyEigen)
352 return;
353
354#else
355
356 if (kRowA != Eigen::Dynamic && kColA != Eigen::Dynamic &&
357 kRowB != Eigen::Dynamic && kColB != Eigen::Dynamic) {
358 CERES_CALL_GEMM(MatrixTransposeMatrixMultiplyEigen)
359 } else {
360 CERES_CALL_GEMM(MatrixTransposeMatrixMultiplyNaive)
361 }
362
363#endif
364}
365
366// Matrix-Vector multiplication
367//
368// c op A * b;
369//
370// where op can be +=, -=, or =.
371//
372// The template parameters (kRowA, kColA) allow specialization of the
373// loop at compile time. If this information is not available, then
374// Eigen::Dynamic should be used as the template argument.
375//
376// kOperation = 1 -> c += A' * b
377// kOperation = -1 -> c -= A' * b
378// kOperation = 0 -> c = A' * b
379template<int kRowA, int kColA, int kOperation>
380inline void MatrixVectorMultiply(const double* A,
381 const int num_row_a,
382 const int num_col_a,
383 const double* b,
384 double* c) {
385#ifdef CERES_NO_CUSTOM_BLAS
386 const typename EigenTypes<kRowA, kColA>::ConstMatrixRef
387 Aref(A, num_row_a, num_col_a);
388 const typename EigenTypes<kColA>::ConstVectorRef bref(b, num_col_a);
389 typename EigenTypes<kRowA>::VectorRef cref(c, num_row_a);
390
391 // lazyProduct works better than .noalias() for matrix-vector
392 // products.
393 if (kOperation > 0) {
394 cref += Aref.lazyProduct(bref);
395 } else if (kOperation < 0) {
396 cref -= Aref.lazyProduct(bref);
397 } else {
398 cref = Aref.lazyProduct(bref);
399 }
400#else
401
402 DCHECK_GT(num_row_a, 0);
403 DCHECK_GT(num_col_a, 0);
404 DCHECK((kRowA == Eigen::Dynamic) || (kRowA == num_row_a));
405 DCHECK((kColA == Eigen::Dynamic) || (kColA == num_col_a));
406
407 const int NUM_ROW_A = (kRowA != Eigen::Dynamic ? kRowA : num_row_a);
408 const int NUM_COL_A = (kColA != Eigen::Dynamic ? kColA : num_col_a);
409 const int span = 4;
410
411 // Calculate the remainder part first.
412
413 // Process the last odd row if present.
414 if (NUM_ROW_A & 1) {
415 int row = NUM_ROW_A - 1;
416 const double* pa = &A[row * NUM_COL_A];
417 const double* pb = &b[0];
418 double tmp = 0.0;
419 for (int col = 0; col < NUM_COL_A; ++col) {
420 tmp += (*pa++) * (*pb++);
421 }
422 CERES_GEMM_STORE_SINGLE(c, row, tmp);
423
424 // Return directly for efficiency of extremely small matrix multiply.
425 if (NUM_ROW_A == 1) {
426 return;
427 }
428 }
429
430 // Process the couple rows in remainder if present.
431 if (NUM_ROW_A & 2) {
432 int row = NUM_ROW_A & (int)(~(span - 1));
433 const double* pa1 = &A[row * NUM_COL_A];
434 const double* pa2 = pa1 + NUM_COL_A;
435 const double* pb = &b[0];
436 double tmp1 = 0.0, tmp2 = 0.0;
437 for (int col = 0; col < NUM_COL_A; ++col) {
438 double bv = *pb++;
439 tmp1 += *(pa1++) * bv;
440 tmp2 += *(pa2++) * bv;
441 }
442 CERES_GEMM_STORE_PAIR(c, row, tmp1, tmp2);
443
444 // Return directly for efficiency of extremely small matrix multiply.
445 if (NUM_ROW_A < span) {
446 return;
447 }
448 }
449
450 // Calculate the main part with multiples of 4.
451 int row_m = NUM_ROW_A & (int)(~(span - 1));
452 for (int row = 0; row < row_m; row += span) {
453 MVM_mat4x1(NUM_COL_A, &A[row * NUM_COL_A], NUM_COL_A,
454 &b[0], &c[row], kOperation);
455 }
456
457#endif // CERES_NO_CUSTOM_BLAS
458}
459
460// Similar to MatrixVectorMultiply, except that A is transposed, i.e.,
461//
462// c op A' * b;
463template<int kRowA, int kColA, int kOperation>
464inline void MatrixTransposeVectorMultiply(const double* A,
465 const int num_row_a,
466 const int num_col_a,
467 const double* b,
468 double* c) {
469#ifdef CERES_NO_CUSTOM_BLAS
470 const typename EigenTypes<kRowA, kColA>::ConstMatrixRef
471 Aref(A, num_row_a, num_col_a);
472 const typename EigenTypes<kRowA>::ConstVectorRef bref(b, num_row_a);
473 typename EigenTypes<kColA>::VectorRef cref(c, num_col_a);
474
475 // lazyProduct works better than .noalias() for matrix-vector
476 // products.
477 if (kOperation > 0) {
478 cref += Aref.transpose().lazyProduct(bref);
479 } else if (kOperation < 0) {
480 cref -= Aref.transpose().lazyProduct(bref);
481 } else {
482 cref = Aref.transpose().lazyProduct(bref);
483 }
484#else
485
486 DCHECK_GT(num_row_a, 0);
487 DCHECK_GT(num_col_a, 0);
488 DCHECK((kRowA == Eigen::Dynamic) || (kRowA == num_row_a));
489 DCHECK((kColA == Eigen::Dynamic) || (kColA == num_col_a));
490
491 const int NUM_ROW_A = (kRowA != Eigen::Dynamic ? kRowA : num_row_a);
492 const int NUM_COL_A = (kColA != Eigen::Dynamic ? kColA : num_col_a);
493 const int span = 4;
494
495 // Calculate the remainder part first.
496
497 // Process the last odd column if present.
498 if (NUM_COL_A & 1) {
499 int row = NUM_COL_A - 1;
500 const double* pa = &A[row];
501 const double* pb = &b[0];
502 double tmp = 0.0;
503 for (int col = 0; col < NUM_ROW_A; ++col) {
504 tmp += *pa * (*pb++);
505 pa += NUM_COL_A;
506 }
507 CERES_GEMM_STORE_SINGLE(c, row, tmp);
508
509 // Return directly for efficiency of extremely small matrix multiply.
510 if (NUM_COL_A == 1) {
511 return;
512 }
513 }
514
515 // Process the couple columns in remainder if present.
516 if (NUM_COL_A & 2) {
517 int row = NUM_COL_A & (int)(~(span - 1));
518 const double* pa = &A[row];
519 const double* pb = &b[0];
520 double tmp1 = 0.0, tmp2 = 0.0;
521 for (int col = 0; col < NUM_ROW_A; ++col) {
522 double bv = *pb++;
523 tmp1 += *(pa ) * bv;
524 tmp2 += *(pa + 1) * bv;
525 pa += NUM_COL_A;
526 }
527 CERES_GEMM_STORE_PAIR(c, row, tmp1, tmp2);
528
529 // Return directly for efficiency of extremely small matrix multiply.
530 if (NUM_COL_A < span) {
531 return;
532 }
533 }
534
535 // Calculate the main part with multiples of 4.
536 int row_m = NUM_COL_A & (int)(~(span - 1));
537 for (int row = 0; row < row_m; row += span) {
538 MTV_mat4x1(NUM_ROW_A, &A[row], NUM_COL_A,
539 &b[0], &c[row], kOperation);
540 }
541
542#endif // CERES_NO_CUSTOM_BLAS
543}
544
545#undef CERES_GEMM_BEGIN
546#undef CERES_GEMM_EIGEN_HEADER
547#undef CERES_GEMM_NAIVE_HEADER
548#undef CERES_CALL_GEMM
549#undef CERES_GEMM_STORE_SINGLE
550#undef CERES_GEMM_STORE_PAIR
551
552} // namespace internal
553} // namespace ceres
554
555#endif // CERES_INTERNAL_SMALL_BLAS_H_