Teach ShmEventLoop how to validate some multithreading use cases

This helps catch bugs in callers who do these limited kinds of
multithreading.

Change-Id: Ie04790b2c46f0401430ed4c18a2d10845329623b
Signed-off-by: Austin Schuh <austin.schuh@bluerivertech.com>
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index d23314e..7cb6a5a 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -7,6 +7,7 @@
 #include <string>
 #include <string_view>
 
+#include "absl/container/btree_set.h"
 #include "aos/configuration.h"
 #include "aos/configuration_generated.h"
 #include "aos/events/channel_preallocated_allocator.h"
@@ -20,8 +21,6 @@
 #include "aos/time/time.h"
 #include "aos/util/phased_loop.h"
 #include "aos/uuid.h"
-
-#include "absl/container/btree_set.h"
 #include "flatbuffers/flatbuffers.h"
 #include "glog/logging.h"
 
@@ -115,6 +114,7 @@
 
  protected:
   EventLoop *event_loop() { return event_loop_; }
+  const EventLoop *event_loop() const { return event_loop_; }
 
   Context context_;
 
@@ -190,6 +190,7 @@
 
  protected:
   EventLoop *event_loop() { return event_loop_; }
+  const EventLoop *event_loop() const { return event_loop_; }
 
   monotonic_clock::time_point monotonic_sent_time_ = monotonic_clock::min_time;
   realtime_clock::time_point realtime_sent_time_ = realtime_clock::min_time;
@@ -473,8 +474,13 @@
   Ftrace ftrace_;
 };
 
+// Note, it is supported to create only:
+//   multiple fetchers, and (one sender or one watcher) per <name, type>
+//   tuple.
 class EventLoop {
  public:
+  // Holds configuration by reference for the lifetime of this object. It may
+  // never be mutated externally in any way.
   EventLoop(const Configuration *configuration);
 
   virtual ~EventLoop();
@@ -495,10 +501,6 @@
     return GetChannel<T>(channel_name) != nullptr;
   }
 
-  // Note, it is supported to create:
-  //   multiple fetchers, and (one sender or one watcher) per <name, type>
-  //   tuple.
-
   // Makes a class that will always fetch the most recent value
   // sent to the provided channel.
   template <typename T>
@@ -596,7 +598,7 @@
 
   // TODO(austin): OnExit for cleanup.
 
-  // Threadsafe.
+  // May be safely called from any thread.
   bool is_running() const { return is_running_.load(); }
 
   // Sets the scheduler priority to run the event loop at.  This may not be
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index 63e1cb9..cee1dfa 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -477,9 +477,13 @@
     simple_shm_fetcher_.RetrieveData();
   }
 
-  ~ShmFetcher() { context_.data = nullptr; }
+  ~ShmFetcher() override {
+    shm_event_loop()->CheckCurrentThread();
+    context_.data = nullptr;
+  }
 
   std::pair<bool, monotonic_clock::time_point> DoFetchNext() override {
+    shm_event_loop()->CheckCurrentThread();
     if (simple_shm_fetcher_.FetchNext()) {
       context_ = simple_shm_fetcher_.context();
       return std::make_pair(true, monotonic_clock::now());
@@ -488,6 +492,7 @@
   }
 
   std::pair<bool, monotonic_clock::time_point> DoFetch() override {
+    shm_event_loop()->CheckCurrentThread();
     if (simple_shm_fetcher_.Fetch()) {
       context_ = simple_shm_fetcher_.context();
       return std::make_pair(true, monotonic_clock::now());
@@ -500,6 +505,10 @@
   }
 
  private:
+  const ShmEventLoop *shm_event_loop() const {
+    return static_cast<const ShmEventLoop *>(event_loop());
+  }
+
   SimpleShmFetcher simple_shm_fetcher_;
 };
 
@@ -517,7 +526,7 @@
             channel)),
         wake_upper_(lockless_queue_memory_.queue()) {}
 
-  ~ShmSender() override {}
+  ~ShmSender() override { shm_event_loop()->CheckCurrentThread(); }
 
   static ipc_lib::LocklessQueueSender VerifySender(
       std::optional<ipc_lib::LocklessQueueSender> sender,
@@ -530,13 +539,20 @@
                << ", too many senders.";
   }
 
-  void *data() override { return lockless_queue_sender_.Data(); }
-  size_t size() override { return lockless_queue_sender_.size(); }
+  void *data() override {
+    shm_event_loop()->CheckCurrentThread();
+    return lockless_queue_sender_.Data();
+  }
+  size_t size() override {
+    shm_event_loop()->CheckCurrentThread();
+    return lockless_queue_sender_.size();
+  }
   bool DoSend(size_t length,
               aos::monotonic_clock::time_point monotonic_remote_time,
               aos::realtime_clock::time_point realtime_remote_time,
               uint32_t remote_queue_index,
               const UUID &remote_boot_uuid) override {
+    shm_event_loop()->CheckCurrentThread();
     CHECK_LE(length, static_cast<size_t>(channel()->max_size()))
         << ": Sent too big a message on "
         << configuration::CleanedChannelToString(channel());
@@ -556,6 +572,7 @@
               aos::realtime_clock::time_point realtime_remote_time,
               uint32_t remote_queue_index,
               const UUID &remote_boot_uuid) override {
+    shm_event_loop()->CheckCurrentThread();
     CHECK_LE(length, static_cast<size_t>(channel()->max_size()))
         << ": Sent too big a message on "
         << configuration::CleanedChannelToString(channel());
@@ -574,9 +591,16 @@
     return lockless_queue_memory_.GetMutableSharedMemory();
   }
 
-  int buffer_index() override { return lockless_queue_sender_.buffer_index(); }
+  int buffer_index() override {
+    shm_event_loop()->CheckCurrentThread();
+    return lockless_queue_sender_.buffer_index();
+  }
 
  private:
+  const ShmEventLoop *shm_event_loop() const {
+    return static_cast<const ShmEventLoop *>(event_loop());
+  }
+
   MMappedQueue lockless_queue_memory_;
   ipc_lib::LocklessQueueSender lockless_queue_sender_;
   ipc_lib::LocklessQueueWakeUpper wake_upper_;
@@ -599,9 +623,13 @@
     }
   }
 
-  ~ShmWatcherState() override { event_loop_->RemoveEvent(&event_); }
+  ~ShmWatcherState() override {
+    event_loop_->CheckCurrentThread();
+    event_loop_->RemoveEvent(&event_);
+  }
 
   void Startup(EventLoop *event_loop) override {
+    event_loop_->CheckCurrentThread();
     simple_shm_fetcher_.PointAtNextQueueIndex();
     CHECK(RegisterWakeup(event_loop->priority()));
   }
@@ -666,6 +694,7 @@
   }
 
   ~ShmTimerHandler() {
+    shm_event_loop_->CheckCurrentThread();
     Disable();
     shm_event_loop_->epoll_.DeleteFd(timerfd_.fd());
   }
@@ -705,6 +734,7 @@
 
   void Setup(monotonic_clock::time_point base,
              monotonic_clock::duration repeat_offset) override {
+    shm_event_loop_->CheckCurrentThread();
     if (event_.valid()) {
       shm_event_loop_->RemoveEvent(&event_);
     }
@@ -717,6 +747,7 @@
   }
 
   void Disable() override {
+    shm_event_loop_->CheckCurrentThread();
     shm_event_loop_->RemoveEvent(&event_);
     timerfd_.Disable();
     disabled_ = true;
@@ -766,6 +797,7 @@
   }
 
   ~ShmPhasedLoopHandler() override {
+    shm_event_loop_->CheckCurrentThread();
     shm_event_loop_->epoll_.DeleteFd(timerfd_.fd());
     shm_event_loop_->RemoveEvent(&event_);
   }
@@ -773,6 +805,7 @@
  private:
   // Reschedules the timer.
   void Schedule(monotonic_clock::time_point sleep_time) override {
+    shm_event_loop_->CheckCurrentThread();
     if (event_.valid()) {
       shm_event_loop_->RemoveEvent(&event_);
     }
@@ -792,6 +825,7 @@
 
 ::std::unique_ptr<RawFetcher> ShmEventLoop::MakeRawFetcher(
     const Channel *channel) {
+  CheckCurrentThread();
   if (!configuration::ChannelIsReadableOnNode(channel, node())) {
     LOG(FATAL) << "Channel { \"name\": \"" << channel->name()->string_view()
                << "\", \"type\": \"" << channel->type()->string_view()
@@ -805,6 +839,7 @@
 
 ::std::unique_ptr<RawSender> ShmEventLoop::MakeRawSender(
     const Channel *channel) {
+  CheckCurrentThread();
   TakeSender(channel);
 
   return ::std::unique_ptr<RawSender>(new ShmSender(shm_base_, this, channel));
@@ -813,6 +848,7 @@
 void ShmEventLoop::MakeRawWatcher(
     const Channel *channel,
     std::function<void(const Context &context, const void *message)> watcher) {
+  CheckCurrentThread();
   TakeWatcher(channel);
 
   NewWatcher(::std::unique_ptr<WatcherState>(
@@ -822,6 +858,7 @@
 void ShmEventLoop::MakeRawNoArgWatcher(
     const Channel *channel,
     std::function<void(const Context &context)> watcher) {
+  CheckCurrentThread();
   TakeWatcher(channel);
 
   NewWatcher(::std::unique_ptr<WatcherState>(new ShmWatcherState(
@@ -831,6 +868,7 @@
 }
 
 TimerHandler *ShmEventLoop::AddTimer(::std::function<void()> callback) {
+  CheckCurrentThread();
   return NewTimer(::std::unique_ptr<TimerHandler>(
       new ShmTimerHandler(this, ::std::move(callback))));
 }
@@ -839,14 +877,28 @@
     ::std::function<void(int)> callback,
     const monotonic_clock::duration interval,
     const monotonic_clock::duration offset) {
+  CheckCurrentThread();
   return NewPhasedLoop(::std::unique_ptr<PhasedLoopHandler>(
       new ShmPhasedLoopHandler(this, ::std::move(callback), interval, offset)));
 }
 
 void ShmEventLoop::OnRun(::std::function<void()> on_run) {
+  CheckCurrentThread();
   on_run_.push_back(::std::move(on_run));
 }
 
+void ShmEventLoop::CheckCurrentThread() const {
+  if (__builtin_expect(check_mutex_ != nullptr, false)) {
+    CHECK(check_mutex_->is_locked())
+        << ": The configured mutex is not locked while calling a "
+           "ShmEventLoop function";
+  }
+  if (__builtin_expect(!!check_tid_, false)) {
+    CHECK_EQ(syscall(SYS_gettid), *check_tid_)
+        << ": Being called from the wrong thread";
+  }
+}
+
 // This is a bit tricky because watchers can generate new events at any time (as
 // long as it's in the past). We want to check the watchers at least once before
 // declaring there are no events to handle, and we want to check them again if
@@ -1021,6 +1073,7 @@
 };
 
 void ShmEventLoop::Run() {
+  CheckCurrentThread();
   SignalHandler::global()->Register(this);
 
   if (watchers_.size() > 0) {
@@ -1100,6 +1153,7 @@
 void ShmEventLoop::Exit() { epoll_.Quit(); }
 
 ShmEventLoop::~ShmEventLoop() {
+  CheckCurrentThread();
   // Force everything with a registered fd with epoll to be destroyed now.
   timers_.clear();
   phased_loops_.clear();
@@ -1109,6 +1163,7 @@
 }
 
 void ShmEventLoop::SetRuntimeRealtimePriority(int priority) {
+  CheckCurrentThread();
   if (is_running()) {
     LOG(FATAL) << "Cannot set realtime priority while running.";
   }
@@ -1116,6 +1171,7 @@
 }
 
 void ShmEventLoop::SetRuntimeAffinity(const cpu_set_t &cpuset) {
+  CheckCurrentThread();
   if (is_running()) {
     LOG(FATAL) << "Cannot set affinity while running.";
   }
@@ -1123,18 +1179,21 @@
 }
 
 void ShmEventLoop::set_name(const std::string_view name) {
+  CheckCurrentThread();
   name_ = std::string(name);
   UpdateTimingReport();
 }
 
 absl::Span<const char> ShmEventLoop::GetWatcherSharedMemory(
     const Channel *channel) {
+  CheckCurrentThread();
   ShmWatcherState *const watcher_state =
       static_cast<ShmWatcherState *>(GetWatcherState(channel));
   return watcher_state->GetSharedMemory();
 }
 
 int ShmEventLoop::NumberBuffers(const Channel *channel) {
+  CheckCurrentThread();
   return MakeQueueConfiguration(
              channel, chrono::ceil<chrono::seconds>(chrono::nanoseconds(
                           configuration()->channel_storage_duration())))
@@ -1143,14 +1202,19 @@
 
 absl::Span<char> ShmEventLoop::GetShmSenderSharedMemory(
     const aos::RawSender *sender) const {
+  CheckCurrentThread();
   return static_cast<const ShmSender *>(sender)->GetSharedMemory();
 }
 
 absl::Span<const char> ShmEventLoop::GetShmFetcherPrivateMemory(
     const aos::RawFetcher *fetcher) const {
+  CheckCurrentThread();
   return static_cast<const ShmFetcher *>(fetcher)->GetPrivateMemory();
 }
 
-pid_t ShmEventLoop::GetTid() { return syscall(SYS_gettid); }
+pid_t ShmEventLoop::GetTid() {
+  CheckCurrentThread();
+  return syscall(SYS_gettid);
+}
 
 }  // namespace aos
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index 845857c..3245f11 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -4,11 +4,11 @@
 #include <vector>
 
 #include "absl/types/span.h"
-
 #include "aos/events/epoll.h"
 #include "aos/events/event_loop.h"
 #include "aos/events/event_loop_generated.h"
 #include "aos/ipc_lib/signalfd.h"
+#include "aos/stl_mutex/stl_mutex.h"
 
 DECLARE_string(application_name);
 DECLARE_string(shm_base);
@@ -92,6 +92,7 @@
   // Returns the local mapping of the shared memory used by the provided Sender.
   template <typename T>
   absl::Span<char> GetSenderSharedMemory(aos::Sender<T> *sender) const {
+    CheckCurrentThread();
     return GetShmSenderSharedMemory(GetRawSender(sender));
   }
 
@@ -103,11 +104,30 @@
   template <typename T>
   absl::Span<const char> GetFetcherPrivateMemory(
       aos::Fetcher<T> *fetcher) const {
+    CheckCurrentThread();
     return GetShmFetcherPrivateMemory(GetRawFetcher(fetcher));
   }
 
   int NumberBuffers(const Channel *channel) override;
 
+  // All public-facing APIs will verify this mutex is held when they are called.
+  // For normal use with everything in a single thread, this is unnecessary.
+  //
+  // This is helpful as a safety check when using a ShmEventLoop with external
+  // synchronization across multiple threads. It will NOT reliably catch race
+  // conditions, but if you have a race condition triggered repeatedly it'll
+  // probably catch it eventually.
+  void CheckForMutex(aos::stl_mutex *check_mutex) {
+    check_mutex_ = check_mutex;
+  }
+
+  // All public-facing APIs will verify they are called in this thread.
+  // For normal use with the whole program in a single thread, this is
+  // unnecessary. It's helpful as a safety check for programs with multiple
+  // threads, where the EventLoop should only be interacted with from a single
+  // one.
+  void LockToThread() { check_tid_ = GetTid(); }
+
  private:
   friend class shm_event_loop_internal::ShmWatcherState;
   friend class shm_event_loop_internal::ShmTimerHandler;
@@ -126,6 +146,8 @@
     return result;
   }
 
+  void CheckCurrentThread() const;
+
   void HandleEvent();
 
   // Returns the TID of the event loop.
@@ -151,6 +173,9 @@
   std::string name_;
   const Node *const node_;
 
+  aos::stl_mutex *check_mutex_ = nullptr;
+  std::optional<pid_t> check_tid_;
+
   internal::EPoll epoll_;
 
   // Only set during Run().
diff --git a/aos/stl_mutex/stl_mutex.h b/aos/stl_mutex/stl_mutex.h
index 86a5988..e3a930e 100644
--- a/aos/stl_mutex/stl_mutex.h
+++ b/aos/stl_mutex/stl_mutex.h
@@ -61,6 +61,11 @@
   bool owner_died() const { return owner_died_; }
   void consistent() { owner_died_ = false; }
 
+  // Returns whether this mutex is locked by the current thread. This is very
+  // hard to use reliably, please think very carefully before using it for
+  // anything beyond probabilistic assertion checks.
+  bool is_locked() const { return mutex_islocked(&native_handle_); }
+
  private:
   aos_mutex native_handle_;
 
@@ -82,7 +87,7 @@
   constexpr stl_recursive_mutex() {}
 
   void lock() {
-    if (mutex_islocked(mutex_.native_handle())) {
+    if (mutex_.is_locked()) {
       CHECK(!owner_died());
       ++recursive_locks_;
     } else {
@@ -95,7 +100,7 @@
     }
   }
   bool try_lock() {
-    if (mutex_islocked(mutex_.native_handle())) {
+    if (mutex_.is_locked()) {
       CHECK(!owner_died());
       ++recursive_locks_;
       return true;