Make attaching poll-based external event loops to EPoll easier

Way easier than tracking handlers for each type: just expose the
underlying epoll API of "enable these events" directly.

Change-Id: I711fe1b3da7690ef424bf50cdd087ca87c02503b
Signed-off-by: Austin Schuh <austin.schuh@bluerivertech.com>
diff --git a/aos/events/epoll.cc b/aos/events/epoll.cc
index 7fff6ef..8cb553b 100644
--- a/aos/events/epoll.cc
+++ b/aos/events/epoll.cc
@@ -4,6 +4,7 @@
 #include <sys/epoll.h>
 #include <sys/timerfd.h>
 #include <unistd.h>
+
 #include <atomic>
 #include <vector>
 
@@ -109,21 +110,7 @@
   }
 
   EventData *const event_data = static_cast<struct EventData *>(event.data.ptr);
-  if (event.events & kInEvents) {
-    CHECK(event_data->in_fn)
-        << ": No handler registered for input events on " << event_data->fd;
-    event_data->in_fn();
-  }
-  if (event.events & kOutEvents) {
-    CHECK(event_data->out_fn)
-        << ": No handler registered for output events on " << event_data->fd;
-    event_data->out_fn();
-  }
-  if (event.events & kErrorEvents) {
-    CHECK(event_data->err_fn)
-        << ": No handler registered for error events on " << event_data->fd;
-    event_data->err_fn();
-  }
+  event_data->DoCallbacks(event.events);
   return true;
 }
 
@@ -139,39 +126,50 @@
 void EPoll::OnReadable(int fd, ::std::function<void()> function) {
   EventData *event_data = GetEventData(fd);
   if (event_data == nullptr) {
-    fns_.emplace_back(std::make_unique<EventData>(fd));
+    fns_.emplace_back(std::make_unique<InOutEventData>(fd));
     event_data = fns_.back().get();
   } else {
-    CHECK(!event_data->in_fn) << ": Duplicate in functions for " << fd;
+    CHECK(!static_cast<InOutEventData *>(event_data)->in_fn)
+        << ": Duplicate in functions for " << fd;
   }
-  event_data->in_fn = ::std::move(function);
+  static_cast<InOutEventData *>(event_data)->in_fn = ::std::move(function);
   DoEpollCtl(event_data, event_data->events | kInEvents);
 }
 
 void EPoll::OnError(int fd, ::std::function<void()> function) {
   EventData *event_data = GetEventData(fd);
   if (event_data == nullptr) {
-    fns_.emplace_back(std::make_unique<EventData>(fd));
+    fns_.emplace_back(std::make_unique<InOutEventData>(fd));
     event_data = fns_.back().get();
   } else {
-    CHECK(!event_data->err_fn) << ": Duplicate in functions for " << fd;
+    CHECK(!static_cast<InOutEventData *>(event_data)->err_fn)
+        << ": Duplicate error functions for " << fd;
   }
-  event_data->err_fn = ::std::move(function);
+  static_cast<InOutEventData *>(event_data)->err_fn = ::std::move(function);
   DoEpollCtl(event_data, event_data->events | kErrorEvents);
 }
 
 void EPoll::OnWriteable(int fd, ::std::function<void()> function) {
   EventData *event_data = GetEventData(fd);
   if (event_data == nullptr) {
-    fns_.emplace_back(std::make_unique<EventData>(fd));
+    fns_.emplace_back(std::make_unique<InOutEventData>(fd));
     event_data = fns_.back().get();
   } else {
-    CHECK(!event_data->out_fn) << ": Duplicate out functions for " << fd;
+    CHECK(!static_cast<InOutEventData *>(event_data)->out_fn)
+        << ": Duplicate out functions for " << fd;
   }
-  event_data->out_fn = ::std::move(function);
+  static_cast<InOutEventData *>(event_data)->out_fn = ::std::move(function);
   DoEpollCtl(event_data, event_data->events | kOutEvents);
 }
 
+void EPoll::OnEvents(int fd, ::std::function<void(uint32_t)> function) {
+  if (GetEventData(fd) != nullptr) {
+    LOG(FATAL) << "May not replace OnEvents handlers";
+  }
+  fns_.emplace_back(std::make_unique<SingleEventData>(fd));
+  static_cast<SingleEventData *>(fns_.back().get())->fn = std::move(function);
+}
+
 void EPoll::ForgetClosedFd(int fd) {
   auto element = fns_.begin();
   while (fns_.end() != element) {
@@ -184,6 +182,10 @@
   LOG(FATAL) << "fd " << fd << " not found";
 }
 
+void EPoll::SetEvents(int fd, uint32_t events) {
+  DoEpollCtl(CHECK_NOTNULL(GetEventData(fd)), events);
+}
+
 // Removes fd from the event loop.
 void EPoll::DeleteFd(int fd) {
   auto element = fns_.begin();
@@ -199,6 +201,21 @@
   LOG(FATAL) << "fd " << fd << " not found";
 }
 
+void EPoll::InOutEventData::DoCallbacks(uint32_t events) {
+  if (events & kInEvents) {
+    CHECK(in_fn) << ": No handler registered for input events on " << fd;
+    in_fn();
+  }
+  if (events & kOutEvents) {
+    CHECK(out_fn) << ": No handler registered for output events on " << fd;
+    out_fn();
+  }
+  if (events & kErrorEvents) {
+    CHECK(err_fn) << ": No handler registered for error events on " << fd;
+    err_fn();
+  }
+}
+
 void EPoll::EnableEvents(int fd, uint32_t events) {
   EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
   DoEpollCtl(event_data, event_data->events | events);
@@ -221,12 +238,14 @@
 
 void EPoll::DoEpollCtl(EventData *event_data, const uint32_t new_events) {
   const uint32_t old_events = event_data->events;
+  if (old_events == new_events) {
+    // Shortcut without calling into the kernel. This happens often with
+    // external event loop integrations that are emulating poll, so make it
+    // fast.
+    return;
+  }
   event_data->events = new_events;
   if (new_events == 0) {
-    if (old_events == 0) {
-      // Not added, and doesn't need to be. Nothing to do here.
-      return;
-    }
     // It was added, but should now be removed.
     PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, event_data->fd, nullptr) == 0);
     return;
diff --git a/aos/events/epoll.h b/aos/events/epoll.h
index 4bcedf1..2b7eb76 100644
--- a/aos/events/epoll.h
+++ b/aos/events/epoll.h
@@ -5,6 +5,7 @@
 #include <sys/epoll.h>
 #include <sys/timerfd.h>
 #include <unistd.h>
+
 #include <atomic>
 #include <functional>
 #include <vector>
@@ -68,21 +69,34 @@
   // Quits.  Async safe.
   void Quit();
 
-  // Called before waiting on the epoll file descriptor.
+  // Adds a function which will be called before waiting on the epoll file
+  // descriptor.
   void BeforeWait(std::function<void()> function);
 
   // Registers a function to be called if the fd becomes readable.
   // Only one function may be registered for readability on each fd.
+  // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+  // OnEvents.
   void OnReadable(int fd, ::std::function<void()> function);
 
   // Registers a function to be called if the fd reports an error.
   // Only one function may be registered for errors on each fd.
+  // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+  // OnEvents.
   void OnError(int fd, ::std::function<void()> function);
 
   // Registers a function to be called if the fd becomes writeable.
   // Only one function may be registered for writability on each fd.
+  // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+  // OnEvents.
   void OnWriteable(int fd, ::std::function<void()> function);
 
+  // Registers a function to be called when the configured events occur on fd.
+  // Which events occur will be passed to the function.
+  // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+  // OnEvents.
+  void OnEvents(int fd, ::std::function<void(uint32_t)> function);
+
   // Removes fd from the event loop.
   // All Fds must be cleaned up before this class is destroyed.
   void DeleteFd(int fd);
@@ -100,19 +114,50 @@
   // writeable.
   void DisableWriteable(int fd) { DisableEvents(fd, kOutEvents); }
 
+  // Sets the epoll events for the given fd. Be careful using this with
+  // OnReadable/OnWriteable/OnError: enabled events which fire with no handler
+  // registered will result in a crash.
+  void SetEvents(int fd, uint32_t events);
+
+  // Returns whether we're currently running. This changes to false when we
+  // start draining events to finish.
+  bool should_run() const { return run_; }
+
  private:
   // Structure whose pointer should be returned by epoll.  Makes looking up the
   // function fast and easy.
   struct EventData {
     EventData(int fd_in) : fd(fd_in) {}
+    virtual ~EventData() = default;
+
     // We use pointers to these objects as persistent identifiers, so they can't
     // be moved.
     EventData(const EventData &) = delete;
     EventData &operator=(const EventData &) = delete;
 
+    // Calls the appropriate callbacks when events are returned from the kernel.
+    virtual void DoCallbacks(uint32_t events) = 0;
+
     const int fd;
     uint32_t events = 0;
+  };
+
+  struct InOutEventData : public EventData {
+    InOutEventData(int fd) : EventData(fd) {}
+    ~InOutEventData() override = default;
+
     std::function<void()> in_fn, out_fn, err_fn;
+
+    void DoCallbacks(uint32_t events) override;
+  };
+
+  struct SingleEventData : public EventData {
+    SingleEventData(int fd) : EventData(fd) {}
+    ~SingleEventData() override = default;
+
+    std::function<void(uint32_t)> fn;
+
+    void DoCallbacks(uint32_t events) override { fn(events); }
   };
 
   void EnableEvents(int fd, uint32_t events);