blob: 374647405631b90fcda0419cd2e84cd14c0a983f [file] [log] [blame]
Maxwell Henderson5e7696f2023-01-29 12:14:27 -08001#include "Halide.h"
2
3#define CHECK(x, message, ...) \
4 do { \
5 if (!(x)) { \
6 fprintf(stderr, "assertion failed: " message ": %s\n", ##__VA_ARGS__, \
7 #x); \
8 abort(); \
9 } \
10 } while (0)
11
12// This is a Halide "generator". This means it is a binary which generates
13// ahead-of-time optimized functions as directed by command-line arguments.
14// https://halide-lang.org/tutorials/tutorial_lesson_15_generators.html has an
15// introduction to much of the magic in this file.
16
Stephan Pleinesf63bde82024-01-13 15:59:33 -080017namespace frc971::vision {
Maxwell Henderson5e7696f2023-01-29 12:14:27 -080018namespace {
19
20// Returns a function implementating a 1-dimensional gaussian blur convolution.
21Halide::Func GenerateBlur(std::string name, Halide::Func in, int col_step,
22 int row_step, int radius, std::vector<float> kernel,
23 Halide::Var col, Halide::Var row) {
24 Halide::Expr expr = kernel[0] * in(col, row);
25 for (int i = 1; i <= radius; ++i) {
26 expr += kernel[0] * (in(col - i * col_step, row - i * row_step) +
27 in(col + i * col_step, row + i * row_step));
28 }
29 Halide::Func func(name);
30 func(col, row) = expr;
31 return func;
32}
33
34template <typename T>
35void SetRowMajor(T *buffer_parameter, int cols, int rows) {
36 buffer_parameter->dim(0).set_stride(1);
37 buffer_parameter->dim(0).set_extent(cols);
38 buffer_parameter->dim(0).set_min(0);
39 buffer_parameter->dim(1).set_stride(cols);
40 buffer_parameter->dim(1).set_extent(rows);
41 buffer_parameter->dim(1).set_min(0);
42}
43
44} // namespace
45
46class DecimateGenerator : public Halide::Generator<DecimateGenerator> {
47 public:
48 GeneratorParam<int> cols{"cols", 0};
49 GeneratorParam<int> rows{"rows", 0};
50
51 Input<Buffer<uint8_t>> input{"input", 3};
52 Output<Buffer<uint8_t>> output{"output", 2};
53 Output<Buffer<uint8_t>> decimated_output{"decimated_output", 2};
54
55 Var col{"col"}, row{"row"};
56
57 void generate() {
58 CHECK(cols > 0, "Must specify a cols");
59 CHECK(rows > 0, "Must specify a rows");
60
61 input.dim(0).set_stride(2);
62 input.dim(0).set_extent(cols);
63 input.dim(0).set_min(0);
64
65 input.dim(1).set_stride(cols * 2);
66 input.dim(1).set_extent(rows);
67 input.dim(1).set_min(0);
68
69 input.dim(2).set_stride(1);
70 input.dim(2).set_extent(2);
71 input.dim(2).set_min(0);
72
73 output(col, row) = input(col, row, 0);
74 decimated_output(col, row) = output(col * 2, row * 2);
75
76 decimated_output.compute_at(output, col);
77
78 decimated_output.vectorize(col, 16);
79
80 SetRowMajor(&output, cols, rows);
81
82 SetRowMajor(&decimated_output, cols / 2, rows / 2);
83 }
84};
85
86class ThresholdGenerator : public Halide::Generator<ThresholdGenerator> {
87 public:
88 GeneratorParam<int> rows{"rows", 0};
89 GeneratorParam<int> cols{"cols", 0};
90
91 Input<Buffer<uint8_t>> input{"input", 2};
92 Output<Buffer<uint8_t>> output{"output", 2};
93
94 Var x{"x"}, y{"y"};
95
96 Func threshold{"threshold"}, threshold_max{"threshold_max"},
97 threshold_min{"threshold_min"},
98 convoluted_threshold_max{"convoluted_threshold_max"},
99 convoluted_threshold_min{"convoluted_threshold_min"};
100
101 void generate() {
102 CHECK(cols > 0, "Columns must be more than 0");
103 CHECK(rows > 0, "Rows must be more than 0");
104
105 const int tile_size = 4;
106
107 RDom r(0, tile_size, 0, tile_size);
108
109 threshold_max(x, y) =
110 maximum(input(r.x + x * tile_size, r.y + y * tile_size));
111 threshold_min(x, y) =
112 minimum(input(r.x + x * tile_size, r.y + y * tile_size));
113
114 RDom r_conv(-1, 3, -1, 3);
115
116 convoluted_threshold_max(x, y) =
117 maximum(threshold_max(clamp(x + r_conv.x, 0, cols / tile_size - 1),
118 clamp(y + r_conv.y, 0, rows / tile_size - 1)));
119
120 convoluted_threshold_min(x, y) =
121 minimum(threshold_min(clamp(x + r_conv.x, 0, cols / tile_size - 1),
122 clamp(y + r_conv.y, 0, rows / tile_size - 1)));
123
124 threshold(x, y) =
125 convoluted_threshold_min(x, y) +
126 (convoluted_threshold_max(x, y) - convoluted_threshold_min(x, y)) / 2;
127
128 output(x, y) =
129 select(convoluted_threshold_max(x / tile_size, y / tile_size) -
130 convoluted_threshold_min(x / tile_size, y / tile_size) <
131 5,
132 Expr((uint8_t)(127)),
133 select(input(x, y) > threshold(x / tile_size, y / tile_size),
134 Expr((uint8_t)(255)), Expr((uint8_t)(0))));
135
136 SetRowMajor(&output, cols, rows);
137
138 Var xi, yi;
139
140 output.compute_root().tile(x, y, xi, yi, tile_size, tile_size);
141 threshold.compute_root();
142 convoluted_threshold_min.compute_root();
143 convoluted_threshold_max.compute_root();
144 threshold_min.compute_root();
145 threshold_max.compute_root();
146 }
147};
148
Stephan Pleinesf63bde82024-01-13 15:59:33 -0800149} // namespace frc971::vision
Maxwell Henderson5e7696f2023-01-29 12:14:27 -0800150
151// TODO(austin): Combine the functions and optimize for device/host and all that
152// jazz.
153HALIDE_REGISTER_GENERATOR(frc971::vision::DecimateGenerator, decimate_generator)
154HALIDE_REGISTER_GENERATOR(frc971::vision::ThresholdGenerator,
155 threshold_generator)