blob: 374647405631b90fcda0419cd2e84cd14c0a983f [file] [log] [blame]
#include "Halide.h"
#define CHECK(x, message, ...) \
do { \
if (!(x)) { \
fprintf(stderr, "assertion failed: " message ": %s\n", ##__VA_ARGS__, \
#x); \
abort(); \
} \
} while (0)
// This is a Halide "generator". This means it is a binary which generates
// ahead-of-time optimized functions as directed by command-line arguments.
// https://halide-lang.org/tutorials/tutorial_lesson_15_generators.html has an
// introduction to much of the magic in this file.
namespace frc971::vision {
namespace {
// Returns a function implementating a 1-dimensional gaussian blur convolution.
Halide::Func GenerateBlur(std::string name, Halide::Func in, int col_step,
int row_step, int radius, std::vector<float> kernel,
Halide::Var col, Halide::Var row) {
Halide::Expr expr = kernel[0] * in(col, row);
for (int i = 1; i <= radius; ++i) {
expr += kernel[0] * (in(col - i * col_step, row - i * row_step) +
in(col + i * col_step, row + i * row_step));
}
Halide::Func func(name);
func(col, row) = expr;
return func;
}
template <typename T>
void SetRowMajor(T *buffer_parameter, int cols, int rows) {
buffer_parameter->dim(0).set_stride(1);
buffer_parameter->dim(0).set_extent(cols);
buffer_parameter->dim(0).set_min(0);
buffer_parameter->dim(1).set_stride(cols);
buffer_parameter->dim(1).set_extent(rows);
buffer_parameter->dim(1).set_min(0);
}
} // namespace
class DecimateGenerator : public Halide::Generator<DecimateGenerator> {
public:
GeneratorParam<int> cols{"cols", 0};
GeneratorParam<int> rows{"rows", 0};
Input<Buffer<uint8_t>> input{"input", 3};
Output<Buffer<uint8_t>> output{"output", 2};
Output<Buffer<uint8_t>> decimated_output{"decimated_output", 2};
Var col{"col"}, row{"row"};
void generate() {
CHECK(cols > 0, "Must specify a cols");
CHECK(rows > 0, "Must specify a rows");
input.dim(0).set_stride(2);
input.dim(0).set_extent(cols);
input.dim(0).set_min(0);
input.dim(1).set_stride(cols * 2);
input.dim(1).set_extent(rows);
input.dim(1).set_min(0);
input.dim(2).set_stride(1);
input.dim(2).set_extent(2);
input.dim(2).set_min(0);
output(col, row) = input(col, row, 0);
decimated_output(col, row) = output(col * 2, row * 2);
decimated_output.compute_at(output, col);
decimated_output.vectorize(col, 16);
SetRowMajor(&output, cols, rows);
SetRowMajor(&decimated_output, cols / 2, rows / 2);
}
};
class ThresholdGenerator : public Halide::Generator<ThresholdGenerator> {
public:
GeneratorParam<int> rows{"rows", 0};
GeneratorParam<int> cols{"cols", 0};
Input<Buffer<uint8_t>> input{"input", 2};
Output<Buffer<uint8_t>> output{"output", 2};
Var x{"x"}, y{"y"};
Func threshold{"threshold"}, threshold_max{"threshold_max"},
threshold_min{"threshold_min"},
convoluted_threshold_max{"convoluted_threshold_max"},
convoluted_threshold_min{"convoluted_threshold_min"};
void generate() {
CHECK(cols > 0, "Columns must be more than 0");
CHECK(rows > 0, "Rows must be more than 0");
const int tile_size = 4;
RDom r(0, tile_size, 0, tile_size);
threshold_max(x, y) =
maximum(input(r.x + x * tile_size, r.y + y * tile_size));
threshold_min(x, y) =
minimum(input(r.x + x * tile_size, r.y + y * tile_size));
RDom r_conv(-1, 3, -1, 3);
convoluted_threshold_max(x, y) =
maximum(threshold_max(clamp(x + r_conv.x, 0, cols / tile_size - 1),
clamp(y + r_conv.y, 0, rows / tile_size - 1)));
convoluted_threshold_min(x, y) =
minimum(threshold_min(clamp(x + r_conv.x, 0, cols / tile_size - 1),
clamp(y + r_conv.y, 0, rows / tile_size - 1)));
threshold(x, y) =
convoluted_threshold_min(x, y) +
(convoluted_threshold_max(x, y) - convoluted_threshold_min(x, y)) / 2;
output(x, y) =
select(convoluted_threshold_max(x / tile_size, y / tile_size) -
convoluted_threshold_min(x / tile_size, y / tile_size) <
5,
Expr((uint8_t)(127)),
select(input(x, y) > threshold(x / tile_size, y / tile_size),
Expr((uint8_t)(255)), Expr((uint8_t)(0))));
SetRowMajor(&output, cols, rows);
Var xi, yi;
output.compute_root().tile(x, y, xi, yi, tile_size, tile_size);
threshold.compute_root();
convoluted_threshold_min.compute_root();
convoluted_threshold_max.compute_root();
threshold_min.compute_root();
threshold_max.compute_root();
}
};
} // namespace frc971::vision
// TODO(austin): Combine the functions and optimize for device/host and all that
// jazz.
HALIDE_REGISTER_GENERATOR(frc971::vision::DecimateGenerator, decimate_generator)
HALIDE_REGISTER_GENERATOR(frc971::vision::ThresholdGenerator,
threshold_generator)