blob: 79cccedab9e1d2cc01f0f2be793fbef868e79008 [file] [log] [blame]
Austin Schuh8c267c72023-11-18 14:05:14 -08001#include "frc971/orin/threshold.h"
2
3#include <stdint.h>
4
5#include "frc971/orin/cuda.h"
6
7namespace frc971 {
8namespace apriltag {
9namespace {
10
11// 1280 -> 2 * 128 * 5
12// 720 -> 2 * 8 * 5 * 9
13//
14// 1456 -> 2 * 8 * 7 * 13
15// 1088 -> 2 * 32 * 17
16
17// Writes out the grayscale image and decimated image.
18__global__ void InternalCudaToGreyscaleAndDecimateHalide(
19 const uint8_t *color_image, uint8_t *gray_image, uint8_t *decimated_image,
20 size_t width, size_t height) {
21 size_t i = blockIdx.x * blockDim.x + threadIdx.x;
22 while (i < width * height) {
23 uint8_t pixel = gray_image[i] = color_image[i * 2];
24
25 const size_t row = i / width;
26 const size_t col = i - width * row;
27
28 // Copy over every other pixel.
29 if (row % 2 == 0 && col % 2 == 0) {
30 size_t decimated_row = row / 2;
31 size_t decimated_col = col / 2;
32 decimated_image[decimated_row * width / 2 + decimated_col] = pixel;
33 }
34 i += blockDim.x * gridDim.x;
35 }
36
37 // TODO(austin): Figure out how to load contiguous memory reasonably
38 // efficiently and max/min over it.
39
40 // TODO(austin): Can we do the threshold here too? That would be less memory
41 // bandwidth consumed...
42}
43
44// Returns the min and max for a row of 4 pixels.
45__forceinline__ __device__ uchar2 minmax(uchar4 row) {
46 uint8_t min_val = std::min(std::min(row.x, row.y), std::min(row.z, row.w));
47 uint8_t max_val = std::max(std::max(row.x, row.y), std::max(row.z, row.w));
48 return make_uchar2(min_val, max_val);
49}
50
51// Returns the min and max for a set of min and maxes.
52__forceinline__ __device__ uchar2 minmax(uchar2 val0, uchar2 val1) {
53 return make_uchar2(std::min(val0.x, val1.x), std::max(val0.y, val1.y));
54}
55
56// Returns the pixel index of a pixel at the provided x and y location.
57__forceinline__ __device__ size_t XYToIndex(size_t width, size_t x, size_t y) {
58 return width * y + x;
59}
60
61// Computes the min and max pixel value for each block of 4 pixels.
62__global__ void InternalBlockMinMax(const uint8_t *decimated_image,
63 uchar2 *unfiltered_minmax_image,
64 size_t width, size_t height) {
65 uchar2 vals[4];
66 const size_t x = blockIdx.x * blockDim.x + threadIdx.x;
67 const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
68
69 if (x >= width || y >= height) {
70 return;
71 }
72
73 for (int i = 0; i < 4; ++i) {
74 const uchar4 decimated_block = *reinterpret_cast<const uchar4 *>(
75 decimated_image + XYToIndex(width * 4, x * 4, y * 4 + i));
76
77 vals[i] = minmax(decimated_block);
78 }
79
80 unfiltered_minmax_image[XYToIndex(width, x, y)] =
81 minmax(minmax(vals[0], vals[1]), minmax(vals[2], vals[3]));
82}
83
84// Filters the min/max for the surrounding block of 9 pixels centered on our
85// location using min/max and writes the result back out.
86__global__ void InternalBlockFilter(const uchar2 *unfiltered_minmax_image,
87 uchar2 *minmax_image, size_t width,
88 size_t height) {
89 uchar2 result = make_uchar2(255, 0);
90
91 const size_t x = blockIdx.x * blockDim.x + threadIdx.x;
92 const size_t y = blockIdx.y * blockDim.y + threadIdx.y;
93
94 if (x >= width || y >= height) {
95 return;
96 }
97
98 // Iterate through the 3x3 set of points centered on the point this image is
99 // responsible for, and compute the overall min/max.
100#pragma unroll
101 for (int i = -1; i <= 1; ++i) {
102#pragma unroll
103 for (int j = -1; j <= 1; ++j) {
104 const ssize_t read_x = x + i;
105 const ssize_t read_y = y + j;
106
107 if (read_x < 0 || read_x >= static_cast<ssize_t>(width)) {
108 continue;
109 }
110 if (read_y < 0 || read_y >= static_cast<ssize_t>(height)) {
111 continue;
112 }
113
114 result = minmax(
115 result, unfiltered_minmax_image[XYToIndex(width, read_x, read_y)]);
116 }
117 }
118
119 minmax_image[XYToIndex(width, x, y)] = result;
120}
121
122// Thresholds the image based on the filtered thresholds.
123__global__ void InternalThreshold(const uint8_t *decimated_image,
124 const uchar2 *minmax_image,
125 uint8_t *thresholded_image, size_t width,
126 size_t height, size_t min_white_black_diff) {
127 size_t i = blockIdx.x * blockDim.x + threadIdx.x;
128 while (i < width * height) {
129 const size_t x = i % width;
130 const size_t y = i / width;
131
132 const uchar2 minmax_val = minmax_image[x / 4 + (y / 4) * width / 4];
133
134 uint8_t result;
135 if (minmax_val.y - minmax_val.x < min_white_black_diff) {
136 result = 127;
137 } else {
138 uint8_t thresh = minmax_val.x + (minmax_val.y - minmax_val.x) / 2;
139 if (decimated_image[i] > thresh) {
140 result = 255;
141 } else {
142 result = 0;
143 }
144 }
145
146 thresholded_image[i] = result;
147 i += blockDim.x * gridDim.x;
148 }
149}
150
151} // namespace
152
153void CudaToGreyscaleAndDecimateHalide(
154 const uint8_t *color_image, uint8_t *gray_image, uint8_t *decimated_image,
155 uint8_t *unfiltered_minmax_image, uint8_t *minmax_image,
156 uint8_t *thresholded_image, size_t width, size_t height,
157 size_t min_white_black_diff, CudaStream *stream) {
158 CHECK((width % 8) == 0);
159 CHECK((height % 8) == 0);
160 constexpr size_t kThreads = 256;
161 {
162 // Step one, convert to gray and decimate.
163 size_t kBlocks = (width * height + kThreads - 1) / kThreads / 4;
164 InternalCudaToGreyscaleAndDecimateHalide<<<kBlocks, kThreads, 0,
165 stream->get()>>>(
166 color_image, gray_image, decimated_image, width, height);
167 MaybeCheckAndSynchronize();
168 }
169
170 size_t decimated_width = width / 2;
171 size_t decimated_height = height / 2;
172
173 {
174 // Step 2, compute a min/max for each block of 4x4 (16) pixels.
175 dim3 threads(16, 16, 1);
176 dim3 blocks((decimated_width / 4 + 15) / 16,
177 (decimated_height / 4 + 15) / 16, 1);
178
179 InternalBlockMinMax<<<blocks, threads, 0, stream->get()>>>(
180 decimated_image, reinterpret_cast<uchar2 *>(unfiltered_minmax_image),
181 decimated_width / 4, decimated_height / 4);
182 MaybeCheckAndSynchronize();
183
184 // Step 3, Blur those min/max's a further +- 1 block in each direction using
185 // min/max.
186 InternalBlockFilter<<<blocks, threads, 0, stream->get()>>>(
187 reinterpret_cast<uchar2 *>(unfiltered_minmax_image),
188 reinterpret_cast<uchar2 *>(minmax_image), decimated_width / 4,
189 decimated_height / 4);
190 MaybeCheckAndSynchronize();
191 }
192
193 {
194 // Now, write out 127 if the min/max are too close to each other, or 0/255
195 // if the pixels are above or below the average of the min/max.
196 size_t kBlocks = (width * height / 4 + kThreads - 1) / kThreads / 4;
197 InternalThreshold<<<kBlocks, kThreads, 0, stream->get()>>>(
198 decimated_image, reinterpret_cast<uchar2 *>(minmax_image),
199 thresholded_image, decimated_width, decimated_height,
200 min_white_black_diff);
201 MaybeCheckAndSynchronize();
202 }
203}
204
205} // namespace apriltag
206} // namespace frc971