Shard image thresholding
It's down to ~9ms per frame now.
Change-Id: If5b0b51105d3d9e8a2435b1f077e46eeb9f3e94a
diff --git a/aos/vision/blob/threshold.cc b/aos/vision/blob/threshold.cc
index 36dcafe..46221b5 100644
--- a/aos/vision/blob/threshold.cc
+++ b/aos/vision/blob/threshold.cc
@@ -30,6 +30,7 @@
// The per-channel (YUYV) values in the current chunk.
uint8_t chunk_channels[2 * kChunkSize];
memcpy(&chunk_channels[0], current_row + x * kChunkSize * 2, 2 * kChunkSize);
+ __builtin_prefetch(current_row + (x + 1) * kChunkSize * 2);
for (int i = 0; i < kChunkSize; ++i) {
if ((chunk_channels[i * 2] > value) != in_range) {
@@ -51,5 +52,74 @@
return RangeImage(0, std::move(result));
}
+FastYuyvYPooledThresholder::FastYuyvYPooledThresholder() {
+ states_.fill(ThreadState::kWaitingForInputData);
+ for (int i = 0; i < kThreads; ++i) {
+ threads_[i] = std::thread([this, i]() { RunThread(i); });
+ }
+}
+
+FastYuyvYPooledThresholder::~FastYuyvYPooledThresholder() {
+ {
+ std::unique_lock<std::mutex> locker(mutex_);
+ quit_ = true;
+ condition_variable_.notify_all();
+ }
+ for (int i = 0; i < kThreads; ++i) {
+ threads_[i].join();
+ }
+}
+
+RangeImage FastYuyvYPooledThresholder::Threshold(ImageFormat fmt,
+ const char *data,
+ uint8_t value) {
+ input_format_ = fmt;
+ input_data_ = data;
+ input_value_ = value;
+ {
+ std::unique_lock<std::mutex> locker(mutex_);
+ for (int i = 0; i < kThreads; ++i) {
+ states_[i] = ThreadState::kProcessing;
+ }
+ condition_variable_.notify_all();
+ while (!AllThreadsDone()) {
+ condition_variable_.wait(locker);
+ }
+ }
+ std::vector<std::vector<ImageRange>> result;
+ result.reserve(fmt.h);
+ for (int i = 0; i < kThreads; ++i) {
+ result.insert(result.end(), outputs_[i].begin(), outputs_[i].end());
+ }
+ return RangeImage(0, std::move(result));
+}
+
+void FastYuyvYPooledThresholder::RunThread(int i) {
+ while (true) {
+ {
+ std::unique_lock<std::mutex> locker(mutex_);
+ while (states_[i] == ThreadState::kWaitingForInputData) {
+ if (quit_) {
+ return;
+ }
+ condition_variable_.wait(locker);
+ }
+ }
+
+ ImageFormat shard_format = input_format_;
+ CHECK_EQ(shard_format.h % kThreads, 0);
+ shard_format.h /= kThreads;
+
+ outputs_[i] = FastYuyvYThreshold(
+ shard_format, input_data_ + shard_format.w * 2 * shard_format.h * i,
+ input_value_);
+ {
+ std::unique_lock<std::mutex> locker(mutex_);
+ states_[i] = ThreadState::kWaitingForInputData;
+ condition_variable_.notify_all();
+ }
+ }
+}
+
} // namespace vision
} // namespace aos
diff --git a/aos/vision/blob/threshold.h b/aos/vision/blob/threshold.h
index 9891722..8251b3a 100644
--- a/aos/vision/blob/threshold.h
+++ b/aos/vision/blob/threshold.h
@@ -1,6 +1,10 @@
#ifndef AOS_VISION_BLOB_THRESHOLD_H_
#define AOS_VISION_BLOB_THRESHOLD_H_
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+
#include "aos/vision/blob/range_image.h"
#include "aos/vision/image/image_types.h"
@@ -80,6 +84,58 @@
// This is implemented via some tricky bit shuffling that goes fast.
RangeImage FastYuyvYThreshold(ImageFormat fmt, const char *data, uint8_t value);
+// Manages a pool of threads which do sharded thresholding.
+class FastYuyvYPooledThresholder {
+ public:
+ // The number of threads we'll use.
+ static constexpr int kThreads = 4;
+
+ FastYuyvYPooledThresholder();
+ // Shuts down and joins the threads.
+ ~FastYuyvYPooledThresholder();
+
+ // Actually does a threshold, merges the result, and returns it.
+ RangeImage Threshold(ImageFormat fmt, const char *data, uint8_t value);
+
+ private:
+ enum class ThreadState {
+ // Each thread moves itself into this state once it's done processing the
+ // previous input data.
+ kWaitingForInputData,
+ // The main thread moves all the threads into this state once it has
+ // finished setting up new input data.
+ kProcessing,
+ };
+
+ // The main function for a thread.
+ void RunThread(int index);
+
+ // Returns true if all threads are currently done.
+ bool AllThreadsDone() const {
+ for (ThreadState state : states_) {
+ if (state != ThreadState::kWaitingForInputData) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ std::array<std::thread, kThreads> threads_;
+ // Protects access to the states_ and coordinates with condition_variable_.
+ std::mutex mutex_;
+ // Signals changes to states_ and quit_.
+ std::condition_variable condition_variable_;
+ bool quit_ = false;
+
+ std::array<ThreadState, kThreads> states_;
+
+ // Access to these is protected by coordination via states_.
+ ImageFormat input_format_;
+ const char *input_data_;
+ uint8_t input_value_;
+ std::array<RangeImage, kThreads> outputs_;
+};
+
} // namespace vision
} // namespace aos