blob: 4254b716bd16e1fbb5e5604eaece7ba634c52431 [file] [log] [blame]
Austin Schuh8c267c72023-11-18 14:05:14 -08001#include "frc971/orin/threshold.h"
2
3#include <stdint.h>
4
5#include "frc971/orin/cuda.h"
6
Stephan Pleinesf63bde82024-01-13 15:59:33 -08007namespace frc971::apriltag {
Austin Schuh8c267c72023-11-18 14:05:14 -08008namespace {
9
10// 1280 -> 2 * 128 * 5
11// 720 -> 2 * 8 * 5 * 9
12//
13// 1456 -> 2 * 8 * 7 * 13
14// 1088 -> 2 * 32 * 17
15
16// Writes out the grayscale image and decimated image.
17__global__ void InternalCudaToGreyscaleAndDecimateHalide(
18 const uint8_t *color_image, uint8_t *gray_image, uint8_t *decimated_image,
19 size_t width, size_t height) {
20 size_t i = blockIdx.x * blockDim.x + threadIdx.x;
21 while (i < width * height) {
22 uint8_t pixel = gray_image[i] = color_image[i * 2];
23
24 const size_t row = i / width;
25 const size_t col = i - width * row;
26
27 // Copy over every other pixel.
28 if (row % 2 == 0 && col % 2 == 0) {
29 size_t decimated_row = row / 2;
30 size_t decimated_col = col / 2;
31 decimated_image[decimated_row * width / 2 + decimated_col] = pixel;
32 }
33 i += blockDim.x * gridDim.x;
34 }
35
36 // TODO(austin): Figure out how to load contiguous memory reasonably
37 // efficiently and max/min over it.
38
39 // TODO(austin): Can we do the threshold here too? That would be less memory
40 // bandwidth consumed...
41}
42
43// Returns the min and max for a row of 4 pixels.
44__forceinline__ __device__ uchar2 minmax(uchar4 row) {
45 uint8_t min_val = std::min(std::min(row.x, row.y), std::min(row.z, row.w));
46 uint8_t max_val = std::max(std::max(row.x, row.y), std::max(row.z, row.w));
47 return make_uchar2(min_val, max_val);
48}
49
50// Returns the min and max for a set of min and maxes.
51__forceinline__ __device__ uchar2 minmax(uchar2 val0, uchar2 val1) {
52 return make_uchar2(std::min(val0.x, val1.x), std::max(val0.y, val1.y));
53}
54
55// Returns the pixel index of a pixel at the provided x and y location.
56__forceinline__ __device__ size_t XYToIndex(size_t width, size_t x, size_t y) {
57 return width * y + x;
58}
59
60// Computes the min and max pixel value for each block of 4 pixels.
61__global__ void InternalBlockMinMax(const uint8_t *decimated_image,
62 uchar2 *unfiltered_minmax_image,
63 size_t width, size_t height) {
64 uchar2 vals[4];
65 const size_t x = blockIdx.x * blockDim.x + threadIdx.x;
66 const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
67
68 if (x >= width || y >= height) {
69 return;
70 }
71
72 for (int i = 0; i < 4; ++i) {
73 const uchar4 decimated_block = *reinterpret_cast<const uchar4 *>(
74 decimated_image + XYToIndex(width * 4, x * 4, y * 4 + i));
75
76 vals[i] = minmax(decimated_block);
77 }
78
79 unfiltered_minmax_image[XYToIndex(width, x, y)] =
80 minmax(minmax(vals[0], vals[1]), minmax(vals[2], vals[3]));
81}
82
83// Filters the min/max for the surrounding block of 9 pixels centered on our
84// location using min/max and writes the result back out.
85__global__ void InternalBlockFilter(const uchar2 *unfiltered_minmax_image,
86 uchar2 *minmax_image, size_t width,
87 size_t height) {
88 uchar2 result = make_uchar2(255, 0);
89
90 const size_t x = blockIdx.x * blockDim.x + threadIdx.x;
91 const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
92
93 if (x >= width || y >= height) {
94 return;
95 }
96
97 // Iterate through the 3x3 set of points centered on the point this image is
98 // responsible for, and compute the overall min/max.
99#pragma unroll
100 for (int i = -1; i <= 1; ++i) {
101#pragma unroll
102 for (int j = -1; j <= 1; ++j) {
103 const ssize_t read_x = x + i;
104 const ssize_t read_y = y + j;
105
106 if (read_x < 0 || read_x >= static_cast<ssize_t>(width)) {
107 continue;
108 }
109 if (read_y < 0 || read_y >= static_cast<ssize_t>(height)) {
110 continue;
111 }
112
113 result = minmax(
114 result, unfiltered_minmax_image[XYToIndex(width, read_x, read_y)]);
115 }
116 }
117
118 minmax_image[XYToIndex(width, x, y)] = result;
119}
120
121// Thresholds the image based on the filtered thresholds.
122__global__ void InternalThreshold(const uint8_t *decimated_image,
123 const uchar2 *minmax_image,
124 uint8_t *thresholded_image, size_t width,
125 size_t height, size_t min_white_black_diff) {
126 size_t i = blockIdx.x * blockDim.x + threadIdx.x;
127 while (i < width * height) {
128 const size_t x = i % width;
129 const size_t y = i / width;
130
131 const uchar2 minmax_val = minmax_image[x / 4 + (y / 4) * width / 4];
132
133 uint8_t result;
134 if (minmax_val.y - minmax_val.x < min_white_black_diff) {
135 result = 127;
136 } else {
137 uint8_t thresh = minmax_val.x + (minmax_val.y - minmax_val.x) / 2;
138 if (decimated_image[i] > thresh) {
139 result = 255;
140 } else {
141 result = 0;
142 }
143 }
144
145 thresholded_image[i] = result;
146 i += blockDim.x * gridDim.x;
147 }
148}
149
150} // namespace
151
152void CudaToGreyscaleAndDecimateHalide(
153 const uint8_t *color_image, uint8_t *gray_image, uint8_t *decimated_image,
154 uint8_t *unfiltered_minmax_image, uint8_t *minmax_image,
155 uint8_t *thresholded_image, size_t width, size_t height,
156 size_t min_white_black_diff, CudaStream *stream) {
157 CHECK((width % 8) == 0);
158 CHECK((height % 8) == 0);
159 constexpr size_t kThreads = 256;
160 {
161 // Step one, convert to gray and decimate.
162 size_t kBlocks = (width * height + kThreads - 1) / kThreads / 4;
163 InternalCudaToGreyscaleAndDecimateHalide<<<kBlocks, kThreads, 0,
164 stream->get()>>>(
165 color_image, gray_image, decimated_image, width, height);
166 MaybeCheckAndSynchronize();
167 }
168
169 size_t decimated_width = width / 2;
170 size_t decimated_height = height / 2;
171
172 {
173 // Step 2, compute a min/max for each block of 4x4 (16) pixels.
174 dim3 threads(16, 16, 1);
175 dim3 blocks((decimated_width / 4 + 15) / 16,
176 (decimated_height / 4 + 15) / 16, 1);
177
178 InternalBlockMinMax<<<blocks, threads, 0, stream->get()>>>(
179 decimated_image, reinterpret_cast<uchar2 *>(unfiltered_minmax_image),
180 decimated_width / 4, decimated_height / 4);
181 MaybeCheckAndSynchronize();
182
183 // Step 3, Blur those min/max's a further +- 1 block in each direction using
184 // min/max.
185 InternalBlockFilter<<<blocks, threads, 0, stream->get()>>>(
186 reinterpret_cast<uchar2 *>(unfiltered_minmax_image),
187 reinterpret_cast<uchar2 *>(minmax_image), decimated_width / 4,
188 decimated_height / 4);
189 MaybeCheckAndSynchronize();
190 }
191
192 {
193 // Now, write out 127 if the min/max are too close to each other, or 0/255
194 // if the pixels are above or below the average of the min/max.
195 size_t kBlocks = (width * height / 4 + kThreads - 1) / kThreads / 4;
196 InternalThreshold<<<kBlocks, kThreads, 0, stream->get()>>>(
197 decimated_image, reinterpret_cast<uchar2 *>(minmax_image),
198 thresholded_image, decimated_width, decimated_height,
199 min_white_black_diff);
200 MaybeCheckAndSynchronize();
201 }
202}
203
Stephan Pleinesf63bde82024-01-13 15:59:33 -0800204} // namespace frc971::apriltag