blob: 299051c5bcf193b5517690ff0a387ba4b9326054 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2015 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30
31#include "ceres/linear_least_squares_problems.h"
32
33#include <cstdio>
34#include <memory>
35#include <string>
36#include <vector>
37
38#include "ceres/block_sparse_matrix.h"
39#include "ceres/block_structure.h"
40#include "ceres/casts.h"
41#include "ceres/file.h"
42#include "ceres/stringprintf.h"
43#include "ceres/triplet_sparse_matrix.h"
44#include "ceres/types.h"
45#include "glog/logging.h"
46
47namespace ceres {
48namespace internal {
49
50using std::string;
51
52LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) {
53 switch (id) {
54 case 0:
55 return LinearLeastSquaresProblem0();
56 case 1:
57 return LinearLeastSquaresProblem1();
58 case 2:
59 return LinearLeastSquaresProblem2();
60 case 3:
61 return LinearLeastSquaresProblem3();
62 case 4:
63 return LinearLeastSquaresProblem4();
64 default:
65 LOG(FATAL) << "Unknown problem id requested " << id;
66 }
67 return NULL;
68}
69
70/*
71A = [1 2]
72 [3 4]
73 [6 -10]
74
75b = [ 8
76 18
77 -18]
78
79x = [2
80 3]
81
82D = [1
83 2]
84
85x_D = [1.78448275;
86 2.82327586;]
87 */
88LinearLeastSquaresProblem* LinearLeastSquaresProblem0() {
89 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
90
91 TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6);
92 problem->b.reset(new double[3]);
93 problem->D.reset(new double[2]);
94
95 problem->x.reset(new double[2]);
96 problem->x_D.reset(new double[2]);
97
98 int* Ai = A->mutable_rows();
99 int* Aj = A->mutable_cols();
100 double* Ax = A->mutable_values();
101
102 int counter = 0;
103 for (int i = 0; i < 3; ++i) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800104 for (int j = 0; j < 2; ++j) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800105 Ai[counter] = i;
106 Aj[counter] = j;
107 ++counter;
108 }
109 }
110
111 Ax[0] = 1.;
112 Ax[1] = 2.;
113 Ax[2] = 3.;
114 Ax[3] = 4.;
115 Ax[4] = 6;
116 Ax[5] = -10;
117 A->set_num_nonzeros(6);
118 problem->A.reset(A);
119
120 problem->b[0] = 8;
121 problem->b[1] = 18;
122 problem->b[2] = -18;
123
124 problem->x[0] = 2.0;
125 problem->x[1] = 3.0;
126
127 problem->D[0] = 1;
128 problem->D[1] = 2;
129
130 problem->x_D[0] = 1.78448275;
131 problem->x_D[1] = 2.82327586;
132 return problem;
133}
134
Austin Schuh70cc9552019-01-21 19:46:48 -0800135/*
136 A = [1 0 | 2 0 0
137 3 0 | 0 4 0
138 0 5 | 0 0 6
139 0 7 | 8 0 0
140 0 9 | 1 0 0
141 0 0 | 1 1 1]
142
143 b = [0
144 1
145 2
146 3
147 4
148 5]
149
150 c = A'* b = [ 3
151 67
152 33
153 9
154 17]
155
156 A'A = [10 0 2 12 0
157 0 155 65 0 30
158 2 65 70 1 1
159 12 0 1 17 1
160 0 30 1 1 37]
161
162 S = [ 42.3419 -1.4000 -11.5806
163 -1.4000 2.6000 1.0000
164 11.5806 1.0000 31.1935]
165
166 r = [ 4.3032
167 5.4000
168 5.0323]
169
170 S\r = [ 0.2102
171 2.1367
172 0.1388]
173
174 A\b = [-2.3061
175 0.3172
176 0.2102
177 2.1367
178 0.1388]
179*/
180// The following two functions create a TripletSparseMatrix and a
181// BlockSparseMatrix version of this problem.
182
183// TripletSparseMatrix version.
184LinearLeastSquaresProblem* LinearLeastSquaresProblem1() {
185 int num_rows = 6;
186 int num_cols = 5;
187
188 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800189 TripletSparseMatrix* A =
190 new TripletSparseMatrix(num_rows, num_cols, num_rows * num_cols);
Austin Schuh70cc9552019-01-21 19:46:48 -0800191 problem->b.reset(new double[num_rows]);
192 problem->D.reset(new double[num_cols]);
193 problem->num_eliminate_blocks = 2;
194
195 int* rows = A->mutable_rows();
196 int* cols = A->mutable_cols();
197 double* values = A->mutable_values();
198
199 int nnz = 0;
200
201 // Row 1
202 {
203 rows[nnz] = 0;
204 cols[nnz] = 0;
205 values[nnz++] = 1;
206
207 rows[nnz] = 0;
208 cols[nnz] = 2;
209 values[nnz++] = 2;
210 }
211
212 // Row 2
213 {
214 rows[nnz] = 1;
215 cols[nnz] = 0;
216 values[nnz++] = 3;
217
218 rows[nnz] = 1;
219 cols[nnz] = 3;
220 values[nnz++] = 4;
221 }
222
223 // Row 3
224 {
225 rows[nnz] = 2;
226 cols[nnz] = 1;
227 values[nnz++] = 5;
228
229 rows[nnz] = 2;
230 cols[nnz] = 4;
231 values[nnz++] = 6;
232 }
233
234 // Row 4
235 {
236 rows[nnz] = 3;
237 cols[nnz] = 1;
238 values[nnz++] = 7;
239
240 rows[nnz] = 3;
241 cols[nnz] = 2;
242 values[nnz++] = 8;
243 }
244
245 // Row 5
246 {
247 rows[nnz] = 4;
248 cols[nnz] = 1;
249 values[nnz++] = 9;
250
251 rows[nnz] = 4;
252 cols[nnz] = 2;
253 values[nnz++] = 1;
254 }
255
256 // Row 6
257 {
258 rows[nnz] = 5;
259 cols[nnz] = 2;
260 values[nnz++] = 1;
261
262 rows[nnz] = 5;
263 cols[nnz] = 3;
264 values[nnz++] = 1;
265
266 rows[nnz] = 5;
267 cols[nnz] = 4;
268 values[nnz++] = 1;
269 }
270
271 A->set_num_nonzeros(nnz);
272 CHECK(A->IsValid());
273
274 problem->A.reset(A);
275
276 for (int i = 0; i < num_cols; ++i) {
277 problem->D.get()[i] = 1;
278 }
279
280 for (int i = 0; i < num_rows; ++i) {
281 problem->b.get()[i] = i;
282 }
283
284 return problem;
285}
286
287// BlockSparseMatrix version
288LinearLeastSquaresProblem* LinearLeastSquaresProblem2() {
289 int num_rows = 6;
290 int num_cols = 5;
291
292 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
293
294 problem->b.reset(new double[num_rows]);
295 problem->D.reset(new double[num_cols]);
296 problem->num_eliminate_blocks = 2;
297
298 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
299 std::unique_ptr<double[]> values(new double[num_rows * num_cols]);
300
301 for (int c = 0; c < num_cols; ++c) {
302 bs->cols.push_back(Block());
303 bs->cols.back().size = 1;
304 bs->cols.back().position = c;
305 }
306
307 int nnz = 0;
308
309 // Row 1
310 {
311 values[nnz++] = 1;
312 values[nnz++] = 2;
313
314 bs->rows.push_back(CompressedRow());
315 CompressedRow& row = bs->rows.back();
316 row.block.size = 1;
317 row.block.position = 0;
318 row.cells.push_back(Cell(0, 0));
319 row.cells.push_back(Cell(2, 1));
320 }
321
322 // Row 2
323 {
324 values[nnz++] = 3;
325 values[nnz++] = 4;
326
327 bs->rows.push_back(CompressedRow());
328 CompressedRow& row = bs->rows.back();
329 row.block.size = 1;
330 row.block.position = 1;
331 row.cells.push_back(Cell(0, 2));
332 row.cells.push_back(Cell(3, 3));
333 }
334
335 // Row 3
336 {
337 values[nnz++] = 5;
338 values[nnz++] = 6;
339
340 bs->rows.push_back(CompressedRow());
341 CompressedRow& row = bs->rows.back();
342 row.block.size = 1;
343 row.block.position = 2;
344 row.cells.push_back(Cell(1, 4));
345 row.cells.push_back(Cell(4, 5));
346 }
347
348 // Row 4
349 {
350 values[nnz++] = 7;
351 values[nnz++] = 8;
352
353 bs->rows.push_back(CompressedRow());
354 CompressedRow& row = bs->rows.back();
355 row.block.size = 1;
356 row.block.position = 3;
357 row.cells.push_back(Cell(1, 6));
358 row.cells.push_back(Cell(2, 7));
359 }
360
361 // Row 5
362 {
363 values[nnz++] = 9;
364 values[nnz++] = 1;
365
366 bs->rows.push_back(CompressedRow());
367 CompressedRow& row = bs->rows.back();
368 row.block.size = 1;
369 row.block.position = 4;
370 row.cells.push_back(Cell(1, 8));
371 row.cells.push_back(Cell(2, 9));
372 }
373
374 // Row 6
375 {
376 values[nnz++] = 1;
377 values[nnz++] = 1;
378 values[nnz++] = 1;
379
380 bs->rows.push_back(CompressedRow());
381 CompressedRow& row = bs->rows.back();
382 row.block.size = 1;
383 row.block.position = 5;
384 row.cells.push_back(Cell(2, 10));
385 row.cells.push_back(Cell(3, 11));
386 row.cells.push_back(Cell(4, 12));
387 }
388
389 BlockSparseMatrix* A = new BlockSparseMatrix(bs);
390 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
391
392 for (int i = 0; i < num_cols; ++i) {
393 problem->D.get()[i] = 1;
394 }
395
396 for (int i = 0; i < num_rows; ++i) {
397 problem->b.get()[i] = i;
398 }
399
400 problem->A.reset(A);
401
402 return problem;
403}
404
Austin Schuh70cc9552019-01-21 19:46:48 -0800405/*
406 A = [1 0
407 3 0
408 0 5
409 0 7
410 0 9
411 0 0]
412
413 b = [0
414 1
415 2
416 3
417 4
418 5]
419*/
420// BlockSparseMatrix version
421LinearLeastSquaresProblem* LinearLeastSquaresProblem3() {
422 int num_rows = 5;
423 int num_cols = 2;
424
425 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
426
427 problem->b.reset(new double[num_rows]);
428 problem->D.reset(new double[num_cols]);
429 problem->num_eliminate_blocks = 2;
430
431 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
432 std::unique_ptr<double[]> values(new double[num_rows * num_cols]);
433
434 for (int c = 0; c < num_cols; ++c) {
435 bs->cols.push_back(Block());
436 bs->cols.back().size = 1;
437 bs->cols.back().position = c;
438 }
439
440 int nnz = 0;
441
442 // Row 1
443 {
444 values[nnz++] = 1;
445 bs->rows.push_back(CompressedRow());
446 CompressedRow& row = bs->rows.back();
447 row.block.size = 1;
448 row.block.position = 0;
449 row.cells.push_back(Cell(0, 0));
450 }
451
452 // Row 2
453 {
454 values[nnz++] = 3;
455 bs->rows.push_back(CompressedRow());
456 CompressedRow& row = bs->rows.back();
457 row.block.size = 1;
458 row.block.position = 1;
459 row.cells.push_back(Cell(0, 1));
460 }
461
462 // Row 3
463 {
464 values[nnz++] = 5;
465 bs->rows.push_back(CompressedRow());
466 CompressedRow& row = bs->rows.back();
467 row.block.size = 1;
468 row.block.position = 2;
469 row.cells.push_back(Cell(1, 2));
470 }
471
472 // Row 4
473 {
474 values[nnz++] = 7;
475 bs->rows.push_back(CompressedRow());
476 CompressedRow& row = bs->rows.back();
477 row.block.size = 1;
478 row.block.position = 3;
479 row.cells.push_back(Cell(1, 3));
480 }
481
482 // Row 5
483 {
484 values[nnz++] = 9;
485 bs->rows.push_back(CompressedRow());
486 CompressedRow& row = bs->rows.back();
487 row.block.size = 1;
488 row.block.position = 4;
489 row.cells.push_back(Cell(1, 4));
490 }
491
492 BlockSparseMatrix* A = new BlockSparseMatrix(bs);
493 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
494
495 for (int i = 0; i < num_cols; ++i) {
496 problem->D.get()[i] = 1;
497 }
498
499 for (int i = 0; i < num_rows; ++i) {
500 problem->b.get()[i] = i;
501 }
502
503 problem->A.reset(A);
504
505 return problem;
506}
507
508/*
509 A = [1 2 0 0 0 1 1
510 1 4 0 0 0 5 6
511 0 0 9 0 0 3 1]
512
513 b = [0
514 1
515 2]
516*/
517// BlockSparseMatrix version
518//
519// This problem has the unique property that it has two different
520// sized f-blocks, but only one of them occurs in the rows involving
521// the one e-block. So performing Schur elimination on this problem
522// tests the Schur Eliminator's ability to handle non-e-block rows
523// correctly when their structure does not conform to the static
524// structure determined by DetectStructure.
525//
526// NOTE: This problem is too small and rank deficient to be solved without
527// the diagonal regularization.
528LinearLeastSquaresProblem* LinearLeastSquaresProblem4() {
529 int num_rows = 3;
530 int num_cols = 7;
531
532 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
533
534 problem->b.reset(new double[num_rows]);
535 problem->D.reset(new double[num_cols]);
536 problem->num_eliminate_blocks = 1;
537
538 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
539 std::unique_ptr<double[]> values(new double[num_rows * num_cols]);
540
541 // Column block structure
542 bs->cols.push_back(Block());
543 bs->cols.back().size = 2;
544 bs->cols.back().position = 0;
545
546 bs->cols.push_back(Block());
547 bs->cols.back().size = 3;
548 bs->cols.back().position = 2;
549
550 bs->cols.push_back(Block());
551 bs->cols.back().size = 2;
552 bs->cols.back().position = 5;
553
554 int nnz = 0;
555
556 // Row 1 & 2
557 {
558 bs->rows.push_back(CompressedRow());
559 CompressedRow& row = bs->rows.back();
560 row.block.size = 2;
561 row.block.position = 0;
562
563 row.cells.push_back(Cell(0, nnz));
564 values[nnz++] = 1;
565 values[nnz++] = 2;
566 values[nnz++] = 1;
567 values[nnz++] = 4;
568
569 row.cells.push_back(Cell(2, nnz));
570 values[nnz++] = 1;
571 values[nnz++] = 1;
572 values[nnz++] = 5;
573 values[nnz++] = 6;
574 }
575
576 // Row 3
577 {
578 bs->rows.push_back(CompressedRow());
579 CompressedRow& row = bs->rows.back();
580 row.block.size = 1;
581 row.block.position = 2;
582
583 row.cells.push_back(Cell(1, nnz));
584 values[nnz++] = 9;
585 values[nnz++] = 0;
586 values[nnz++] = 0;
587
588 row.cells.push_back(Cell(2, nnz));
589 values[nnz++] = 3;
590 values[nnz++] = 1;
591 }
592
593 BlockSparseMatrix* A = new BlockSparseMatrix(bs);
594 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));
595
596 for (int i = 0; i < num_cols; ++i) {
597 problem->D.get()[i] = (i + 1) * 100;
598 }
599
600 for (int i = 0; i < num_rows; ++i) {
601 problem->b.get()[i] = i;
602 }
603
604 problem->A.reset(A);
605 return problem;
606}
607
608namespace {
609bool DumpLinearLeastSquaresProblemToConsole(const SparseMatrix* A,
610 const double* D,
611 const double* b,
612 const double* x,
613 int num_eliminate_blocks) {
614 CHECK(A != nullptr);
615 Matrix AA;
616 A->ToDenseMatrix(&AA);
617 LOG(INFO) << "A^T: \n" << AA.transpose();
618
619 if (D != NULL) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800620 LOG(INFO) << "A's appended diagonal:\n" << ConstVectorRef(D, A->num_cols());
Austin Schuh70cc9552019-01-21 19:46:48 -0800621 }
622
623 if (b != NULL) {
624 LOG(INFO) << "b: \n" << ConstVectorRef(b, A->num_rows());
625 }
626
627 if (x != NULL) {
628 LOG(INFO) << "x: \n" << ConstVectorRef(x, A->num_cols());
629 }
630 return true;
631}
632
633void WriteArrayToFileOrDie(const string& filename,
634 const double* x,
635 const int size) {
636 CHECK(x != nullptr);
637 VLOG(2) << "Writing array to: " << filename;
638 FILE* fptr = fopen(filename.c_str(), "w");
639 CHECK(fptr != nullptr);
640 for (int i = 0; i < size; ++i) {
641 fprintf(fptr, "%17f\n", x[i]);
642 }
643 fclose(fptr);
644}
645
646bool DumpLinearLeastSquaresProblemToTextFile(const string& filename_base,
647 const SparseMatrix* A,
648 const double* D,
649 const double* b,
650 const double* x,
651 int num_eliminate_blocks) {
652 CHECK(A != nullptr);
653 LOG(INFO) << "writing to: " << filename_base << "*";
654
655 string matlab_script;
656 StringAppendF(&matlab_script,
657 "function lsqp = load_trust_region_problem()\n");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800658 StringAppendF(&matlab_script, "lsqp.num_rows = %d;\n", A->num_rows());
659 StringAppendF(&matlab_script, "lsqp.num_cols = %d;\n", A->num_cols());
Austin Schuh70cc9552019-01-21 19:46:48 -0800660
661 {
662 string filename = filename_base + "_A.txt";
663 FILE* fptr = fopen(filename.c_str(), "w");
664 CHECK(fptr != nullptr);
665 A->ToTextFile(fptr);
666 fclose(fptr);
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800667 StringAppendF(
668 &matlab_script, "tmp = load('%s', '-ascii');\n", filename.c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800669 StringAppendF(
670 &matlab_script,
671 "lsqp.A = sparse(tmp(:, 1) + 1, tmp(:, 2) + 1, tmp(:, 3), %d, %d);\n",
672 A->num_rows(),
673 A->num_cols());
674 }
675
Austin Schuh70cc9552019-01-21 19:46:48 -0800676 if (D != NULL) {
677 string filename = filename_base + "_D.txt";
678 WriteArrayToFileOrDie(filename, D, A->num_cols());
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800679 StringAppendF(
680 &matlab_script, "lsqp.D = load('%s', '-ascii');\n", filename.c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800681 }
682
683 if (b != NULL) {
684 string filename = filename_base + "_b.txt";
685 WriteArrayToFileOrDie(filename, b, A->num_rows());
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800686 StringAppendF(
687 &matlab_script, "lsqp.b = load('%s', '-ascii');\n", filename.c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800688 }
689
690 if (x != NULL) {
691 string filename = filename_base + "_x.txt";
692 WriteArrayToFileOrDie(filename, x, A->num_cols());
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800693 StringAppendF(
694 &matlab_script, "lsqp.x = load('%s', '-ascii');\n", filename.c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800695 }
696
697 string matlab_filename = filename_base + ".m";
698 WriteStringToFileOrDie(matlab_script, matlab_filename);
699 return true;
700}
701} // namespace
702
703bool DumpLinearLeastSquaresProblem(const string& filename_base,
704 DumpFormatType dump_format_type,
705 const SparseMatrix* A,
706 const double* D,
707 const double* b,
708 const double* x,
709 int num_eliminate_blocks) {
710 switch (dump_format_type) {
711 case CONSOLE:
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800712 return DumpLinearLeastSquaresProblemToConsole(
713 A, D, b, x, num_eliminate_blocks);
Austin Schuh70cc9552019-01-21 19:46:48 -0800714 case TEXTFILE:
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800715 return DumpLinearLeastSquaresProblemToTextFile(
716 filename_base, A, D, b, x, num_eliminate_blocks);
Austin Schuh70cc9552019-01-21 19:46:48 -0800717 default:
718 LOG(FATAL) << "Unknown DumpFormatType " << dump_format_type;
719 }
720
721 return true;
722}
723
724} // namespace internal
725} // namespace ceres