blob: c5d3eb9e4924ad9339141c7eec649b1b022575f8 [file] [log] [blame]
Parker Schuh6691f192017-01-14 17:01:02 -08001#include "aos/vision/blob/hierarchical_contour_merge.h"
2
Stephan Pleinescc500b92024-05-30 10:58:40 -07003#include <stdint.h>
4
5#include <algorithm>
Tyler Chatowbf0609c2021-07-31 16:13:27 -07006#include <cmath>
Parker Schuh6691f192017-01-14 17:01:02 -08007#include <queue>
Stephan Pleinescc500b92024-05-30 10:58:40 -07008#include <utility>
Parker Schuh6691f192017-01-14 17:01:02 -08009
10#include "aos/vision/blob/disjoint_set.h"
11
Stephan Pleinesf63bde82024-01-13 15:59:33 -080012namespace aos::vision {
Parker Schuh6691f192017-01-14 17:01:02 -080013
14namespace {
15
16int Mod(int a, int n) { return a - n * (a / n); }
17
18} // namespace
19
20template <typename T>
21class IntegralArray {
22 public:
23 IntegralArray() {}
24 IntegralArray(int size) { items_.reserve(size); }
25
26 // This is an exclusive range lookup into a modulo ring.
27 // The integral is precomputed in items_ and is inclusive even though
Parker Schuh0ff777c2017-02-19 15:01:13 -080028 // the input is [a, b).
Parker Schuh6691f192017-01-14 17:01:02 -080029 T Get(int a, int b) {
30 a = Mod(a, items_.size());
31 b = Mod(b, items_.size());
32 if (a == b) return 0;
33 if (b < a) {
34 if (b == 0) {
35 return items_[items_.size() - 1] - items_[a - 1];
36 }
37 return items_[items_.size() - 1] + items_[b - 1] - items_[a - 1];
38 }
39 if (a == 0) {
40 return items_[b - 1];
41 } else {
42 return items_[b - 1] - items_[a - 1];
43 }
44 }
45 void Add(T t) {
46 if (items_.size() == 0) {
47 items_.push_back(t);
48 } else {
49 items_.push_back(t + items_[items_.size() - 1]);
50 }
51 }
52
53 private:
54 std::vector<T> items_;
55};
56
57class IntegralLineFit {
58 public:
59 IntegralLineFit(int number_of_points, int min_line_length)
60 : xx_(number_of_points),
61 xy_(number_of_points),
62 yy_(number_of_points),
63 x_(number_of_points),
64 y_(number_of_points),
65 // These are not IntegralArrays.
66 n_(number_of_points),
67 min_len_(min_line_length) {}
68
69 void AddPt(Point pt) {
70 xx_.Add(pt.x * pt.x);
71 xy_.Add(pt.x * pt.y);
72 yy_.Add(pt.y * pt.y);
73 x_.Add(pt.x);
74 y_.Add(pt.y);
75 }
76
77 int GetNForRange(int st, int ed) {
78 int nv = (ed + 1) - st;
79 if (ed < st) {
80 nv += n_;
81 }
82 return nv;
83 }
84
85 float GetLineErrorRate(int st, int ed) {
86 int64_t nv = GetNForRange(st, ed);
87
88 int64_t px = x_.Get(st, ed);
89 int64_t py = y_.Get(st, ed);
90 int64_t pxx = xx_.Get(st, ed);
91 int64_t pxy = xy_.Get(st, ed);
92 int64_t pyy = yy_.Get(st, ed);
93
94 double nvsq = nv * nv;
95 double m_xx = (pxx * nv - px * px) / nvsq;
96 double m_xy = (pxy * nv - px * py) / nvsq;
97 double m_yy = (pyy * nv - py * py) / nvsq;
98
99 double b = m_xx + m_yy;
100 double c = m_xx * m_yy - m_xy * m_xy;
101 return ((b - sqrt(b * b - 4 * c)) / 2.0);
102 }
103
104 float GetErrorLineRange(int st, int ed) {
105 int nv = GetNForRange(st, ed);
106 int j = std::max(min_len_ - nv, 0) / 2;
107 return GetLineErrorRate((st - j + n_) % n_, (ed + 1 + j + n_) % n_);
108 }
109
110 FittedLine FitLine(int st, int ed, Point pst, Point ped) {
111 int nv = GetNForRange(st, ed);
112 // Adjust line out to be at least min_len_.
113 int j = std::max(min_len_ - nv, 0) / 2;
114
115 st = Mod(st - j, n_);
116 ed = Mod(ed + 1 + j, n_);
117 if (nv <= min_len_) {
118 return FittedLine{pst, pst};
119 }
120
121 int64_t px = x_.Get(st, ed);
122 int64_t py = y_.Get(st, ed);
123 int64_t pxx = xx_.Get(st, ed);
124 int64_t pxy = xy_.Get(st, ed);
125 int64_t pyy = yy_.Get(st, ed);
126
127 double nvsq = nv * nv;
128 double m_xx = (pxx * nv - px * px) / nvsq;
129 double m_xy = (pxy * nv - px * py) / nvsq;
130 double m_yy = (pyy * nv - py * py) / nvsq;
131 double m_x = px / ((double)nv);
132 double m_y = py / ((double)nv);
133
134 double b = (m_xx + m_yy) / 2.0;
135 double c = m_xx * m_yy - m_xy * m_xy;
136
137 double eiggen = sqrt(b * b - c);
138 double eigv = b - eiggen;
139
140 double vx = m_xx - eigv;
141 double vy = m_xy;
142 double mag = sqrt(vx * vx + vy * vy);
143 vx /= mag;
144 vy /= mag;
145
146 double av = vx * (pst.x - m_x) + vy * (pst.y - m_y);
147 double bv = vx * (ped.x - m_x) + vy * (ped.y - m_y);
148
149 Point apt = {(int)(m_x + vx * av), (int)(m_y + vy * av)};
150 Point bpt = {(int)(m_x + vx * bv), (int)(m_y + vy * bv)};
151
152 return FittedLine{apt, bpt};
153 }
154
155 private:
156 IntegralArray<int> xx_;
157 IntegralArray<int> xy_;
158 IntegralArray<int> yy_;
159 IntegralArray<int> x_;
160 IntegralArray<int> y_;
161
162 // Number of points in contour.
163 int n_;
164
165 // Minimum line length we will look for.
166 int min_len_;
167};
168
169struct JoinEvent {
170 int st;
171 int ed;
172 // All joins defined to be equal in priority.
173 // To be used in std::pair<float, JoinEvent> so need a comparator
174 // event though it isn't used.
175 bool operator<(const JoinEvent & /*o*/) const { return false; }
176};
177
178void HierarchicalMerge(ContourNode *stval, std::vector<FittedLine> *fit_lines,
179 float merge_rate, int min_len) {
180 ContourNode *c = stval;
181 // count the number of points in the contour.
182 int n = 0;
183 do {
184 n++;
185 c = c->next;
186 } while (c != stval);
187 IntegralLineFit fit(n, min_len);
188 c = stval;
189 std::vector<Point> pts;
190 do {
191 fit.AddPt(c->pt);
192 pts.push_back(c->pt);
193 c = c->next;
194 } while (c != stval);
195
196 DisjointSet ids(n);
197
198 std::vector<int> sts;
199 sts.reserve(n);
200 std::vector<int> eds;
201 eds.reserve(n);
202 for (int i = 0; i < n; i++) {
203 sts.push_back(i);
204 eds.push_back(i);
205 }
206
207 // Note priority queue takes a pair, so float is used as the priority.
208 std::priority_queue<std::pair<float, JoinEvent>> events;
209 for (int i = 0; i < n; i++) {
210 float err = fit.GetErrorLineRange(i - 1, i);
211 events.push(
212 std::pair<float, JoinEvent>(err, JoinEvent{(i - 1 + n) % n, i}));
213 }
214
215 while (events.size() > 0) {
216 auto event = events.top().second;
217 // Merge the lines that are most like a line.
218 events.pop();
219 int pi1 = ids.Find(event.st);
220 int pi2 = ids.Find(event.ed);
221 int st = sts[pi1];
222 int ed = eds[pi2];
223 if (st == event.st && ed == event.ed && pi1 != pi2) {
224 ids[pi2] = pi1;
225 int pi = sts[ids.Find((st - 1 + n) % n)];
226 int ni = eds[ids.Find((ed + 1 + n) % n)];
227 eds[pi1] = ed;
228 if (pi != st) {
229 float err = fit.GetErrorLineRange(pi, ed);
230 if (err < merge_rate) {
231 events.push(std::pair<float, JoinEvent>(err, JoinEvent{pi, ed}));
232 }
233 }
234 if (ni != ed) {
235 float err = fit.GetErrorLineRange(st, ni);
236 if (err < merge_rate) {
237 events.push(std::pair<float, JoinEvent>(err, JoinEvent{st, ni}));
238 }
239 }
240 }
241 }
242 for (int i = 0; i < n; i++) {
243 if (ids[i] == -1) {
244 int sti = sts[i];
245 int edi = eds[i];
246 if ((edi - sti + n) % n > min_len) {
247 auto line_fit = fit.FitLine(sti, edi, pts[sti], pts[edi]);
248 fit_lines->emplace_back(line_fit);
249 }
250 }
251 }
252}
253
Stephan Pleinesf63bde82024-01-13 15:59:33 -0800254} // namespace aos::vision