Expose the underlying shared memory buffers from ShmEventLoop

For advanced use cases, this is handy to set up these memory regions
specially.

Change-Id: I7e858d3af5f3fa51f980e0e4cab5dfcfa7b83fd9
diff --git a/aos/events/event_loop.cc b/aos/events/event_loop.cc
index e368df2..ec23afc 100644
--- a/aos/events/event_loop.cc
+++ b/aos/events/event_loop.cc
@@ -81,6 +81,16 @@
   return std::distance(configuration()->channels()->begin(), c);
 }
 
+WatcherState *EventLoop::GetWatcherState(const Channel *channel) {
+  const int channel_index = ChannelIndex(channel);
+  for (const std::unique_ptr<WatcherState> &watcher : watchers_) {
+    if (watcher->channel_index() == channel_index) {
+      return watcher.get();
+    }
+  }
+  LOG(FATAL) << "No watcher found for channel";
+}
+
 void EventLoop::NewSender(RawSender *sender) {
   senders_.emplace_back(sender);
   UpdateTimingReport();
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index 4a12096..ed365c3 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -1,4 +1,5 @@
 #ifndef AOS_EVENTS_EVENT_LOOP_H_
+
 #define AOS_EVENTS_EVENT_LOOP_H_
 
 #include <atomic>
@@ -482,6 +483,16 @@
   // Validates that channel exists inside configuration_ and finds its index.
   int ChannelIndex(const Channel *channel);
 
+  // Returns the state for the watcher on the corresponding channel. This
+  // watcher must exist before calling this.
+  WatcherState *GetWatcherState(const Channel *channel);
+
+  // Returns a Sender's protected RawSender
+  template <typename T>
+  static RawSender *GetRawSender(aos::Sender<T> *sender) {
+    return sender->sender_.get();
+  }
+
   // Context available for watchers, timers, and phased loops.
   Context context_;
 
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index cc11520..c125792 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -123,6 +123,10 @@
 
   const ipc_lib::LocklessQueueConfiguration &config() const { return config_; }
 
+  absl::Span<char> GetSharedMemory() const {
+    return absl::Span<char>(static_cast<char *>(data_), size_);
+  }
+
  private:
   void MkdirP(std::string_view path) {
     auto last_slash_pos = path.find_last_of("/");
@@ -314,6 +318,10 @@
 
   void UnregisterWakeup() { lockless_queue_.UnregisterWakeup(); }
 
+  absl::Span<char> GetSharedMemory() const {
+    return lockless_queue_memory_.GetSharedMemory();
+  }
+
  private:
   const Channel *const channel_;
   MMapedQueue lockless_queue_memory_;
@@ -402,6 +410,10 @@
     return true;
   }
 
+  absl::Span<char> GetSharedMemory() const {
+    return lockless_queue_memory_.GetSharedMemory();
+  }
+
  private:
   MMapedQueue lockless_queue_memory_;
   ipc_lib::LocklessQueue lockless_queue_;
@@ -456,6 +468,10 @@
 
   void UnregisterWakeup() { return simple_shm_fetcher_.UnregisterWakeup(); }
 
+  absl::Span<char> GetSharedMemory() const {
+    return simple_shm_fetcher_.GetSharedMemory();
+  }
+
  private:
   bool has_new_data_ = false;
 
@@ -708,7 +724,8 @@
     ScopedSignalMask mask({SIGINT, SIGHUP, SIGTERM});
     std::unique_lock<stl_mutex> locker(mutex_);
 
-    event_loops_.erase(std::find(event_loops_.begin(), event_loops_.end(), event_loop));
+    event_loops_.erase(
+        std::find(event_loops_.begin(), event_loops_.end(), event_loop));
 
     if (event_loops_.size() == 0u) {
       // The last caller restores the original signal handlers.
@@ -836,6 +853,17 @@
   UpdateTimingReport();
 }
 
+absl::Span<char> ShmEventLoop::GetWatcherSharedMemory(const Channel *channel) {
+  internal::WatcherState *const watcher_state =
+      static_cast<internal::WatcherState *>(GetWatcherState(channel));
+  return watcher_state->GetSharedMemory();
+}
+
+absl::Span<char> ShmEventLoop::GetShmSenderSharedMemory(
+    const aos::RawSender *sender) const {
+  return static_cast<const internal::ShmSender *>(sender)->GetSharedMemory();
+}
+
 pid_t ShmEventLoop::GetTid() { return syscall(SYS_gettid); }
 
 }  // namespace aos
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index 52bb338..fa870b8 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -3,6 +3,8 @@
 
 #include <vector>
 
+#include "absl/types/span.h"
+
 #include "aos/events/epoll.h"
 #include "aos/events/event_loop.h"
 #include "aos/events/event_loop_generated.h"
@@ -68,8 +70,20 @@
 
   int priority() const override { return priority_; }
 
+  // Returns the epoll loop used to run the event loop.
   internal::EPoll *epoll() { return &epoll_; }
 
+  // Returns the local mapping of the shared memory used by the watcher on the
+  // specified channel. A watcher must be created on this channel before calling
+  // this.
+  absl::Span<char> GetWatcherSharedMemory(const Channel *channel);
+
+  // Returns the local mapping of the shared memory used by the provided Sender
+  template <typename T>
+  absl::Span<char> GetSenderSharedMemory(aos::Sender<T> *sender) const {
+    return GetShmSenderSharedMemory(GetRawSender(sender));
+  }
+
  private:
   friend class internal::WatcherState;
   friend class internal::TimerHandlerState;
@@ -82,6 +96,9 @@
   // Returns the TID of the event loop.
   pid_t GetTid() override;
 
+  // Private method to access the shared memory mapping of a ShmSender
+  absl::Span<char> GetShmSenderSharedMemory(const aos::RawSender *sender) const;
+
   std::vector<std::function<void()>> on_run_;
   int priority_ = 0;
   std::string name_;
@@ -90,7 +107,6 @@
   internal::EPoll epoll_;
 };
 
-
 }  // namespace aos
 
 #endif  // AOS_EVENTS_SHM_EVENT_LOOP_H_
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index d33f496..949d7db 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -181,6 +181,34 @@
   EXPECT_EQ(times.size(), 2u);
 }
 
+// Test GetWatcherSharedMemory in a few basic scenarios.
+TEST(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
+  ShmEventLoopTestFactory factory;
+  auto generic_loop1 = factory.MakePrimary("primary");
+  ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+  const auto channel = configuration::GetChannel(
+      loop1->configuration(), "/test", TestMessage::GetFullyQualifiedName(),
+      loop1->name(), loop1->node());
+
+  // First verify it handles an invalid channel reasonably.
+  EXPECT_DEATH(loop1->GetWatcherSharedMemory(channel),
+               "No watcher found for channel");
+
+  // Then, actually create a watcher, and verify it returns something sane.
+  loop1->MakeWatcher("/test", [](const TestMessage &) {});
+  EXPECT_FALSE(loop1->GetWatcherSharedMemory(channel).empty());
+}
+
+TEST(ShmEventLoopTest, GetSenderSharedMemory) {
+  ShmEventLoopTestFactory factory;
+  auto generic_loop1 = factory.MakePrimary("primary");
+  ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+
+  // check that GetSenderSharedMemory returns non-null/non-empty memory span
+  auto sender = loop1->MakeSender<TestMessage>("/test");
+  EXPECT_FALSE(loop1->GetSenderSharedMemory(&sender).empty());
+}
+
 // TODO(austin): Test that missing a deadline with a timer recovers as expected.
 
 }  // namespace testing