Maxwell Henderson | 5e7696f | 2023-01-29 12:14:27 -0800 | [diff] [blame] | 1 | #include "Halide.h" |
| 2 | |
| 3 | #define CHECK(x, message, ...) \ |
| 4 | do { \ |
| 5 | if (!(x)) { \ |
| 6 | fprintf(stderr, "assertion failed: " message ": %s\n", ##__VA_ARGS__, \ |
| 7 | #x); \ |
| 8 | abort(); \ |
| 9 | } \ |
| 10 | } while (0) |
| 11 | |
| 12 | // This is a Halide "generator". This means it is a binary which generates |
| 13 | // ahead-of-time optimized functions as directed by command-line arguments. |
| 14 | // https://halide-lang.org/tutorials/tutorial_lesson_15_generators.html has an |
| 15 | // introduction to much of the magic in this file. |
| 16 | |
Stephan Pleines | f63bde8 | 2024-01-13 15:59:33 -0800 | [diff] [blame] | 17 | namespace frc971::vision { |
Maxwell Henderson | 5e7696f | 2023-01-29 12:14:27 -0800 | [diff] [blame] | 18 | namespace { |
| 19 | |
| 20 | // Returns a function implementating a 1-dimensional gaussian blur convolution. |
| 21 | Halide::Func GenerateBlur(std::string name, Halide::Func in, int col_step, |
| 22 | int row_step, int radius, std::vector<float> kernel, |
| 23 | Halide::Var col, Halide::Var row) { |
| 24 | Halide::Expr expr = kernel[0] * in(col, row); |
| 25 | for (int i = 1; i <= radius; ++i) { |
| 26 | expr += kernel[0] * (in(col - i * col_step, row - i * row_step) + |
| 27 | in(col + i * col_step, row + i * row_step)); |
| 28 | } |
| 29 | Halide::Func func(name); |
| 30 | func(col, row) = expr; |
| 31 | return func; |
| 32 | } |
| 33 | |
| 34 | template <typename T> |
| 35 | void SetRowMajor(T *buffer_parameter, int cols, int rows) { |
| 36 | buffer_parameter->dim(0).set_stride(1); |
| 37 | buffer_parameter->dim(0).set_extent(cols); |
| 38 | buffer_parameter->dim(0).set_min(0); |
| 39 | buffer_parameter->dim(1).set_stride(cols); |
| 40 | buffer_parameter->dim(1).set_extent(rows); |
| 41 | buffer_parameter->dim(1).set_min(0); |
| 42 | } |
| 43 | |
| 44 | } // namespace |
| 45 | |
| 46 | class DecimateGenerator : public Halide::Generator<DecimateGenerator> { |
| 47 | public: |
| 48 | GeneratorParam<int> cols{"cols", 0}; |
| 49 | GeneratorParam<int> rows{"rows", 0}; |
| 50 | |
| 51 | Input<Buffer<uint8_t>> input{"input", 3}; |
| 52 | Output<Buffer<uint8_t>> output{"output", 2}; |
| 53 | Output<Buffer<uint8_t>> decimated_output{"decimated_output", 2}; |
| 54 | |
| 55 | Var col{"col"}, row{"row"}; |
| 56 | |
| 57 | void generate() { |
| 58 | CHECK(cols > 0, "Must specify a cols"); |
| 59 | CHECK(rows > 0, "Must specify a rows"); |
| 60 | |
| 61 | input.dim(0).set_stride(2); |
| 62 | input.dim(0).set_extent(cols); |
| 63 | input.dim(0).set_min(0); |
| 64 | |
| 65 | input.dim(1).set_stride(cols * 2); |
| 66 | input.dim(1).set_extent(rows); |
| 67 | input.dim(1).set_min(0); |
| 68 | |
| 69 | input.dim(2).set_stride(1); |
| 70 | input.dim(2).set_extent(2); |
| 71 | input.dim(2).set_min(0); |
| 72 | |
| 73 | output(col, row) = input(col, row, 0); |
| 74 | decimated_output(col, row) = output(col * 2, row * 2); |
| 75 | |
| 76 | decimated_output.compute_at(output, col); |
| 77 | |
| 78 | decimated_output.vectorize(col, 16); |
| 79 | |
| 80 | SetRowMajor(&output, cols, rows); |
| 81 | |
| 82 | SetRowMajor(&decimated_output, cols / 2, rows / 2); |
| 83 | } |
| 84 | }; |
| 85 | |
| 86 | class ThresholdGenerator : public Halide::Generator<ThresholdGenerator> { |
| 87 | public: |
| 88 | GeneratorParam<int> rows{"rows", 0}; |
| 89 | GeneratorParam<int> cols{"cols", 0}; |
| 90 | |
| 91 | Input<Buffer<uint8_t>> input{"input", 2}; |
| 92 | Output<Buffer<uint8_t>> output{"output", 2}; |
| 93 | |
| 94 | Var x{"x"}, y{"y"}; |
| 95 | |
| 96 | Func threshold{"threshold"}, threshold_max{"threshold_max"}, |
| 97 | threshold_min{"threshold_min"}, |
| 98 | convoluted_threshold_max{"convoluted_threshold_max"}, |
| 99 | convoluted_threshold_min{"convoluted_threshold_min"}; |
| 100 | |
| 101 | void generate() { |
| 102 | CHECK(cols > 0, "Columns must be more than 0"); |
| 103 | CHECK(rows > 0, "Rows must be more than 0"); |
| 104 | |
| 105 | const int tile_size = 4; |
| 106 | |
| 107 | RDom r(0, tile_size, 0, tile_size); |
| 108 | |
| 109 | threshold_max(x, y) = |
| 110 | maximum(input(r.x + x * tile_size, r.y + y * tile_size)); |
| 111 | threshold_min(x, y) = |
| 112 | minimum(input(r.x + x * tile_size, r.y + y * tile_size)); |
| 113 | |
| 114 | RDom r_conv(-1, 3, -1, 3); |
| 115 | |
| 116 | convoluted_threshold_max(x, y) = |
| 117 | maximum(threshold_max(clamp(x + r_conv.x, 0, cols / tile_size - 1), |
| 118 | clamp(y + r_conv.y, 0, rows / tile_size - 1))); |
| 119 | |
| 120 | convoluted_threshold_min(x, y) = |
| 121 | minimum(threshold_min(clamp(x + r_conv.x, 0, cols / tile_size - 1), |
| 122 | clamp(y + r_conv.y, 0, rows / tile_size - 1))); |
| 123 | |
| 124 | threshold(x, y) = |
| 125 | convoluted_threshold_min(x, y) + |
| 126 | (convoluted_threshold_max(x, y) - convoluted_threshold_min(x, y)) / 2; |
| 127 | |
| 128 | output(x, y) = |
| 129 | select(convoluted_threshold_max(x / tile_size, y / tile_size) - |
| 130 | convoluted_threshold_min(x / tile_size, y / tile_size) < |
| 131 | 5, |
| 132 | Expr((uint8_t)(127)), |
| 133 | select(input(x, y) > threshold(x / tile_size, y / tile_size), |
| 134 | Expr((uint8_t)(255)), Expr((uint8_t)(0)))); |
| 135 | |
| 136 | SetRowMajor(&output, cols, rows); |
| 137 | |
| 138 | Var xi, yi; |
| 139 | |
| 140 | output.compute_root().tile(x, y, xi, yi, tile_size, tile_size); |
| 141 | threshold.compute_root(); |
| 142 | convoluted_threshold_min.compute_root(); |
| 143 | convoluted_threshold_max.compute_root(); |
| 144 | threshold_min.compute_root(); |
| 145 | threshold_max.compute_root(); |
| 146 | } |
| 147 | }; |
| 148 | |
Stephan Pleines | f63bde8 | 2024-01-13 15:59:33 -0800 | [diff] [blame] | 149 | } // namespace frc971::vision |
Maxwell Henderson | 5e7696f | 2023-01-29 12:14:27 -0800 | [diff] [blame] | 150 | |
| 151 | // TODO(austin): Combine the functions and optimize for device/host and all that |
| 152 | // jazz. |
| 153 | HALIDE_REGISTER_GENERATOR(frc971::vision::DecimateGenerator, decimate_generator) |
| 154 | HALIDE_REGISTER_GENERATOR(frc971::vision::ThresholdGenerator, |
| 155 | threshold_generator) |