blob: ec617ca3d01423c41db63777e8e71cd73044f137 [file] [log] [blame]
Parker Schuh6691f192017-01-14 17:01:02 -08001#include "aos/vision/blob/hierarchical_contour_merge.h"
2
3#include <math.h>
4#include <queue>
5
6#include "aos/vision/blob/disjoint_set.h"
7
8namespace aos {
9namespace vision {
10
11namespace {
12
13int Mod(int a, int n) { return a - n * (a / n); }
14
15} // namespace
16
17template <typename T>
18class IntegralArray {
19 public:
20 IntegralArray() {}
21 IntegralArray(int size) { items_.reserve(size); }
22
23 // This is an exclusive range lookup into a modulo ring.
24 // The integral is precomputed in items_ and is inclusive even though
Parker Schuh0ff777c2017-02-19 15:01:13 -080025 // the input is [a, b).
Parker Schuh6691f192017-01-14 17:01:02 -080026 T Get(int a, int b) {
27 a = Mod(a, items_.size());
28 b = Mod(b, items_.size());
29 if (a == b) return 0;
30 if (b < a) {
31 if (b == 0) {
32 return items_[items_.size() - 1] - items_[a - 1];
33 }
34 return items_[items_.size() - 1] + items_[b - 1] - items_[a - 1];
35 }
36 if (a == 0) {
37 return items_[b - 1];
38 } else {
39 return items_[b - 1] - items_[a - 1];
40 }
41 }
42 void Add(T t) {
43 if (items_.size() == 0) {
44 items_.push_back(t);
45 } else {
46 items_.push_back(t + items_[items_.size() - 1]);
47 }
48 }
49
50 private:
51 std::vector<T> items_;
52};
53
54class IntegralLineFit {
55 public:
56 IntegralLineFit(int number_of_points, int min_line_length)
57 : xx_(number_of_points),
58 xy_(number_of_points),
59 yy_(number_of_points),
60 x_(number_of_points),
61 y_(number_of_points),
62 // These are not IntegralArrays.
63 n_(number_of_points),
64 min_len_(min_line_length) {}
65
66 void AddPt(Point pt) {
67 xx_.Add(pt.x * pt.x);
68 xy_.Add(pt.x * pt.y);
69 yy_.Add(pt.y * pt.y);
70 x_.Add(pt.x);
71 y_.Add(pt.y);
72 }
73
74 int GetNForRange(int st, int ed) {
75 int nv = (ed + 1) - st;
76 if (ed < st) {
77 nv += n_;
78 }
79 return nv;
80 }
81
82 float GetLineErrorRate(int st, int ed) {
83 int64_t nv = GetNForRange(st, ed);
84
85 int64_t px = x_.Get(st, ed);
86 int64_t py = y_.Get(st, ed);
87 int64_t pxx = xx_.Get(st, ed);
88 int64_t pxy = xy_.Get(st, ed);
89 int64_t pyy = yy_.Get(st, ed);
90
91 double nvsq = nv * nv;
92 double m_xx = (pxx * nv - px * px) / nvsq;
93 double m_xy = (pxy * nv - px * py) / nvsq;
94 double m_yy = (pyy * nv - py * py) / nvsq;
95
96 double b = m_xx + m_yy;
97 double c = m_xx * m_yy - m_xy * m_xy;
98 return ((b - sqrt(b * b - 4 * c)) / 2.0);
99 }
100
101 float GetErrorLineRange(int st, int ed) {
102 int nv = GetNForRange(st, ed);
103 int j = std::max(min_len_ - nv, 0) / 2;
104 return GetLineErrorRate((st - j + n_) % n_, (ed + 1 + j + n_) % n_);
105 }
106
107 FittedLine FitLine(int st, int ed, Point pst, Point ped) {
108 int nv = GetNForRange(st, ed);
109 // Adjust line out to be at least min_len_.
110 int j = std::max(min_len_ - nv, 0) / 2;
111
112 st = Mod(st - j, n_);
113 ed = Mod(ed + 1 + j, n_);
114 if (nv <= min_len_) {
115 return FittedLine{pst, pst};
116 }
117
118 int64_t px = x_.Get(st, ed);
119 int64_t py = y_.Get(st, ed);
120 int64_t pxx = xx_.Get(st, ed);
121 int64_t pxy = xy_.Get(st, ed);
122 int64_t pyy = yy_.Get(st, ed);
123
124 double nvsq = nv * nv;
125 double m_xx = (pxx * nv - px * px) / nvsq;
126 double m_xy = (pxy * nv - px * py) / nvsq;
127 double m_yy = (pyy * nv - py * py) / nvsq;
128 double m_x = px / ((double)nv);
129 double m_y = py / ((double)nv);
130
131 double b = (m_xx + m_yy) / 2.0;
132 double c = m_xx * m_yy - m_xy * m_xy;
133
134 double eiggen = sqrt(b * b - c);
135 double eigv = b - eiggen;
136
137 double vx = m_xx - eigv;
138 double vy = m_xy;
139 double mag = sqrt(vx * vx + vy * vy);
140 vx /= mag;
141 vy /= mag;
142
143 double av = vx * (pst.x - m_x) + vy * (pst.y - m_y);
144 double bv = vx * (ped.x - m_x) + vy * (ped.y - m_y);
145
146 Point apt = {(int)(m_x + vx * av), (int)(m_y + vy * av)};
147 Point bpt = {(int)(m_x + vx * bv), (int)(m_y + vy * bv)};
148
149 return FittedLine{apt, bpt};
150 }
151
152 private:
153 IntegralArray<int> xx_;
154 IntegralArray<int> xy_;
155 IntegralArray<int> yy_;
156 IntegralArray<int> x_;
157 IntegralArray<int> y_;
158
159 // Number of points in contour.
160 int n_;
161
162 // Minimum line length we will look for.
163 int min_len_;
164};
165
166struct JoinEvent {
167 int st;
168 int ed;
169 // All joins defined to be equal in priority.
170 // To be used in std::pair<float, JoinEvent> so need a comparator
171 // event though it isn't used.
172 bool operator<(const JoinEvent & /*o*/) const { return false; }
173};
174
175void HierarchicalMerge(ContourNode *stval, std::vector<FittedLine> *fit_lines,
176 float merge_rate, int min_len) {
177 ContourNode *c = stval;
178 // count the number of points in the contour.
179 int n = 0;
180 do {
181 n++;
182 c = c->next;
183 } while (c != stval);
184 IntegralLineFit fit(n, min_len);
185 c = stval;
186 std::vector<Point> pts;
187 do {
188 fit.AddPt(c->pt);
189 pts.push_back(c->pt);
190 c = c->next;
191 } while (c != stval);
192
193 DisjointSet ids(n);
194
195 std::vector<int> sts;
196 sts.reserve(n);
197 std::vector<int> eds;
198 eds.reserve(n);
199 for (int i = 0; i < n; i++) {
200 sts.push_back(i);
201 eds.push_back(i);
202 }
203
204 // Note priority queue takes a pair, so float is used as the priority.
205 std::priority_queue<std::pair<float, JoinEvent>> events;
206 for (int i = 0; i < n; i++) {
207 float err = fit.GetErrorLineRange(i - 1, i);
208 events.push(
209 std::pair<float, JoinEvent>(err, JoinEvent{(i - 1 + n) % n, i}));
210 }
211
212 while (events.size() > 0) {
213 auto event = events.top().second;
214 // Merge the lines that are most like a line.
215 events.pop();
216 int pi1 = ids.Find(event.st);
217 int pi2 = ids.Find(event.ed);
218 int st = sts[pi1];
219 int ed = eds[pi2];
220 if (st == event.st && ed == event.ed && pi1 != pi2) {
221 ids[pi2] = pi1;
222 int pi = sts[ids.Find((st - 1 + n) % n)];
223 int ni = eds[ids.Find((ed + 1 + n) % n)];
224 eds[pi1] = ed;
225 if (pi != st) {
226 float err = fit.GetErrorLineRange(pi, ed);
227 if (err < merge_rate) {
228 events.push(std::pair<float, JoinEvent>(err, JoinEvent{pi, ed}));
229 }
230 }
231 if (ni != ed) {
232 float err = fit.GetErrorLineRange(st, ni);
233 if (err < merge_rate) {
234 events.push(std::pair<float, JoinEvent>(err, JoinEvent{st, ni}));
235 }
236 }
237 }
238 }
239 for (int i = 0; i < n; i++) {
240 if (ids[i] == -1) {
241 int sti = sts[i];
242 int edi = eds[i];
243 if ((edi - sti + n) % n > min_len) {
244 auto line_fit = fit.FitLine(sti, edi, pts[sti], pts[edi]);
245 fit_lines->emplace_back(line_fit);
246 }
247 }
248 }
249}
250
251} // namespace vision
252} // namespace aos