Add support for writable events on an FD to EPoll

This allows writing a TCP server with it, for example.

Change-Id: I47aa260aec8a88791f783d0236fc926bbe271192
diff --git a/aos/events/BUILD b/aos/events/BUILD
index a183cae..6b4af79 100644
--- a/aos/events/BUILD
+++ b/aos/events/BUILD
@@ -41,6 +41,16 @@
     ],
 )
 
+cc_test(
+    name = "epoll_test",
+    srcs = ["epoll_test.cc"],
+    deps = [
+        ":epoll",
+        "//aos/testing:googletest",
+        "@com_github_google_glog//:glog",
+    ],
+)
+
 cc_library(
     name = "event_loop",
     srcs = [
@@ -214,6 +224,7 @@
         "//aos/ipc_lib:signalfd",
         "//aos/stl_mutex",
         "//aos/util:phased_loop",
+        "@com_google_absl//absl/base",
     ],
 )
 
diff --git a/aos/events/epoll.cc b/aos/events/epoll.cc
index e19cf59..b9e7edb 100644
--- a/aos/events/epoll.cc
+++ b/aos/events/epoll.cc
@@ -71,6 +71,7 @@
 }
 
 void EPoll::Run() {
+  run_ = true;
   while (true) {
     // Pull a single event out.  Infinite timeout if we are supposed to be
     // running, and 0 length timeout otherwise.  This lets us flush the event
@@ -90,34 +91,53 @@
         return;
       }
     }
-    EventData *event_data = static_cast<struct EventData *>(event.data.ptr);
-    if (event.events & (EPOLLIN | EPOLLPRI)) {
+    EventData *const event_data = static_cast<struct EventData *>(event.data.ptr);
+    if (event.events & kInEvents) {
+      CHECK(event_data->in_fn)
+          << ": No handler registered for input events on " << event_data->fd;
       event_data->in_fn();
     }
+    if (event.events & kOutEvents) {
+      CHECK(event_data->out_fn)
+          << ": No handler registered for output events on " << event_data->fd;
+      event_data->out_fn();
+    }
   }
 }
 
 void EPoll::Quit() { PCHECK(write(quit_signal_fd_, "q", 1) == 1); }
 
 void EPoll::OnReadable(int fd, ::std::function<void()> function) {
-  ::std::unique_ptr<EventData> event_data(
-      new EventData{fd, ::std::move(function)});
+  EventData *event_data = GetEventData(fd);
+  if (event_data == nullptr) {
+    fns_.emplace_back(std::make_unique<EventData>(fd));
+    event_data = fns_.back().get();
+  } else {
+    CHECK(!event_data->in_fn) << ": Duplicate in functions for " << fd;
+  }
+  event_data->in_fn = ::std::move(function);
+  DoEpollCtl(event_data, event_data->events | kInEvents);
+}
 
-  struct epoll_event event;
-  event.events = EPOLLIN | EPOLLPRI;
-  event.data.ptr = event_data.get();
-  PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == 0)
-      << ": Failed to add fd " << fd;
-  fns_.push_back(::std::move(event_data));
+void EPoll::OnWriteable(int fd, ::std::function<void()> function) {
+  EventData *event_data = GetEventData(fd);
+  if (event_data == nullptr) {
+    fns_.emplace_back(std::make_unique<EventData>(fd));
+    event_data = fns_.back().get();
+  } else {
+    CHECK(!event_data->out_fn) << ": Duplicate out functions for " << fd;
+  }
+  event_data->out_fn = ::std::move(function);
+  DoEpollCtl(event_data, event_data->events | kOutEvents);
 }
 
 // Removes fd from the event loop.
 void EPoll::DeleteFd(int fd) {
   auto element = fns_.begin();
-  PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == 0);
   while (fns_.end() != element) {
     if (element->get()->fd == fd) {
       fns_.erase(element);
+      PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == 0);
       return;
     }
     ++element;
@@ -125,5 +145,50 @@
   LOG(FATAL) << "fd " << fd << " not found";
 }
 
+void EPoll::EnableEvents(int fd, uint32_t events) {
+  EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
+  DoEpollCtl(event_data, event_data->events | events);
+}
+
+void EPoll::DisableEvents(int fd, uint32_t events) {
+  EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
+  DoEpollCtl(event_data, event_data->events & ~events);
+}
+
+EPoll::EventData *EPoll::GetEventData(int fd) {
+  const auto iterator = std::find_if(
+      fns_.begin(), fns_.end(),
+      [fd](const std::unique_ptr<EventData> &data) { return data->fd == fd; });
+  if (iterator == fns_.end()) {
+    return nullptr;
+  }
+  return iterator->get();
+}
+
+void EPoll::DoEpollCtl(EventData *event_data, const uint32_t new_events) {
+  const uint32_t old_events = event_data->events;
+  event_data->events = new_events;
+  if (new_events == 0) {
+    if (old_events == 0) {
+      // Not added, and doesn't need to be. Nothing to do here.
+      return;
+    }
+    // It was added, but should now be removed.
+    PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, event_data->fd, nullptr) == 0);
+    return;
+  }
+
+  int operation = EPOLL_CTL_MOD;
+  if (old_events == 0) {
+    // If it wasn't added before, then this is the first time it's being added.
+    operation = EPOLL_CTL_ADD;
+  }
+  struct epoll_event event;
+  event.events = event_data->events;
+  event.data.ptr = event_data;
+  PCHECK(epoll_ctl(epoll_fd_, operation, event_data->fd, &event) == 0)
+      << ": Failed to " << operation << " epoll fd: " << event_data->fd;
+}
+
 }  // namespace internal
 }  // namespace aos
diff --git a/aos/events/epoll.h b/aos/events/epoll.h
index 7bc135c..51a20d0 100644
--- a/aos/events/epoll.h
+++ b/aos/events/epoll.h
@@ -63,25 +63,57 @@
   void Quit();
 
   // Registers a function to be called if the fd becomes readable.
-  // There should only be 1 function registered for each fd.
+  // Only one function may be registered for readability on each fd.
   void OnReadable(int fd, ::std::function<void()> function);
 
+  // Registers a function to be called if the fd becomes writeable.
+  // Only one function may be registered for writability on each fd.
+  void OnWriteable(int fd, ::std::function<void()> function);
+
   // Removes fd from the event loop.
   // All Fds must be cleaned up before this class is destroyed.
   void DeleteFd(int fd);
 
+  // Enables calling the existing function registered for fd when it becomes
+  // writeable.
+  void EnableWriteable(int fd) { EnableEvents(fd, kOutEvents); }
+
+  // Disables calling the existing function registered for fd when it becomes
+  // writeable.
+  void DisableWriteable(int fd) { DisableEvents(fd, kOutEvents); }
+
  private:
+  // Structure whose pointer should be returned by epoll.  Makes looking up the
+  // function fast and easy.
+  struct EventData {
+    EventData(int fd_in) : fd(fd_in) {}
+    // We use pointers to these objects as persistent identifiers, so they can't
+    // be moved.
+    EventData(const EventData &) = delete;
+    EventData &operator=(const EventData &) = delete;
+
+    const int fd;
+    uint32_t events = 0;
+    ::std::function<void()> in_fn, out_fn;
+  };
+
+  void EnableEvents(int fd, uint32_t events);
+  void DisableEvents(int fd, uint32_t events);
+
+  EventData *GetEventData(int fd);
+
+  void DoEpollCtl(EventData *event_data, uint32_t new_events);
+
+  // TODO(Brian): Figure out a nicer way to handle EPOLLPRI than lumping it in
+  // with input.
+  static constexpr uint32_t kInEvents = EPOLLIN | EPOLLPRI;
+  static constexpr uint32_t kOutEvents = EPOLLOUT;
+
   ::std::atomic<bool> run_{true};
 
   // Main epoll fd.
   int epoll_fd_;
 
-  // Structure whose pointer should be returned by epoll.  Makes looking up the
-  // function fast and easy.
-  struct EventData {
-    int fd;
-    ::std::function<void()> in_fn;
-  };
   ::std::vector<::std::unique_ptr<EventData>> fns_;
 
   // Pipe pair for handling quit.
diff --git a/aos/events/epoll_test.cc b/aos/events/epoll_test.cc
new file mode 100644
index 0000000..4e6bbbc
--- /dev/null
+++ b/aos/events/epoll_test.cc
@@ -0,0 +1,172 @@
+#include "aos/events/epoll.h"
+
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "glog/logging.h"
+
+namespace aos {
+namespace internal {
+namespace testing {
+
+// A simple wrapper around both ends of a pipe along with some helpers to easily
+// read/write data through it.
+class Pipe {
+ public:
+  Pipe() { PCHECK(pipe2(fds_, O_NONBLOCK) == 0); }
+  ~Pipe() {
+    PCHECK(close(fds_[0]) == 0);
+    PCHECK(close(fds_[1]) == 0);
+  }
+
+  int read_fd() { return fds_[0]; }
+  int write_fd() { return fds_[1]; }
+
+  void Write(const std::string &data) {
+    CHECK_EQ(write(write_fd(), data.data(), data.size()),
+             static_cast<ssize_t>(data.size()));
+  }
+
+  std::string Read(size_t size) {
+    std::string result;
+    result.resize(size);
+    CHECK_EQ(read(read_fd(), result.data(), size), static_cast<ssize_t>(size));
+    return result;
+  }
+
+ private:
+  int fds_[2];
+};
+
+class EPollTest : public ::testing::Test {
+  public:
+   void RunFor(std::chrono::nanoseconds duration) {
+     TimerFd timerfd;
+     bool did_quit = false;
+     epoll_.OnReadable(timerfd.fd(), [this, &timerfd, &did_quit]() {
+       CHECK(!did_quit);
+       epoll_.Quit();
+       did_quit = true;
+       timerfd.Read();
+     });
+     timerfd.SetTime(monotonic_clock::now() + duration,
+                     monotonic_clock::duration::zero());
+     epoll_.Run();
+     CHECK(did_quit);
+     epoll_.DeleteFd(timerfd.fd());
+   }
+
+  // Tests should avoid relying on ordering for events closer in time than this,
+  // or waiting for longer than this to ensure events happen in order.
+  static constexpr std::chrono::nanoseconds tick_duration() {
+    return std::chrono::milliseconds(50);
+  }
+
+   EPoll epoll_;
+};
+
+// Test that the basics of OnReadable work.
+TEST_F(EPollTest, BasicReadable) {
+  Pipe pipe;
+  bool got_data = false;
+  epoll_.OnReadable(pipe.read_fd(), [&]() {
+    ASSERT_FALSE(got_data);
+    ASSERT_EQ("some", pipe.Read(4));
+    got_data = true;
+  });
+  RunFor(tick_duration());
+  EXPECT_FALSE(got_data);
+
+  pipe.Write("some");
+  RunFor(tick_duration());
+  EXPECT_TRUE(got_data);
+
+  epoll_.DeleteFd(pipe.read_fd());
+}
+
+// Test that the basics of OnWriteable work.
+TEST_F(EPollTest, BasicWriteable) {
+  Pipe pipe;
+  int number_writes = 0;
+  epoll_.OnWriteable(pipe.write_fd(), [&]() {
+    pipe.Write(" ");
+    ++number_writes;
+  });
+
+  // First, fill up the pipe's write buffer.
+  RunFor(tick_duration());
+  EXPECT_GT(number_writes, 0);
+
+  // Now, if we try again, we shouldn't do anything.
+  const int bytes_in_pipe = number_writes;
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, 0);
+
+  // Empty the pipe, then fill it up again.
+  for (int i = 0; i < bytes_in_pipe; ++i) {
+    ASSERT_EQ(" ", pipe.Read(1));
+  }
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, bytes_in_pipe);
+
+  epoll_.DeleteFd(pipe.write_fd());
+}
+
+TEST(EPollDeathTest, InvalidFd) {
+  EPoll epoll;
+  Pipe pipe;
+  epoll.OnReadable(pipe.read_fd(), []() {});
+  EXPECT_DEATH(epoll.OnReadable(pipe.read_fd(), []() {}),
+               "Duplicate in functions");
+  epoll.OnWriteable(pipe.read_fd(), []() {});
+  EXPECT_DEATH(epoll.OnWriteable(pipe.read_fd(), []() {}),
+               "Duplicate out functions");
+
+  epoll.DeleteFd(pipe.read_fd());
+  EXPECT_DEATH(epoll.DeleteFd(pipe.read_fd()), "fd [0-9]+ not found");
+  EXPECT_DEATH(epoll.DeleteFd(pipe.write_fd()), "fd [0-9]+ not found");
+}
+
+// Tests that enabling/disabling a writeable FD works.
+TEST_F(EPollTest, WriteableEnableDisable) {
+  Pipe pipe;
+  int number_writes = 0;
+  epoll_.OnWriteable(pipe.write_fd(), [&]() {
+    pipe.Write(" ");
+    ++number_writes;
+  });
+
+  // First, fill up the pipe's write buffer.
+  RunFor(tick_duration());
+  EXPECT_GT(number_writes, 0);
+
+  // Empty the pipe.
+  const int bytes_in_pipe = number_writes;
+  for (int i = 0; i < bytes_in_pipe; ++i) {
+    ASSERT_EQ(" ", pipe.Read(1));
+  }
+
+  // If we disable writeable checking, then nothing should happen.
+  epoll_.DisableWriteable(pipe.write_fd());
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, 0);
+
+  // Disabling it again should be a NOP.
+  epoll_.DisableWriteable(pipe.write_fd());
+
+  // And then when we re-enable, it should fill the pipe up again.
+  epoll_.EnableWriteable(pipe.write_fd());
+  number_writes = 0;
+  RunFor(tick_duration());
+  EXPECT_EQ(number_writes, bytes_in_pipe);
+
+  epoll_.DeleteFd(pipe.write_fd());
+}
+
+}  // namespace testing
+}  // namespace internal
+}  // namespace aos