Merge changes I47aa260a,I7e858d3a

* changes:
  Add support for writable events on an FD to EPoll
  Expose the underlying shared memory buffers from ShmEventLoop
diff --git a/aos/events/BUILD b/aos/events/BUILD
index a4c24ce..9cfdf10 100644
--- a/aos/events/BUILD
+++ b/aos/events/BUILD
@@ -41,6 +41,16 @@
     ],
 )
 
+cc_test(
+    name = "epoll_test",
+    srcs = ["epoll_test.cc"],
+    deps = [
+        ":epoll",
+        "//aos/testing:googletest",
+        "@com_github_google_glog//:glog",
+    ],
+)
+
 cc_library(
     name = "event_loop",
     srcs = [
@@ -215,6 +225,7 @@
         "//aos/ipc_lib:signalfd",
         "//aos/stl_mutex",
         "//aos/util:phased_loop",
+        "@com_google_absl//absl/base",
     ],
 )
 
diff --git a/aos/events/epoll.cc b/aos/events/epoll.cc
index e19cf59..b9e7edb 100644
--- a/aos/events/epoll.cc
+++ b/aos/events/epoll.cc
@@ -71,6 +71,7 @@
 }
 
 void EPoll::Run() {
+  run_ = true;
   while (true) {
     // Pull a single event out.  Infinite timeout if we are supposed to be
     // running, and 0 length timeout otherwise.  This lets us flush the event
@@ -90,34 +91,53 @@
         return;
       }
     }
-    EventData *event_data = static_cast<struct EventData *>(event.data.ptr);
-    if (event.events & (EPOLLIN | EPOLLPRI)) {
+    EventData *const event_data = static_cast<struct EventData *>(event.data.ptr);
+    if (event.events & kInEvents) {
+      CHECK(event_data->in_fn)
+          << ": No handler registered for input events on " << event_data->fd;
       event_data->in_fn();
     }
+    if (event.events & kOutEvents) {
+      CHECK(event_data->out_fn)
+          << ": No handler registered for output events on " << event_data->fd;
+      event_data->out_fn();
+    }
   }
 }
 
 void EPoll::Quit() { PCHECK(write(quit_signal_fd_, "q", 1) == 1); }
 
 void EPoll::OnReadable(int fd, ::std::function<void()> function) {
-  ::std::unique_ptr<EventData> event_data(
-      new EventData{fd, ::std::move(function)});
+  EventData *event_data = GetEventData(fd);
+  if (event_data == nullptr) {
+    fns_.emplace_back(std::make_unique<EventData>(fd));
+    event_data = fns_.back().get();
+  } else {
+    CHECK(!event_data->in_fn) << ": Duplicate in functions for " << fd;
+  }
+  event_data->in_fn = ::std::move(function);
+  DoEpollCtl(event_data, event_data->events | kInEvents);
+}
 
-  struct epoll_event event;
-  event.events = EPOLLIN | EPOLLPRI;
-  event.data.ptr = event_data.get();
-  PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == 0)
-      << ": Failed to add fd " << fd;
-  fns_.push_back(::std::move(event_data));
+void EPoll::OnWriteable(int fd, ::std::function<void()> function) {
+  EventData *event_data = GetEventData(fd);
+  if (event_data == nullptr) {
+    fns_.emplace_back(std::make_unique<EventData>(fd));
+    event_data = fns_.back().get();
+  } else {
+    CHECK(!event_data->out_fn) << ": Duplicate out functions for " << fd;
+  }
+  event_data->out_fn = ::std::move(function);
+  DoEpollCtl(event_data, event_data->events | kOutEvents);
 }
 
 // Removes fd from the event loop.
 void EPoll::DeleteFd(int fd) {
   auto element = fns_.begin();
-  PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == 0);
   while (fns_.end() != element) {
     if (element->get()->fd == fd) {
       fns_.erase(element);
+      PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == 0);
       return;
     }
     ++element;
@@ -125,5 +145,50 @@
   LOG(FATAL) << "fd " << fd << " not found";
 }
 
+void EPoll::EnableEvents(int fd, uint32_t events) {
+  EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
+  DoEpollCtl(event_data, event_data->events | events);
+}
+
+void EPoll::DisableEvents(int fd, uint32_t events) {
+  EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
+  DoEpollCtl(event_data, event_data->events & ~events);
+}
+
+EPoll::EventData *EPoll::GetEventData(int fd) {
+  const auto iterator = std::find_if(
+      fns_.begin(), fns_.end(),
+      [fd](const std::unique_ptr<EventData> &data) { return data->fd == fd; });
+  if (iterator == fns_.end()) {
+    return nullptr;
+  }
+  return iterator->get();
+}
+
+void EPoll::DoEpollCtl(EventData *event_data, const uint32_t new_events) {
+  const uint32_t old_events = event_data->events;
+  event_data->events = new_events;
+  if (new_events == 0) {
+    if (old_events == 0) {
+      // Not added, and doesn't need to be. Nothing to do here.
+      return;
+    }
+    // It was added, but should now be removed.
+    PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, event_data->fd, nullptr) == 0);
+    return;
+  }
+
+  int operation = EPOLL_CTL_MOD;
+  if (old_events == 0) {
+    // If it wasn't added before, then this is the first time it's being added.
+    operation = EPOLL_CTL_ADD;
+  }
+  struct epoll_event event;
+  event.events = event_data->events;
+  event.data.ptr = event_data;
+  PCHECK(epoll_ctl(epoll_fd_, operation, event_data->fd, &event) == 0)
+      << ": Failed to " << operation << " epoll fd: " << event_data->fd;
+}
+
 }  // namespace internal
 }  // namespace aos
diff --git a/aos/events/epoll.h b/aos/events/epoll.h
index 7bc135c..51a20d0 100644
--- a/aos/events/epoll.h
+++ b/aos/events/epoll.h
@@ -63,25 +63,57 @@
   void Quit();
 
   // Registers a function to be called if the fd becomes readable.
-  // There should only be 1 function registered for each fd.
+  // Only one function may be registered for readability on each fd.
   void OnReadable(int fd, ::std::function<void()> function);
 
+  // Registers a function to be called if the fd becomes writeable.
+  // Only one function may be registered for writability on each fd.
+  void OnWriteable(int fd, ::std::function<void()> function);
+
   // Removes fd from the event loop.
   // All Fds must be cleaned up before this class is destroyed.
   void DeleteFd(int fd);
 
+  // Enables calling the existing function registered for fd when it becomes
+  // writeable.
+  void EnableWriteable(int fd) { EnableEvents(fd, kOutEvents); }
+
+  // Disables calling the existing function registered for fd when it becomes
+  // writeable.
+  void DisableWriteable(int fd) { DisableEvents(fd, kOutEvents); }
+
  private:
+  // Structure whose pointer should be returned by epoll.  Makes looking up the
+  // function fast and easy.
+  struct EventData {
+    EventData(int fd_in) : fd(fd_in) {}
+    // We use pointers to these objects as persistent identifiers, so they can't
+    // be moved.
+    EventData(const EventData &) = delete;
+    EventData &operator=(const EventData &) = delete;
+
+    const int fd;
+    uint32_t events = 0;
+    ::std::function<void()> in_fn, out_fn;
+  };
+
+  void EnableEvents(int fd, uint32_t events);
+  void DisableEvents(int fd, uint32_t events);
+
+  EventData *GetEventData(int fd);
+
+  void DoEpollCtl(EventData *event_data, uint32_t new_events);
+
+  // TODO(Brian): Figure out a nicer way to handle EPOLLPRI than lumping it in
+  // with input.
+  static constexpr uint32_t kInEvents = EPOLLIN | EPOLLPRI;
+  static constexpr uint32_t kOutEvents = EPOLLOUT;
+
   ::std::atomic<bool> run_{true};
 
   // Main epoll fd.
   int epoll_fd_;
 
-  // Structure whose pointer should be returned by epoll.  Makes looking up the
-  // function fast and easy.
-  struct EventData {
-    int fd;
-    ::std::function<void()> in_fn;
-  };
   ::std::vector<::std::unique_ptr<EventData>> fns_;
 
   // Pipe pair for handling quit.
diff --git a/aos/events/epoll_test.cc b/aos/events/epoll_test.cc
new file mode 100644
index 0000000..4e6bbbc
--- /dev/null
+++ b/aos/events/epoll_test.cc
@@ -0,0 +1,172 @@
+#include "aos/events/epoll.h"
+
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "glog/logging.h"
+
+namespace aos {
+namespace internal {
+namespace testing {
+
+// A simple wrapper around both ends of a pipe along with some helpers to easily
+// read/write data through it.
+class Pipe {
+ public:
+  Pipe() { PCHECK(pipe2(fds_, O_NONBLOCK) == 0); }
+  ~Pipe() {
+    PCHECK(close(fds_[0]) == 0);
+    PCHECK(close(fds_[1]) == 0);
+  }
+
+  int read_fd() { return fds_[0]; }
+  int write_fd() { return fds_[1]; }
+
+  void Write(const std::string &data) {
+    CHECK_EQ(write(write_fd(), data.data(), data.size()),
+             static_cast<ssize_t>(data.size()));
+  }
+
+  std::string Read(size_t size) {
+    std::string result;
+    result.resize(size);
+    CHECK_EQ(read(read_fd(), result.data(), size), static_cast<ssize_t>(size));
+    return result;
+  }
+
+ private:
+  int fds_[2];
+};
+
+class EPollTest : public ::testing::Test {
+  public:
+   void RunFor(std::chrono::nanoseconds duration) {
+     TimerFd timerfd;
+     bool did_quit = false;
+     epoll_.OnReadable(timerfd.fd(), [this, &timerfd, &did_quit]() {
+       CHECK(!did_quit);
+       epoll_.Quit();
+       did_quit = true;
+       timerfd.Read();
+     });
+     timerfd.SetTime(monotonic_clock::now() + duration,
+                     monotonic_clock::duration::zero());
+     epoll_.Run();
+     CHECK(did_quit);
+     epoll_.DeleteFd(timerfd.fd());
+   }
+
+  // Tests should avoid relying on ordering for events closer in time than this,
+  // or waiting for longer than this to ensure events happen in order.
+  static constexpr std::chrono::nanoseconds tick_duration() {
+    return std::chrono::milliseconds(50);
+  }
+
+   EPoll epoll_;
+};
+
+// Test that the basics of OnReadable work.
+TEST_F(EPollTest, BasicReadable) {
+  Pipe pipe;
+  bool got_data = false;
+  epoll_.OnReadable(pipe.read_fd(), [&]() {
+    ASSERT_FALSE(got_data);
+    ASSERT_EQ("some", pipe.Read(4));
+    got_data = true;
+  });
+  RunFor(tick_duration());
+  EXPECT_FALSE(got_data);
+
+  pipe.Write("some");
+  RunFor(tick_duration());
+  EXPECT_TRUE(got_data);
+
+  epoll_.DeleteFd(pipe.read_fd());
+}
+
+// Test that the basics of OnWriteable work.
+TEST_F(EPollTest, BasicWriteable) {
+  Pipe pipe;
+  int number_writes = 0;
+  epoll_.OnWriteable(pipe.write_fd(), [&]() {
+    pipe.Write(" ");
+    ++number_writes;
+  });
+
+  // First, fill up the pipe's write buffer.
+  RunFor(tick_duration());
+  EXPECT_GT(number_writes, 0);
+
+  // Now, if we try again, we shouldn't do anything.
+  const int bytes_in_pipe = number_writes;
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, 0);
+
+  // Empty the pipe, then fill it up again.
+  for (int i = 0; i < bytes_in_pipe; ++i) {
+    ASSERT_EQ(" ", pipe.Read(1));
+  }
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, bytes_in_pipe);
+
+  epoll_.DeleteFd(pipe.write_fd());
+}
+
+TEST(EPollDeathTest, InvalidFd) {
+  EPoll epoll;
+  Pipe pipe;
+  epoll.OnReadable(pipe.read_fd(), []() {});
+  EXPECT_DEATH(epoll.OnReadable(pipe.read_fd(), []() {}),
+               "Duplicate in functions");
+  epoll.OnWriteable(pipe.read_fd(), []() {});
+  EXPECT_DEATH(epoll.OnWriteable(pipe.read_fd(), []() {}),
+               "Duplicate out functions");
+
+  epoll.DeleteFd(pipe.read_fd());
+  EXPECT_DEATH(epoll.DeleteFd(pipe.read_fd()), "fd [0-9]+ not found");
+  EXPECT_DEATH(epoll.DeleteFd(pipe.write_fd()), "fd [0-9]+ not found");
+}
+
+// Tests that enabling/disabling a writeable FD works.
+TEST_F(EPollTest, WriteableEnableDisable) {
+  Pipe pipe;
+  int number_writes = 0;
+  epoll_.OnWriteable(pipe.write_fd(), [&]() {
+    pipe.Write(" ");
+    ++number_writes;
+  });
+
+  // First, fill up the pipe's write buffer.
+  RunFor(tick_duration());
+  EXPECT_GT(number_writes, 0);
+
+  // Empty the pipe.
+  const int bytes_in_pipe = number_writes;
+  for (int i = 0; i < bytes_in_pipe; ++i) {
+    ASSERT_EQ(" ", pipe.Read(1));
+  }
+
+  // If we disable writeable checking, then nothing should happen.
+  epoll_.DisableWriteable(pipe.write_fd());
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, 0);
+
+  // Disabling it again should be a NOP.
+  epoll_.DisableWriteable(pipe.write_fd());
+
+  // And then when we re-enable, it should fill the pipe up again.
+  epoll_.EnableWriteable(pipe.write_fd());
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, bytes_in_pipe);
+
+  epoll_.DeleteFd(pipe.write_fd());
+}
+
+}  // namespace testing
+}  // namespace internal
+}  // namespace aos
diff --git a/aos/events/event_loop.cc b/aos/events/event_loop.cc
index c6d3755..76b83c9 100644
--- a/aos/events/event_loop.cc
+++ b/aos/events/event_loop.cc
@@ -74,6 +74,16 @@
   return configuration::ChannelIndex(configuration_, channel);
 }
 
+WatcherState *EventLoop::GetWatcherState(const Channel *channel) {
+  const int channel_index = ChannelIndex(channel);
+  for (const std::unique_ptr<WatcherState> &watcher : watchers_) {
+    if (watcher->channel_index() == channel_index) {
+      return watcher.get();
+    }
+  }
+  LOG(FATAL) << "No watcher found for channel";
+}
+
 void EventLoop::NewSender(RawSender *sender) {
   senders_.emplace_back(sender);
   UpdateTimingReport();
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index 9618a32..3fe46ec 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -1,4 +1,5 @@
 #ifndef AOS_EVENTS_EVENT_LOOP_H_
+
 #define AOS_EVENTS_EVENT_LOOP_H_
 
 #include <atomic>
@@ -512,6 +513,16 @@
   // Validates that channel exists inside configuration_ and finds its index.
   int ChannelIndex(const Channel *channel);
 
+  // Returns the state for the watcher on the corresponding channel. This
+  // watcher must exist before calling this.
+  WatcherState *GetWatcherState(const Channel *channel);
+
+  // Returns a Sender's protected RawSender
+  template <typename T>
+  static RawSender *GetRawSender(aos::Sender<T> *sender) {
+    return sender->sender_.get();
+  }
+
   // Context available for watchers, timers, and phased loops.
   Context context_;
 
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index c4656d9..b4578cb 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -124,6 +124,10 @@
 
   const ipc_lib::LocklessQueueConfiguration &config() const { return config_; }
 
+  absl::Span<char> GetSharedMemory() const {
+    return absl::Span<char>(static_cast<char *>(data_), size_);
+  }
+
  private:
   ipc_lib::LocklessQueueConfiguration config_;
 
@@ -298,6 +302,10 @@
 
   void UnregisterWakeup() { lockless_queue_.UnregisterWakeup(); }
 
+  absl::Span<char> GetSharedMemory() const {
+    return lockless_queue_memory_.GetSharedMemory();
+  }
+
  private:
   char *data_storage_start() {
     return RoundChannelData(data_storage_.get(), channel_->max_size());
@@ -383,6 +391,10 @@
     return true;
   }
 
+  absl::Span<char> GetSharedMemory() const {
+    return lockless_queue_memory_.GetSharedMemory();
+  }
+
  private:
   MMapedQueue lockless_queue_memory_;
   ipc_lib::LocklessQueue lockless_queue_;
@@ -437,6 +449,10 @@
 
   void UnregisterWakeup() { return simple_shm_fetcher_.UnregisterWakeup(); }
 
+  absl::Span<char> GetSharedMemory() const {
+    return simple_shm_fetcher_.GetSharedMemory();
+  }
+
  private:
   bool has_new_data_ = false;
 
@@ -689,7 +705,8 @@
     ScopedSignalMask mask({SIGINT, SIGHUP, SIGTERM});
     std::unique_lock<stl_mutex> locker(mutex_);
 
-    event_loops_.erase(std::find(event_loops_.begin(), event_loops_.end(), event_loop));
+    event_loops_.erase(
+        std::find(event_loops_.begin(), event_loops_.end(), event_loop));
 
     if (event_loops_.size() == 0u) {
       // The last caller restores the original signal handlers.
@@ -817,6 +834,17 @@
   UpdateTimingReport();
 }
 
+absl::Span<char> ShmEventLoop::GetWatcherSharedMemory(const Channel *channel) {
+  internal::WatcherState *const watcher_state =
+      static_cast<internal::WatcherState *>(GetWatcherState(channel));
+  return watcher_state->GetSharedMemory();
+}
+
+absl::Span<char> ShmEventLoop::GetShmSenderSharedMemory(
+    const aos::RawSender *sender) const {
+  return static_cast<const internal::ShmSender *>(sender)->GetSharedMemory();
+}
+
 pid_t ShmEventLoop::GetTid() { return syscall(SYS_gettid); }
 
 }  // namespace aos
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index 52bb338..fa870b8 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -3,6 +3,8 @@
 
 #include <vector>
 
+#include "absl/types/span.h"
+
 #include "aos/events/epoll.h"
 #include "aos/events/event_loop.h"
 #include "aos/events/event_loop_generated.h"
@@ -68,8 +70,20 @@
 
   int priority() const override { return priority_; }
 
+  // Returns the epoll loop used to run the event loop.
   internal::EPoll *epoll() { return &epoll_; }
 
+  // Returns the local mapping of the shared memory used by the watcher on the
+  // specified channel. A watcher must be created on this channel before calling
+  // this.
+  absl::Span<char> GetWatcherSharedMemory(const Channel *channel);
+
+  // Returns the local mapping of the shared memory used by the provided Sender
+  template <typename T>
+  absl::Span<char> GetSenderSharedMemory(aos::Sender<T> *sender) const {
+    return GetShmSenderSharedMemory(GetRawSender(sender));
+  }
+
  private:
   friend class internal::WatcherState;
   friend class internal::TimerHandlerState;
@@ -82,6 +96,9 @@
   // Returns the TID of the event loop.
   pid_t GetTid() override;
 
+  // Private method to access the shared memory mapping of a ShmSender
+  absl::Span<char> GetShmSenderSharedMemory(const aos::RawSender *sender) const;
+
   std::vector<std::function<void()>> on_run_;
   int priority_ = 0;
   std::string name_;
@@ -90,7 +107,6 @@
   internal::EPoll epoll_;
 };
 
-
 }  // namespace aos
 
 #endif  // AOS_EVENTS_SHM_EVENT_LOOP_H_
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index 2edc3fc..984f978 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -183,6 +183,34 @@
   EXPECT_EQ(times.size(), 2u);
 }
 
+// Test GetWatcherSharedMemory in a few basic scenarios.
+TEST(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
+  ShmEventLoopTestFactory factory;
+  auto generic_loop1 = factory.MakePrimary("primary");
+  ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+  const auto channel = configuration::GetChannel(
+      loop1->configuration(), "/test", TestMessage::GetFullyQualifiedName(),
+      loop1->name(), loop1->node());
+
+  // First verify it handles an invalid channel reasonably.
+  EXPECT_DEATH(loop1->GetWatcherSharedMemory(channel),
+               "No watcher found for channel");
+
+  // Then, actually create a watcher, and verify it returns something sane.
+  loop1->MakeWatcher("/test", [](const TestMessage &) {});
+  EXPECT_FALSE(loop1->GetWatcherSharedMemory(channel).empty());
+}
+
+TEST(ShmEventLoopTest, GetSenderSharedMemory) {
+  ShmEventLoopTestFactory factory;
+  auto generic_loop1 = factory.MakePrimary("primary");
+  ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+
+  // check that GetSenderSharedMemory returns non-null/non-empty memory span
+  auto sender = loop1->MakeSender<TestMessage>("/test");
+  EXPECT_FALSE(loop1->GetSenderSharedMemory(&sender).empty());
+}
+
 // TODO(austin): Test that missing a deadline with a timer recovers as expected.
 
 }  // namespace testing