blob: 4ac389a0cecb43219588ad3abd49cdeb39562cf6 [file] [log] [blame]
Maxwell Henderson5e7696f2023-01-29 12:14:27 -08001#include "Halide.h"
2
3#define CHECK(x, message, ...) \
4 do { \
5 if (!(x)) { \
6 fprintf(stderr, "assertion failed: " message ": %s\n", ##__VA_ARGS__, \
7 #x); \
8 abort(); \
9 } \
10 } while (0)
11
12// This is a Halide "generator". This means it is a binary which generates
13// ahead-of-time optimized functions as directed by command-line arguments.
14// https://halide-lang.org/tutorials/tutorial_lesson_15_generators.html has an
15// introduction to much of the magic in this file.
16
17namespace frc971 {
18namespace vision {
19namespace {
20
21// Returns a function implementating a 1-dimensional gaussian blur convolution.
22Halide::Func GenerateBlur(std::string name, Halide::Func in, int col_step,
23 int row_step, int radius, std::vector<float> kernel,
24 Halide::Var col, Halide::Var row) {
25 Halide::Expr expr = kernel[0] * in(col, row);
26 for (int i = 1; i <= radius; ++i) {
27 expr += kernel[0] * (in(col - i * col_step, row - i * row_step) +
28 in(col + i * col_step, row + i * row_step));
29 }
30 Halide::Func func(name);
31 func(col, row) = expr;
32 return func;
33}
34
35template <typename T>
36void SetRowMajor(T *buffer_parameter, int cols, int rows) {
37 buffer_parameter->dim(0).set_stride(1);
38 buffer_parameter->dim(0).set_extent(cols);
39 buffer_parameter->dim(0).set_min(0);
40 buffer_parameter->dim(1).set_stride(cols);
41 buffer_parameter->dim(1).set_extent(rows);
42 buffer_parameter->dim(1).set_min(0);
43}
44
45} // namespace
46
47class DecimateGenerator : public Halide::Generator<DecimateGenerator> {
48 public:
49 GeneratorParam<int> cols{"cols", 0};
50 GeneratorParam<int> rows{"rows", 0};
51
52 Input<Buffer<uint8_t>> input{"input", 3};
53 Output<Buffer<uint8_t>> output{"output", 2};
54 Output<Buffer<uint8_t>> decimated_output{"decimated_output", 2};
55
56 Var col{"col"}, row{"row"};
57
58 void generate() {
59 CHECK(cols > 0, "Must specify a cols");
60 CHECK(rows > 0, "Must specify a rows");
61
62 input.dim(0).set_stride(2);
63 input.dim(0).set_extent(cols);
64 input.dim(0).set_min(0);
65
66 input.dim(1).set_stride(cols * 2);
67 input.dim(1).set_extent(rows);
68 input.dim(1).set_min(0);
69
70 input.dim(2).set_stride(1);
71 input.dim(2).set_extent(2);
72 input.dim(2).set_min(0);
73
74 output(col, row) = input(col, row, 0);
75 decimated_output(col, row) = output(col * 2, row * 2);
76
77 decimated_output.compute_at(output, col);
78
79 decimated_output.vectorize(col, 16);
80
81 SetRowMajor(&output, cols, rows);
82
83 SetRowMajor(&decimated_output, cols / 2, rows / 2);
84 }
85};
86
87class ThresholdGenerator : public Halide::Generator<ThresholdGenerator> {
88 public:
89 GeneratorParam<int> rows{"rows", 0};
90 GeneratorParam<int> cols{"cols", 0};
91
92 Input<Buffer<uint8_t>> input{"input", 2};
93 Output<Buffer<uint8_t>> output{"output", 2};
94
95 Var x{"x"}, y{"y"};
96
97 Func threshold{"threshold"}, threshold_max{"threshold_max"},
98 threshold_min{"threshold_min"},
99 convoluted_threshold_max{"convoluted_threshold_max"},
100 convoluted_threshold_min{"convoluted_threshold_min"};
101
102 void generate() {
103 CHECK(cols > 0, "Columns must be more than 0");
104 CHECK(rows > 0, "Rows must be more than 0");
105
106 const int tile_size = 4;
107
108 RDom r(0, tile_size, 0, tile_size);
109
110 threshold_max(x, y) =
111 maximum(input(r.x + x * tile_size, r.y + y * tile_size));
112 threshold_min(x, y) =
113 minimum(input(r.x + x * tile_size, r.y + y * tile_size));
114
115 RDom r_conv(-1, 3, -1, 3);
116
117 convoluted_threshold_max(x, y) =
118 maximum(threshold_max(clamp(x + r_conv.x, 0, cols / tile_size - 1),
119 clamp(y + r_conv.y, 0, rows / tile_size - 1)));
120
121 convoluted_threshold_min(x, y) =
122 minimum(threshold_min(clamp(x + r_conv.x, 0, cols / tile_size - 1),
123 clamp(y + r_conv.y, 0, rows / tile_size - 1)));
124
125 threshold(x, y) =
126 convoluted_threshold_min(x, y) +
127 (convoluted_threshold_max(x, y) - convoluted_threshold_min(x, y)) / 2;
128
129 output(x, y) =
130 select(convoluted_threshold_max(x / tile_size, y / tile_size) -
131 convoluted_threshold_min(x / tile_size, y / tile_size) <
132 5,
133 Expr((uint8_t)(127)),
134 select(input(x, y) > threshold(x / tile_size, y / tile_size),
135 Expr((uint8_t)(255)), Expr((uint8_t)(0))));
136
137 SetRowMajor(&output, cols, rows);
138
139 Var xi, yi;
140
141 output.compute_root().tile(x, y, xi, yi, tile_size, tile_size);
142 threshold.compute_root();
143 convoluted_threshold_min.compute_root();
144 convoluted_threshold_max.compute_root();
145 threshold_min.compute_root();
146 threshold_max.compute_root();
147 }
148};
149
150} // namespace vision
151} // namespace frc971
152
153// TODO(austin): Combine the functions and optimize for device/host and all that
154// jazz.
155HALIDE_REGISTER_GENERATOR(frc971::vision::DecimateGenerator, decimate_generator)
156HALIDE_REGISTER_GENERATOR(frc971::vision::ThresholdGenerator,
157 threshold_generator)