blob: 06c34d42272bc6fdd0c0db7815ba8afdf193d523 [file] [log] [blame]
Parker Schuh6691f192017-01-14 17:01:02 -08001#include "aos/vision/blob/hierarchical_contour_merge.h"
2
Tyler Chatowbf0609c2021-07-31 16:13:27 -07003#include <cmath>
Parker Schuh6691f192017-01-14 17:01:02 -08004#include <queue>
5
6#include "aos/vision/blob/disjoint_set.h"
7
Stephan Pleinesf63bde82024-01-13 15:59:33 -08008namespace aos::vision {
Parker Schuh6691f192017-01-14 17:01:02 -08009
10namespace {
11
12int Mod(int a, int n) { return a - n * (a / n); }
13
14} // namespace
15
16template <typename T>
17class IntegralArray {
18 public:
19 IntegralArray() {}
20 IntegralArray(int size) { items_.reserve(size); }
21
22 // This is an exclusive range lookup into a modulo ring.
23 // The integral is precomputed in items_ and is inclusive even though
Parker Schuh0ff777c2017-02-19 15:01:13 -080024 // the input is [a, b).
Parker Schuh6691f192017-01-14 17:01:02 -080025 T Get(int a, int b) {
26 a = Mod(a, items_.size());
27 b = Mod(b, items_.size());
28 if (a == b) return 0;
29 if (b < a) {
30 if (b == 0) {
31 return items_[items_.size() - 1] - items_[a - 1];
32 }
33 return items_[items_.size() - 1] + items_[b - 1] - items_[a - 1];
34 }
35 if (a == 0) {
36 return items_[b - 1];
37 } else {
38 return items_[b - 1] - items_[a - 1];
39 }
40 }
41 void Add(T t) {
42 if (items_.size() == 0) {
43 items_.push_back(t);
44 } else {
45 items_.push_back(t + items_[items_.size() - 1]);
46 }
47 }
48
49 private:
50 std::vector<T> items_;
51};
52
53class IntegralLineFit {
54 public:
55 IntegralLineFit(int number_of_points, int min_line_length)
56 : xx_(number_of_points),
57 xy_(number_of_points),
58 yy_(number_of_points),
59 x_(number_of_points),
60 y_(number_of_points),
61 // These are not IntegralArrays.
62 n_(number_of_points),
63 min_len_(min_line_length) {}
64
65 void AddPt(Point pt) {
66 xx_.Add(pt.x * pt.x);
67 xy_.Add(pt.x * pt.y);
68 yy_.Add(pt.y * pt.y);
69 x_.Add(pt.x);
70 y_.Add(pt.y);
71 }
72
73 int GetNForRange(int st, int ed) {
74 int nv = (ed + 1) - st;
75 if (ed < st) {
76 nv += n_;
77 }
78 return nv;
79 }
80
81 float GetLineErrorRate(int st, int ed) {
82 int64_t nv = GetNForRange(st, ed);
83
84 int64_t px = x_.Get(st, ed);
85 int64_t py = y_.Get(st, ed);
86 int64_t pxx = xx_.Get(st, ed);
87 int64_t pxy = xy_.Get(st, ed);
88 int64_t pyy = yy_.Get(st, ed);
89
90 double nvsq = nv * nv;
91 double m_xx = (pxx * nv - px * px) / nvsq;
92 double m_xy = (pxy * nv - px * py) / nvsq;
93 double m_yy = (pyy * nv - py * py) / nvsq;
94
95 double b = m_xx + m_yy;
96 double c = m_xx * m_yy - m_xy * m_xy;
97 return ((b - sqrt(b * b - 4 * c)) / 2.0);
98 }
99
100 float GetErrorLineRange(int st, int ed) {
101 int nv = GetNForRange(st, ed);
102 int j = std::max(min_len_ - nv, 0) / 2;
103 return GetLineErrorRate((st - j + n_) % n_, (ed + 1 + j + n_) % n_);
104 }
105
106 FittedLine FitLine(int st, int ed, Point pst, Point ped) {
107 int nv = GetNForRange(st, ed);
108 // Adjust line out to be at least min_len_.
109 int j = std::max(min_len_ - nv, 0) / 2;
110
111 st = Mod(st - j, n_);
112 ed = Mod(ed + 1 + j, n_);
113 if (nv <= min_len_) {
114 return FittedLine{pst, pst};
115 }
116
117 int64_t px = x_.Get(st, ed);
118 int64_t py = y_.Get(st, ed);
119 int64_t pxx = xx_.Get(st, ed);
120 int64_t pxy = xy_.Get(st, ed);
121 int64_t pyy = yy_.Get(st, ed);
122
123 double nvsq = nv * nv;
124 double m_xx = (pxx * nv - px * px) / nvsq;
125 double m_xy = (pxy * nv - px * py) / nvsq;
126 double m_yy = (pyy * nv - py * py) / nvsq;
127 double m_x = px / ((double)nv);
128 double m_y = py / ((double)nv);
129
130 double b = (m_xx + m_yy) / 2.0;
131 double c = m_xx * m_yy - m_xy * m_xy;
132
133 double eiggen = sqrt(b * b - c);
134 double eigv = b - eiggen;
135
136 double vx = m_xx - eigv;
137 double vy = m_xy;
138 double mag = sqrt(vx * vx + vy * vy);
139 vx /= mag;
140 vy /= mag;
141
142 double av = vx * (pst.x - m_x) + vy * (pst.y - m_y);
143 double bv = vx * (ped.x - m_x) + vy * (ped.y - m_y);
144
145 Point apt = {(int)(m_x + vx * av), (int)(m_y + vy * av)};
146 Point bpt = {(int)(m_x + vx * bv), (int)(m_y + vy * bv)};
147
148 return FittedLine{apt, bpt};
149 }
150
151 private:
152 IntegralArray<int> xx_;
153 IntegralArray<int> xy_;
154 IntegralArray<int> yy_;
155 IntegralArray<int> x_;
156 IntegralArray<int> y_;
157
158 // Number of points in contour.
159 int n_;
160
161 // Minimum line length we will look for.
162 int min_len_;
163};
164
165struct JoinEvent {
166 int st;
167 int ed;
168 // All joins defined to be equal in priority.
169 // To be used in std::pair<float, JoinEvent> so need a comparator
170 // event though it isn't used.
171 bool operator<(const JoinEvent & /*o*/) const { return false; }
172};
173
174void HierarchicalMerge(ContourNode *stval, std::vector<FittedLine> *fit_lines,
175 float merge_rate, int min_len) {
176 ContourNode *c = stval;
177 // count the number of points in the contour.
178 int n = 0;
179 do {
180 n++;
181 c = c->next;
182 } while (c != stval);
183 IntegralLineFit fit(n, min_len);
184 c = stval;
185 std::vector<Point> pts;
186 do {
187 fit.AddPt(c->pt);
188 pts.push_back(c->pt);
189 c = c->next;
190 } while (c != stval);
191
192 DisjointSet ids(n);
193
194 std::vector<int> sts;
195 sts.reserve(n);
196 std::vector<int> eds;
197 eds.reserve(n);
198 for (int i = 0; i < n; i++) {
199 sts.push_back(i);
200 eds.push_back(i);
201 }
202
203 // Note priority queue takes a pair, so float is used as the priority.
204 std::priority_queue<std::pair<float, JoinEvent>> events;
205 for (int i = 0; i < n; i++) {
206 float err = fit.GetErrorLineRange(i - 1, i);
207 events.push(
208 std::pair<float, JoinEvent>(err, JoinEvent{(i - 1 + n) % n, i}));
209 }
210
211 while (events.size() > 0) {
212 auto event = events.top().second;
213 // Merge the lines that are most like a line.
214 events.pop();
215 int pi1 = ids.Find(event.st);
216 int pi2 = ids.Find(event.ed);
217 int st = sts[pi1];
218 int ed = eds[pi2];
219 if (st == event.st && ed == event.ed && pi1 != pi2) {
220 ids[pi2] = pi1;
221 int pi = sts[ids.Find((st - 1 + n) % n)];
222 int ni = eds[ids.Find((ed + 1 + n) % n)];
223 eds[pi1] = ed;
224 if (pi != st) {
225 float err = fit.GetErrorLineRange(pi, ed);
226 if (err < merge_rate) {
227 events.push(std::pair<float, JoinEvent>(err, JoinEvent{pi, ed}));
228 }
229 }
230 if (ni != ed) {
231 float err = fit.GetErrorLineRange(st, ni);
232 if (err < merge_rate) {
233 events.push(std::pair<float, JoinEvent>(err, JoinEvent{st, ni}));
234 }
235 }
236 }
237 }
238 for (int i = 0; i < n; i++) {
239 if (ids[i] == -1) {
240 int sti = sts[i];
241 int edi = eds[i];
242 if ((edi - sti + n) % n > min_len) {
243 auto line_fit = fit.FitLine(sti, edi, pts[sti], pts[edi]);
244 fit_lines->emplace_back(line_fit);
245 }
246 }
247 }
248}
249
Stephan Pleinesf63bde82024-01-13 15:59:33 -0800250} // namespace aos::vision