blob: 155570c81e57fb4902531b45876e3a18104d63e8 [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/dense_qr.h"
32
33#include <memory>
34#include <numeric>
35#include <string>
36#include <tuple>
37#include <vector>
38
39#include "Eigen/Dense"
40#include "ceres/internal/eigen.h"
41#include "ceres/linear_solver.h"
42#include "glog/logging.h"
43#include "gmock/gmock.h"
44#include "gtest/gtest.h"
45
46namespace ceres::internal {
47
48using Param = DenseLinearAlgebraLibraryType;
49
50namespace {
51
52std::string ParamInfoToString(testing::TestParamInfo<Param> info) {
53 return DenseLinearAlgebraLibraryTypeToString(info.param);
54}
55
56} // namespace
57
58class DenseQRTest : public ::testing::TestWithParam<Param> {};
59
60TEST_P(DenseQRTest, FactorAndSolve) {
61 // TODO(sameeragarwal): Convert these tests into type parameterized tests so
62 // that we can test the single and double precision solvers.
63
64 using Scalar = double;
65 using MatrixType = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>;
66 using VectorType = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
67
68 LinearSolver::Options options;
69 ContextImpl context;
70#ifndef CERES_NO_CUDA
71 options.context = &context;
72 std::string error;
73 CHECK(context.InitCuda(&error)) << error;
74#endif // CERES_NO_CUDA
75 options.dense_linear_algebra_library_type = GetParam();
76 const double kEpsilon = std::numeric_limits<double>::epsilon() * 1.5e4;
77 std::unique_ptr<DenseQR> dense_qr = DenseQR::Create(options);
78
79 const int kNumTrials = 10;
80 const int kMinNumCols = 1;
81 const int kMaxNumCols = 10;
82 const int kMinRowsFactor = 1;
83 const int kMaxRowsFactor = 3;
84 for (int num_cols = kMinNumCols; num_cols < kMaxNumCols; ++num_cols) {
85 for (int num_rows = kMinRowsFactor * num_cols;
86 num_rows < kMaxRowsFactor * num_cols;
87 ++num_rows) {
88 for (int trial = 0; trial < kNumTrials; ++trial) {
89 MatrixType lhs = MatrixType::Random(num_rows, num_cols);
90 Vector x = VectorType::Random(num_cols);
91 Vector rhs = lhs * x;
92 Vector actual = Vector::Random(num_cols);
93 LinearSolver::Summary summary;
94 summary.termination_type = dense_qr->FactorAndSolve(num_rows,
95 num_cols,
96 lhs.data(),
97 rhs.data(),
98 actual.data(),
99 &summary.message);
100 ASSERT_EQ(summary.termination_type,
101 LinearSolverTerminationType::SUCCESS);
102 ASSERT_NEAR((x - actual).norm() / x.norm(), 0.0, kEpsilon)
103 << "\nexpected: " << x.transpose()
104 << "\nactual : " << actual.transpose();
105 }
106 }
107 }
108}
109
110namespace {
111
112// NOTE: preprocessor directives in a macro are not standard conforming
113decltype(auto) MakeValues() {
114 return ::testing::Values(EIGEN
115#ifndef CERES_NO_LAPACK
116 ,
117 LAPACK
118#endif
119#ifndef CERES_NO_CUDA
120 ,
121 CUDA
122#endif
123 );
124}
125
126} // namespace
127
128INSTANTIATE_TEST_SUITE_P(_, DenseQRTest, MakeValues(), ParamInfoToString);
129
130} // namespace ceres::internal