blob: 99318e6ea9ae9518f0cd154ad4987158963e8e7a [file] [log] [blame]
Austin Schuhb02d2332023-09-05 22:18:35 -07001#include <iostream>
2
3#include "Halide.h"
4
5#define CHECK(x, message, ...) \
6 do { \
7 if (!(x)) { \
8 fprintf(stderr, "assertion failed: " message ": %s\n", ##__VA_ARGS__, \
9 #x); \
10 abort(); \
11 } \
12 } while (0)
13
14// This is a Halide "generator". This means it is a binary which generates
15// ahead-of-time optimized functions as directed by command-line arguments.
16// https://halide-lang.org/tutorials/tutorial_lesson_15_generators.html has an
17// introduction to much of the magic in this file.
18namespace frc971 {
19namespace orin {
20namespace {
21
22template <typename T>
23void SetRowMajor(T *buffer_parameter, int cols, int rows) {
24 buffer_parameter->dim(0).set_stride(3);
25 buffer_parameter->dim(0).set_extent(cols);
26 buffer_parameter->dim(0).set_min(0);
27
28 buffer_parameter->dim(1).set_stride(cols * 3);
29 buffer_parameter->dim(1).set_extent(rows);
30 buffer_parameter->dim(1).set_min(0);
31
32 buffer_parameter->dim(2).set_stride(1);
33 buffer_parameter->dim(2).set_extent(3);
34 buffer_parameter->dim(2).set_min(0);
35}
36} // namespace
37
38// Takes an image with y in one plane with a provided stride, and cbcr in
39// another with a provided stride and makes a ycbcr output image.
40class YCbCr : public Halide::Generator<YCbCr> {
41 public:
42 GeneratorParam<int> cols{"cols", 0};
43 GeneratorParam<int> rows{"rows", 0};
44 GeneratorParam<int> ystride{"ystride", 0};
45 GeneratorParam<int> cbcrstride{"cbcrstride", 0};
46
47 Input<Buffer<uint8_t, 2>> input_y{"y"};
48 Input<Buffer<uint8_t, 3>> input_cbcr{"cbcr"};
49 Output<Buffer<uint8_t, 3>> output{"output"};
50
51 Var col{"col"}, row{"row"}, channel{"channel"};
52
53 // Everything is indexed as col, row, channel.
54 void generate() {
55 CHECK(cols > 0, "Must specify a cols");
56 CHECK(rows > 0, "Must specify a rows");
57
58 input_y.dim(0).set_stride(1);
59 input_y.dim(0).set_extent(cols);
60 input_y.dim(0).set_min(0);
61
62 input_y.dim(1).set_stride(ystride);
63 input_y.dim(1).set_extent(rows);
64 input_y.dim(1).set_min(0);
65
66 input_cbcr.dim(0).set_stride(2);
67 input_cbcr.dim(0).set_extent(cols);
68 input_cbcr.dim(0).set_min(0);
69
70 input_cbcr.dim(1).set_stride(cbcrstride);
71 input_cbcr.dim(1).set_extent(rows);
72 input_cbcr.dim(1).set_min(0);
73
74 input_cbcr.dim(2).set_stride(1);
75 input_cbcr.dim(2).set_extent(2);
76 input_cbcr.dim(2).set_min(0);
77
78 output(col, row, channel) =
79 Halide::select(channel == 0, input_y(col, row),
80 Halide::select(channel == 1, input_cbcr(col, row, 0),
81 input_cbcr(col, row, 1)));
82
83 output.reorder(channel, col, row);
84 output.unroll(channel);
85
86 output.vectorize(col, 8);
87 output.unroll(col, 4);
88
89 SetRowMajor(&output, cols, rows);
90 }
91};
92
93} // namespace orin
94} // namespace frc971
95
96HALIDE_REGISTER_GENERATOR(frc971::orin::YCbCr, ycbcr)