blob: d74e570cef7ba3cf2c44bf81a8917feb57c3a81b [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
Austin Schuh3de38b02024-06-25 18:25:10 -07002// Copyright 2023 Google Inc. All rights reserved.
Austin Schuh70cc9552019-01-21 19:46:48 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: David Gallup (dgallup@google.com)
30// Sameer Agarwal (sameeragarwal@google.com)
31
32#include "ceres/canonical_views_clustering.h"
33
Austin Schuh70cc9552019-01-21 19:46:48 -080034#include <unordered_map>
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080035#include <unordered_set>
Austin Schuh3de38b02024-06-25 18:25:10 -070036#include <vector>
Austin Schuh70cc9552019-01-21 19:46:48 -080037
38#include "ceres/graph.h"
Austin Schuh3de38b02024-06-25 18:25:10 -070039#include "ceres/internal/export.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080040#include "ceres/map_util.h"
41#include "glog/logging.h"
42
Austin Schuh3de38b02024-06-25 18:25:10 -070043namespace ceres::internal {
Austin Schuh70cc9552019-01-21 19:46:48 -080044
Austin Schuh3de38b02024-06-25 18:25:10 -070045using IntMap = std::unordered_map<int, int>;
46using IntSet = std::unordered_set<int>;
Austin Schuh70cc9552019-01-21 19:46:48 -080047
Austin Schuh3de38b02024-06-25 18:25:10 -070048class CERES_NO_EXPORT CanonicalViewsClustering {
Austin Schuh70cc9552019-01-21 19:46:48 -080049 public:
Austin Schuh70cc9552019-01-21 19:46:48 -080050 // Compute the canonical views clustering of the vertices of the
51 // graph. centers will contain the vertices that are the identified
52 // as the canonical views/cluster centers, and membership is a map
53 // from vertices to cluster_ids. The i^th cluster center corresponds
54 // to the i^th cluster. It is possible depending on the
55 // configuration of the clustering algorithm that some of the
56 // vertices may not be assigned to any cluster. In this case they
57 // are assigned to a cluster with id = kInvalidClusterId.
58 void ComputeClustering(const CanonicalViewsClusteringOptions& options,
59 const WeightedGraph<int>& graph,
Austin Schuh3de38b02024-06-25 18:25:10 -070060 std::vector<int>* centers,
Austin Schuh70cc9552019-01-21 19:46:48 -080061 IntMap* membership);
62
63 private:
64 void FindValidViews(IntSet* valid_views) const;
Austin Schuh3de38b02024-06-25 18:25:10 -070065 double ComputeClusteringQualityDifference(
66 int candidate, const std::vector<int>& centers) const;
Austin Schuh70cc9552019-01-21 19:46:48 -080067 void UpdateCanonicalViewAssignments(const int canonical_view);
Austin Schuh3de38b02024-06-25 18:25:10 -070068 void ComputeClusterMembership(const std::vector<int>& centers,
Austin Schuh70cc9552019-01-21 19:46:48 -080069 IntMap* membership) const;
70
71 CanonicalViewsClusteringOptions options_;
72 const WeightedGraph<int>* graph_;
73 // Maps a view to its representative canonical view (its cluster
74 // center).
75 IntMap view_to_canonical_view_;
76 // Maps a view to its similarity to its current cluster center.
77 std::unordered_map<int, double> view_to_canonical_view_similarity_;
78};
79
80void ComputeCanonicalViewsClustering(
81 const CanonicalViewsClusteringOptions& options,
82 const WeightedGraph<int>& graph,
Austin Schuh3de38b02024-06-25 18:25:10 -070083 std::vector<int>* centers,
Austin Schuh70cc9552019-01-21 19:46:48 -080084 IntMap* membership) {
Austin Schuh3de38b02024-06-25 18:25:10 -070085 time_t start_time = time(nullptr);
Austin Schuh70cc9552019-01-21 19:46:48 -080086 CanonicalViewsClustering cv;
87 cv.ComputeClustering(options, graph, centers, membership);
88 VLOG(2) << "Canonical views clustering time (secs): "
Austin Schuh3de38b02024-06-25 18:25:10 -070089 << time(nullptr) - start_time;
Austin Schuh70cc9552019-01-21 19:46:48 -080090}
91
92// Implementation of CanonicalViewsClustering
93void CanonicalViewsClustering::ComputeClustering(
94 const CanonicalViewsClusteringOptions& options,
95 const WeightedGraph<int>& graph,
Austin Schuh3de38b02024-06-25 18:25:10 -070096 std::vector<int>* centers,
Austin Schuh70cc9552019-01-21 19:46:48 -080097 IntMap* membership) {
98 options_ = options;
99 CHECK(centers != nullptr);
100 CHECK(membership != nullptr);
101 centers->clear();
102 membership->clear();
103 graph_ = &graph;
104
105 IntSet valid_views;
106 FindValidViews(&valid_views);
Austin Schuh3de38b02024-06-25 18:25:10 -0700107 while (!valid_views.empty()) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800108 // Find the next best canonical view.
109 double best_difference = -std::numeric_limits<double>::max();
110 int best_view = 0;
111
112 // TODO(sameeragarwal): Make this loop multi-threaded.
113 for (const auto& view : valid_views) {
114 const double difference =
115 ComputeClusteringQualityDifference(view, *centers);
116 if (difference > best_difference) {
117 best_difference = difference;
118 best_view = view;
119 }
120 }
121
122 CHECK_GT(best_difference, -std::numeric_limits<double>::max());
123
124 // Add canonical view if quality improves, or if minimum is not
125 // yet met, otherwise break.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800126 if ((best_difference <= 0) && (centers->size() >= options_.min_views)) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800127 break;
128 }
129
130 centers->push_back(best_view);
131 valid_views.erase(best_view);
132 UpdateCanonicalViewAssignments(best_view);
133 }
134
135 ComputeClusterMembership(*centers, membership);
136}
137
138// Return the set of vertices of the graph which have valid vertex
139// weights.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800140void CanonicalViewsClustering::FindValidViews(IntSet* valid_views) const {
Austin Schuh70cc9552019-01-21 19:46:48 -0800141 const IntSet& views = graph_->vertices();
142 for (const auto& view : views) {
143 if (graph_->VertexWeight(view) != WeightedGraph<int>::InvalidWeight()) {
144 valid_views->insert(view);
145 }
146 }
147}
148
149// Computes the difference in the quality score if 'candidate' were
150// added to the set of canonical views.
151double CanonicalViewsClustering::ComputeClusteringQualityDifference(
Austin Schuh3de38b02024-06-25 18:25:10 -0700152 const int candidate, const std::vector<int>& centers) const {
Austin Schuh70cc9552019-01-21 19:46:48 -0800153 // View score.
154 double difference =
155 options_.view_score_weight * graph_->VertexWeight(candidate);
156
157 // Compute how much the quality score changes if the candidate view
158 // was added to the list of canonical views and its nearest
159 // neighbors became members of its cluster.
160 const IntSet& neighbors = graph_->Neighbors(candidate);
161 for (const auto& neighbor : neighbors) {
162 const double old_similarity =
163 FindWithDefault(view_to_canonical_view_similarity_, neighbor, 0.0);
164 const double new_similarity = graph_->EdgeWeight(neighbor, candidate);
165 if (new_similarity > old_similarity) {
166 difference += new_similarity - old_similarity;
167 }
168 }
169
170 // Number of views penalty.
171 difference -= options_.size_penalty_weight;
172
173 // Orthogonality.
Austin Schuh3de38b02024-06-25 18:25:10 -0700174 for (int center : centers) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800175 difference -= options_.similarity_penalty_weight *
Austin Schuh3de38b02024-06-25 18:25:10 -0700176 graph_->EdgeWeight(center, candidate);
Austin Schuh70cc9552019-01-21 19:46:48 -0800177 }
178
179 return difference;
180}
181
182// Reassign views if they're more similar to the new canonical view.
183void CanonicalViewsClustering::UpdateCanonicalViewAssignments(
184 const int canonical_view) {
185 const IntSet& neighbors = graph_->Neighbors(canonical_view);
186 for (const auto& neighbor : neighbors) {
187 const double old_similarity =
188 FindWithDefault(view_to_canonical_view_similarity_, neighbor, 0.0);
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800189 const double new_similarity = graph_->EdgeWeight(neighbor, canonical_view);
Austin Schuh70cc9552019-01-21 19:46:48 -0800190 if (new_similarity > old_similarity) {
191 view_to_canonical_view_[neighbor] = canonical_view;
192 view_to_canonical_view_similarity_[neighbor] = new_similarity;
193 }
194 }
195}
196
197// Assign a cluster id to each view.
198void CanonicalViewsClustering::ComputeClusterMembership(
Austin Schuh3de38b02024-06-25 18:25:10 -0700199 const std::vector<int>& centers, IntMap* membership) const {
Austin Schuh70cc9552019-01-21 19:46:48 -0800200 CHECK(membership != nullptr);
201 membership->clear();
202
203 // The i^th cluster has cluster id i.
204 IntMap center_to_cluster_id;
205 for (int i = 0; i < centers.size(); ++i) {
206 center_to_cluster_id[centers[i]] = i;
207 }
208
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800209 static constexpr int kInvalidClusterId = -1;
Austin Schuh70cc9552019-01-21 19:46:48 -0800210
211 const IntSet& views = graph_->vertices();
212 for (const auto& view : views) {
213 auto it = view_to_canonical_view_.find(view);
214 int cluster_id = kInvalidClusterId;
215 if (it != view_to_canonical_view_.end()) {
216 cluster_id = FindOrDie(center_to_cluster_id, it->second);
217 }
218
219 InsertOrDie(membership, view, cluster_id);
220 }
221}
222
Austin Schuh3de38b02024-06-25 18:25:10 -0700223} // namespace ceres::internal