Teach ShmEventLoop how to validate some multithreading use cases
This helps catch bugs in callers who do these limited kinds of
multithreading.
Change-Id: Ie04790b2c46f0401430ed4c18a2d10845329623b
Signed-off-by: Austin Schuh <austin.schuh@bluerivertech.com>
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index d23314e..7cb6a5a 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -7,6 +7,7 @@
#include <string>
#include <string_view>
+#include "absl/container/btree_set.h"
#include "aos/configuration.h"
#include "aos/configuration_generated.h"
#include "aos/events/channel_preallocated_allocator.h"
@@ -20,8 +21,6 @@
#include "aos/time/time.h"
#include "aos/util/phased_loop.h"
#include "aos/uuid.h"
-
-#include "absl/container/btree_set.h"
#include "flatbuffers/flatbuffers.h"
#include "glog/logging.h"
@@ -115,6 +114,7 @@
protected:
EventLoop *event_loop() { return event_loop_; }
+ const EventLoop *event_loop() const { return event_loop_; }
Context context_;
@@ -190,6 +190,7 @@
protected:
EventLoop *event_loop() { return event_loop_; }
+ const EventLoop *event_loop() const { return event_loop_; }
monotonic_clock::time_point monotonic_sent_time_ = monotonic_clock::min_time;
realtime_clock::time_point realtime_sent_time_ = realtime_clock::min_time;
@@ -473,8 +474,13 @@
Ftrace ftrace_;
};
+// Note, it is supported to create only:
+// multiple fetchers, and (one sender or one watcher) per <name, type>
+// tuple.
class EventLoop {
public:
+ // Holds configuration by reference for the lifetime of this object. It may
+ // never be mutated externally in any way.
EventLoop(const Configuration *configuration);
virtual ~EventLoop();
@@ -495,10 +501,6 @@
return GetChannel<T>(channel_name) != nullptr;
}
- // Note, it is supported to create:
- // multiple fetchers, and (one sender or one watcher) per <name, type>
- // tuple.
-
// Makes a class that will always fetch the most recent value
// sent to the provided channel.
template <typename T>
@@ -596,7 +598,7 @@
// TODO(austin): OnExit for cleanup.
- // Threadsafe.
+ // May be safely called from any thread.
bool is_running() const { return is_running_.load(); }
// Sets the scheduler priority to run the event loop at. This may not be
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index 63e1cb9..cee1dfa 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -477,9 +477,13 @@
simple_shm_fetcher_.RetrieveData();
}
- ~ShmFetcher() { context_.data = nullptr; }
+ ~ShmFetcher() override {
+ shm_event_loop()->CheckCurrentThread();
+ context_.data = nullptr;
+ }
std::pair<bool, monotonic_clock::time_point> DoFetchNext() override {
+ shm_event_loop()->CheckCurrentThread();
if (simple_shm_fetcher_.FetchNext()) {
context_ = simple_shm_fetcher_.context();
return std::make_pair(true, monotonic_clock::now());
@@ -488,6 +492,7 @@
}
std::pair<bool, monotonic_clock::time_point> DoFetch() override {
+ shm_event_loop()->CheckCurrentThread();
if (simple_shm_fetcher_.Fetch()) {
context_ = simple_shm_fetcher_.context();
return std::make_pair(true, monotonic_clock::now());
@@ -500,6 +505,10 @@
}
private:
+ const ShmEventLoop *shm_event_loop() const {
+ return static_cast<const ShmEventLoop *>(event_loop());
+ }
+
SimpleShmFetcher simple_shm_fetcher_;
};
@@ -517,7 +526,7 @@
channel)),
wake_upper_(lockless_queue_memory_.queue()) {}
- ~ShmSender() override {}
+ ~ShmSender() override { shm_event_loop()->CheckCurrentThread(); }
static ipc_lib::LocklessQueueSender VerifySender(
std::optional<ipc_lib::LocklessQueueSender> sender,
@@ -530,13 +539,20 @@
<< ", too many senders.";
}
- void *data() override { return lockless_queue_sender_.Data(); }
- size_t size() override { return lockless_queue_sender_.size(); }
+ void *data() override {
+ shm_event_loop()->CheckCurrentThread();
+ return lockless_queue_sender_.Data();
+ }
+ size_t size() override {
+ shm_event_loop()->CheckCurrentThread();
+ return lockless_queue_sender_.size();
+ }
bool DoSend(size_t length,
aos::monotonic_clock::time_point monotonic_remote_time,
aos::realtime_clock::time_point realtime_remote_time,
uint32_t remote_queue_index,
const UUID &remote_boot_uuid) override {
+ shm_event_loop()->CheckCurrentThread();
CHECK_LE(length, static_cast<size_t>(channel()->max_size()))
<< ": Sent too big a message on "
<< configuration::CleanedChannelToString(channel());
@@ -556,6 +572,7 @@
aos::realtime_clock::time_point realtime_remote_time,
uint32_t remote_queue_index,
const UUID &remote_boot_uuid) override {
+ shm_event_loop()->CheckCurrentThread();
CHECK_LE(length, static_cast<size_t>(channel()->max_size()))
<< ": Sent too big a message on "
<< configuration::CleanedChannelToString(channel());
@@ -574,9 +591,16 @@
return lockless_queue_memory_.GetMutableSharedMemory();
}
- int buffer_index() override { return lockless_queue_sender_.buffer_index(); }
+ int buffer_index() override {
+ shm_event_loop()->CheckCurrentThread();
+ return lockless_queue_sender_.buffer_index();
+ }
private:
+ const ShmEventLoop *shm_event_loop() const {
+ return static_cast<const ShmEventLoop *>(event_loop());
+ }
+
MMappedQueue lockless_queue_memory_;
ipc_lib::LocklessQueueSender lockless_queue_sender_;
ipc_lib::LocklessQueueWakeUpper wake_upper_;
@@ -599,9 +623,13 @@
}
}
- ~ShmWatcherState() override { event_loop_->RemoveEvent(&event_); }
+ ~ShmWatcherState() override {
+ event_loop_->CheckCurrentThread();
+ event_loop_->RemoveEvent(&event_);
+ }
void Startup(EventLoop *event_loop) override {
+ event_loop_->CheckCurrentThread();
simple_shm_fetcher_.PointAtNextQueueIndex();
CHECK(RegisterWakeup(event_loop->priority()));
}
@@ -666,6 +694,7 @@
}
~ShmTimerHandler() {
+ shm_event_loop_->CheckCurrentThread();
Disable();
shm_event_loop_->epoll_.DeleteFd(timerfd_.fd());
}
@@ -705,6 +734,7 @@
void Setup(monotonic_clock::time_point base,
monotonic_clock::duration repeat_offset) override {
+ shm_event_loop_->CheckCurrentThread();
if (event_.valid()) {
shm_event_loop_->RemoveEvent(&event_);
}
@@ -717,6 +747,7 @@
}
void Disable() override {
+ shm_event_loop_->CheckCurrentThread();
shm_event_loop_->RemoveEvent(&event_);
timerfd_.Disable();
disabled_ = true;
@@ -766,6 +797,7 @@
}
~ShmPhasedLoopHandler() override {
+ shm_event_loop_->CheckCurrentThread();
shm_event_loop_->epoll_.DeleteFd(timerfd_.fd());
shm_event_loop_->RemoveEvent(&event_);
}
@@ -773,6 +805,7 @@
private:
// Reschedules the timer.
void Schedule(monotonic_clock::time_point sleep_time) override {
+ shm_event_loop_->CheckCurrentThread();
if (event_.valid()) {
shm_event_loop_->RemoveEvent(&event_);
}
@@ -792,6 +825,7 @@
::std::unique_ptr<RawFetcher> ShmEventLoop::MakeRawFetcher(
const Channel *channel) {
+ CheckCurrentThread();
if (!configuration::ChannelIsReadableOnNode(channel, node())) {
LOG(FATAL) << "Channel { \"name\": \"" << channel->name()->string_view()
<< "\", \"type\": \"" << channel->type()->string_view()
@@ -805,6 +839,7 @@
::std::unique_ptr<RawSender> ShmEventLoop::MakeRawSender(
const Channel *channel) {
+ CheckCurrentThread();
TakeSender(channel);
return ::std::unique_ptr<RawSender>(new ShmSender(shm_base_, this, channel));
@@ -813,6 +848,7 @@
void ShmEventLoop::MakeRawWatcher(
const Channel *channel,
std::function<void(const Context &context, const void *message)> watcher) {
+ CheckCurrentThread();
TakeWatcher(channel);
NewWatcher(::std::unique_ptr<WatcherState>(
@@ -822,6 +858,7 @@
void ShmEventLoop::MakeRawNoArgWatcher(
const Channel *channel,
std::function<void(const Context &context)> watcher) {
+ CheckCurrentThread();
TakeWatcher(channel);
NewWatcher(::std::unique_ptr<WatcherState>(new ShmWatcherState(
@@ -831,6 +868,7 @@
}
TimerHandler *ShmEventLoop::AddTimer(::std::function<void()> callback) {
+ CheckCurrentThread();
return NewTimer(::std::unique_ptr<TimerHandler>(
new ShmTimerHandler(this, ::std::move(callback))));
}
@@ -839,14 +877,28 @@
::std::function<void(int)> callback,
const monotonic_clock::duration interval,
const monotonic_clock::duration offset) {
+ CheckCurrentThread();
return NewPhasedLoop(::std::unique_ptr<PhasedLoopHandler>(
new ShmPhasedLoopHandler(this, ::std::move(callback), interval, offset)));
}
void ShmEventLoop::OnRun(::std::function<void()> on_run) {
+ CheckCurrentThread();
on_run_.push_back(::std::move(on_run));
}
+void ShmEventLoop::CheckCurrentThread() const {
+ if (__builtin_expect(check_mutex_ != nullptr, false)) {
+ CHECK(check_mutex_->is_locked())
+ << ": The configured mutex is not locked while calling a "
+ "ShmEventLoop function";
+ }
+ if (__builtin_expect(!!check_tid_, false)) {
+ CHECK_EQ(syscall(SYS_gettid), *check_tid_)
+ << ": Being called from the wrong thread";
+ }
+}
+
// This is a bit tricky because watchers can generate new events at any time (as
// long as it's in the past). We want to check the watchers at least once before
// declaring there are no events to handle, and we want to check them again if
@@ -1021,6 +1073,7 @@
};
void ShmEventLoop::Run() {
+ CheckCurrentThread();
SignalHandler::global()->Register(this);
if (watchers_.size() > 0) {
@@ -1100,6 +1153,7 @@
void ShmEventLoop::Exit() { epoll_.Quit(); }
ShmEventLoop::~ShmEventLoop() {
+ CheckCurrentThread();
// Force everything with a registered fd with epoll to be destroyed now.
timers_.clear();
phased_loops_.clear();
@@ -1109,6 +1163,7 @@
}
void ShmEventLoop::SetRuntimeRealtimePriority(int priority) {
+ CheckCurrentThread();
if (is_running()) {
LOG(FATAL) << "Cannot set realtime priority while running.";
}
@@ -1116,6 +1171,7 @@
}
void ShmEventLoop::SetRuntimeAffinity(const cpu_set_t &cpuset) {
+ CheckCurrentThread();
if (is_running()) {
LOG(FATAL) << "Cannot set affinity while running.";
}
@@ -1123,18 +1179,21 @@
}
void ShmEventLoop::set_name(const std::string_view name) {
+ CheckCurrentThread();
name_ = std::string(name);
UpdateTimingReport();
}
absl::Span<const char> ShmEventLoop::GetWatcherSharedMemory(
const Channel *channel) {
+ CheckCurrentThread();
ShmWatcherState *const watcher_state =
static_cast<ShmWatcherState *>(GetWatcherState(channel));
return watcher_state->GetSharedMemory();
}
int ShmEventLoop::NumberBuffers(const Channel *channel) {
+ CheckCurrentThread();
return MakeQueueConfiguration(
channel, chrono::ceil<chrono::seconds>(chrono::nanoseconds(
configuration()->channel_storage_duration())))
@@ -1143,14 +1202,19 @@
absl::Span<char> ShmEventLoop::GetShmSenderSharedMemory(
const aos::RawSender *sender) const {
+ CheckCurrentThread();
return static_cast<const ShmSender *>(sender)->GetSharedMemory();
}
absl::Span<const char> ShmEventLoop::GetShmFetcherPrivateMemory(
const aos::RawFetcher *fetcher) const {
+ CheckCurrentThread();
return static_cast<const ShmFetcher *>(fetcher)->GetPrivateMemory();
}
-pid_t ShmEventLoop::GetTid() { return syscall(SYS_gettid); }
+pid_t ShmEventLoop::GetTid() {
+ CheckCurrentThread();
+ return syscall(SYS_gettid);
+}
} // namespace aos
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index 845857c..3245f11 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -4,11 +4,11 @@
#include <vector>
#include "absl/types/span.h"
-
#include "aos/events/epoll.h"
#include "aos/events/event_loop.h"
#include "aos/events/event_loop_generated.h"
#include "aos/ipc_lib/signalfd.h"
+#include "aos/stl_mutex/stl_mutex.h"
DECLARE_string(application_name);
DECLARE_string(shm_base);
@@ -92,6 +92,7 @@
// Returns the local mapping of the shared memory used by the provided Sender.
template <typename T>
absl::Span<char> GetSenderSharedMemory(aos::Sender<T> *sender) const {
+ CheckCurrentThread();
return GetShmSenderSharedMemory(GetRawSender(sender));
}
@@ -103,11 +104,30 @@
template <typename T>
absl::Span<const char> GetFetcherPrivateMemory(
aos::Fetcher<T> *fetcher) const {
+ CheckCurrentThread();
return GetShmFetcherPrivateMemory(GetRawFetcher(fetcher));
}
int NumberBuffers(const Channel *channel) override;
+ // All public-facing APIs will verify this mutex is held when they are called.
+ // For normal use with everything in a single thread, this is unnecessary.
+ //
+ // This is helpful as a safety check when using a ShmEventLoop with external
+ // synchronization across multiple threads. It will NOT reliably catch race
+ // conditions, but if you have a race condition triggered repeatedly it'll
+ // probably catch it eventually.
+ void CheckForMutex(aos::stl_mutex *check_mutex) {
+ check_mutex_ = check_mutex;
+ }
+
+ // All public-facing APIs will verify they are called in this thread.
+ // For normal use with the whole program in a single thread, this is
+ // unnecessary. It's helpful as a safety check for programs with multiple
+ // threads, where the EventLoop should only be interacted with from a single
+ // one.
+ void LockToThread() { check_tid_ = GetTid(); }
+
private:
friend class shm_event_loop_internal::ShmWatcherState;
friend class shm_event_loop_internal::ShmTimerHandler;
@@ -126,6 +146,8 @@
return result;
}
+ void CheckCurrentThread() const;
+
void HandleEvent();
// Returns the TID of the event loop.
@@ -151,6 +173,9 @@
std::string name_;
const Node *const node_;
+ aos::stl_mutex *check_mutex_ = nullptr;
+ std::optional<pid_t> check_tid_;
+
internal::EPoll epoll_;
// Only set during Run().
diff --git a/aos/stl_mutex/stl_mutex.h b/aos/stl_mutex/stl_mutex.h
index 86a5988..e3a930e 100644
--- a/aos/stl_mutex/stl_mutex.h
+++ b/aos/stl_mutex/stl_mutex.h
@@ -61,6 +61,11 @@
bool owner_died() const { return owner_died_; }
void consistent() { owner_died_ = false; }
+ // Returns whether this mutex is locked by the current thread. This is very
+ // hard to use reliably, please think very carefully before using it for
+ // anything beyond probabilistic assertion checks.
+ bool is_locked() const { return mutex_islocked(&native_handle_); }
+
private:
aos_mutex native_handle_;
@@ -82,7 +87,7 @@
constexpr stl_recursive_mutex() {}
void lock() {
- if (mutex_islocked(mutex_.native_handle())) {
+ if (mutex_.is_locked()) {
CHECK(!owner_died());
++recursive_locks_;
} else {
@@ -95,7 +100,7 @@
}
}
bool try_lock() {
- if (mutex_islocked(mutex_.native_handle())) {
+ if (mutex_.is_locked()) {
CHECK(!owner_died());
++recursive_locks_;
return true;