Merge changes I47aa260a,I7e858d3a
* changes:
Add support for writable events on an FD to EPoll
Expose the underlying shared memory buffers from ShmEventLoop
diff --git a/aos/events/BUILD b/aos/events/BUILD
index a4c24ce..9cfdf10 100644
--- a/aos/events/BUILD
+++ b/aos/events/BUILD
@@ -41,6 +41,16 @@
],
)
+cc_test(
+ name = "epoll_test",
+ srcs = ["epoll_test.cc"],
+ deps = [
+ ":epoll",
+ "//aos/testing:googletest",
+ "@com_github_google_glog//:glog",
+ ],
+)
+
cc_library(
name = "event_loop",
srcs = [
@@ -215,6 +225,7 @@
"//aos/ipc_lib:signalfd",
"//aos/stl_mutex",
"//aos/util:phased_loop",
+ "@com_google_absl//absl/base",
],
)
diff --git a/aos/events/epoll.cc b/aos/events/epoll.cc
index e19cf59..b9e7edb 100644
--- a/aos/events/epoll.cc
+++ b/aos/events/epoll.cc
@@ -71,6 +71,7 @@
}
void EPoll::Run() {
+ run_ = true;
while (true) {
// Pull a single event out. Infinite timeout if we are supposed to be
// running, and 0 length timeout otherwise. This lets us flush the event
@@ -90,34 +91,53 @@
return;
}
}
- EventData *event_data = static_cast<struct EventData *>(event.data.ptr);
- if (event.events & (EPOLLIN | EPOLLPRI)) {
+ EventData *const event_data = static_cast<struct EventData *>(event.data.ptr);
+ if (event.events & kInEvents) {
+ CHECK(event_data->in_fn)
+ << ": No handler registered for input events on " << event_data->fd;
event_data->in_fn();
}
+ if (event.events & kOutEvents) {
+ CHECK(event_data->out_fn)
+ << ": No handler registered for output events on " << event_data->fd;
+ event_data->out_fn();
+ }
}
}
void EPoll::Quit() { PCHECK(write(quit_signal_fd_, "q", 1) == 1); }
void EPoll::OnReadable(int fd, ::std::function<void()> function) {
- ::std::unique_ptr<EventData> event_data(
- new EventData{fd, ::std::move(function)});
+ EventData *event_data = GetEventData(fd);
+ if (event_data == nullptr) {
+ fns_.emplace_back(std::make_unique<EventData>(fd));
+ event_data = fns_.back().get();
+ } else {
+ CHECK(!event_data->in_fn) << ": Duplicate in functions for " << fd;
+ }
+ event_data->in_fn = ::std::move(function);
+ DoEpollCtl(event_data, event_data->events | kInEvents);
+}
- struct epoll_event event;
- event.events = EPOLLIN | EPOLLPRI;
- event.data.ptr = event_data.get();
- PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == 0)
- << ": Failed to add fd " << fd;
- fns_.push_back(::std::move(event_data));
+void EPoll::OnWriteable(int fd, ::std::function<void()> function) {
+ EventData *event_data = GetEventData(fd);
+ if (event_data == nullptr) {
+ fns_.emplace_back(std::make_unique<EventData>(fd));
+ event_data = fns_.back().get();
+ } else {
+ CHECK(!event_data->out_fn) << ": Duplicate out functions for " << fd;
+ }
+ event_data->out_fn = ::std::move(function);
+ DoEpollCtl(event_data, event_data->events | kOutEvents);
}
// Removes fd from the event loop.
void EPoll::DeleteFd(int fd) {
auto element = fns_.begin();
- PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == 0);
while (fns_.end() != element) {
if (element->get()->fd == fd) {
fns_.erase(element);
+ PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == 0);
return;
}
++element;
@@ -125,5 +145,50 @@
LOG(FATAL) << "fd " << fd << " not found";
}
+void EPoll::EnableEvents(int fd, uint32_t events) {
+ EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
+ DoEpollCtl(event_data, event_data->events | events);
+}
+
+void EPoll::DisableEvents(int fd, uint32_t events) {
+ EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
+ DoEpollCtl(event_data, event_data->events & ~events);
+}
+
+EPoll::EventData *EPoll::GetEventData(int fd) {
+ const auto iterator = std::find_if(
+ fns_.begin(), fns_.end(),
+ [fd](const std::unique_ptr<EventData> &data) { return data->fd == fd; });
+ if (iterator == fns_.end()) {
+ return nullptr;
+ }
+ return iterator->get();
+}
+
+void EPoll::DoEpollCtl(EventData *event_data, const uint32_t new_events) {
+ const uint32_t old_events = event_data->events;
+ event_data->events = new_events;
+ if (new_events == 0) {
+ if (old_events == 0) {
+ // Not added, and doesn't need to be. Nothing to do here.
+ return;
+ }
+ // It was added, but should now be removed.
+ PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, event_data->fd, nullptr) == 0);
+ return;
+ }
+
+ int operation = EPOLL_CTL_MOD;
+ if (old_events == 0) {
+ // If it wasn't added before, then this is the first time it's being added.
+ operation = EPOLL_CTL_ADD;
+ }
+ struct epoll_event event;
+ event.events = event_data->events;
+ event.data.ptr = event_data;
+ PCHECK(epoll_ctl(epoll_fd_, operation, event_data->fd, &event) == 0)
+ << ": Failed to " << operation << " epoll fd: " << event_data->fd;
+}
+
} // namespace internal
} // namespace aos
diff --git a/aos/events/epoll.h b/aos/events/epoll.h
index 7bc135c..51a20d0 100644
--- a/aos/events/epoll.h
+++ b/aos/events/epoll.h
@@ -63,25 +63,57 @@
void Quit();
// Registers a function to be called if the fd becomes readable.
- // There should only be 1 function registered for each fd.
+ // Only one function may be registered for readability on each fd.
void OnReadable(int fd, ::std::function<void()> function);
+ // Registers a function to be called if the fd becomes writeable.
+ // Only one function may be registered for writability on each fd.
+ void OnWriteable(int fd, ::std::function<void()> function);
+
// Removes fd from the event loop.
// All Fds must be cleaned up before this class is destroyed.
void DeleteFd(int fd);
+ // Enables calling the existing function registered for fd when it becomes
+ // writeable.
+ void EnableWriteable(int fd) { EnableEvents(fd, kOutEvents); }
+
+ // Disables calling the existing function registered for fd when it becomes
+ // writeable.
+ void DisableWriteable(int fd) { DisableEvents(fd, kOutEvents); }
+
private:
+ // Structure whose pointer should be returned by epoll. Makes looking up the
+ // function fast and easy.
+ struct EventData {
+ EventData(int fd_in) : fd(fd_in) {}
+ // We use pointers to these objects as persistent identifiers, so they can't
+ // be moved.
+ EventData(const EventData &) = delete;
+ EventData &operator=(const EventData &) = delete;
+
+ const int fd;
+ uint32_t events = 0;
+ ::std::function<void()> in_fn, out_fn;
+ };
+
+ void EnableEvents(int fd, uint32_t events);
+ void DisableEvents(int fd, uint32_t events);
+
+ EventData *GetEventData(int fd);
+
+ void DoEpollCtl(EventData *event_data, uint32_t new_events);
+
+ // TODO(Brian): Figure out a nicer way to handle EPOLLPRI than lumping it in
+ // with input.
+ static constexpr uint32_t kInEvents = EPOLLIN | EPOLLPRI;
+ static constexpr uint32_t kOutEvents = EPOLLOUT;
+
::std::atomic<bool> run_{true};
// Main epoll fd.
int epoll_fd_;
- // Structure whose pointer should be returned by epoll. Makes looking up the
- // function fast and easy.
- struct EventData {
- int fd;
- ::std::function<void()> in_fn;
- };
::std::vector<::std::unique_ptr<EventData>> fns_;
// Pipe pair for handling quit.
diff --git a/aos/events/epoll_test.cc b/aos/events/epoll_test.cc
new file mode 100644
index 0000000..4e6bbbc
--- /dev/null
+++ b/aos/events/epoll_test.cc
@@ -0,0 +1,172 @@
+#include "aos/events/epoll.h"
+
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "glog/logging.h"
+
+namespace aos {
+namespace internal {
+namespace testing {
+
+// A simple wrapper around both ends of a pipe along with some helpers to easily
+// read/write data through it.
+class Pipe {
+ public:
+ Pipe() { PCHECK(pipe2(fds_, O_NONBLOCK) == 0); }
+ ~Pipe() {
+ PCHECK(close(fds_[0]) == 0);
+ PCHECK(close(fds_[1]) == 0);
+ }
+
+ int read_fd() { return fds_[0]; }
+ int write_fd() { return fds_[1]; }
+
+ void Write(const std::string &data) {
+ CHECK_EQ(write(write_fd(), data.data(), data.size()),
+ static_cast<ssize_t>(data.size()));
+ }
+
+ std::string Read(size_t size) {
+ std::string result;
+ result.resize(size);
+ CHECK_EQ(read(read_fd(), result.data(), size), static_cast<ssize_t>(size));
+ return result;
+ }
+
+ private:
+ int fds_[2];
+};
+
+class EPollTest : public ::testing::Test {
+ public:
+ void RunFor(std::chrono::nanoseconds duration) {
+ TimerFd timerfd;
+ bool did_quit = false;
+ epoll_.OnReadable(timerfd.fd(), [this, &timerfd, &did_quit]() {
+ CHECK(!did_quit);
+ epoll_.Quit();
+ did_quit = true;
+ timerfd.Read();
+ });
+ timerfd.SetTime(monotonic_clock::now() + duration,
+ monotonic_clock::duration::zero());
+ epoll_.Run();
+ CHECK(did_quit);
+ epoll_.DeleteFd(timerfd.fd());
+ }
+
+ // Tests should avoid relying on ordering for events closer in time than this,
+ // or waiting for longer than this to ensure events happen in order.
+ static constexpr std::chrono::nanoseconds tick_duration() {
+ return std::chrono::milliseconds(50);
+ }
+
+ EPoll epoll_;
+};
+
+// Test that the basics of OnReadable work.
+TEST_F(EPollTest, BasicReadable) {
+ Pipe pipe;
+ bool got_data = false;
+ epoll_.OnReadable(pipe.read_fd(), [&]() {
+ ASSERT_FALSE(got_data);
+ ASSERT_EQ("some", pipe.Read(4));
+ got_data = true;
+ });
+ RunFor(tick_duration());
+ EXPECT_FALSE(got_data);
+
+ pipe.Write("some");
+ RunFor(tick_duration());
+ EXPECT_TRUE(got_data);
+
+ epoll_.DeleteFd(pipe.read_fd());
+}
+
+// Test that the basics of OnWriteable work.
+TEST_F(EPollTest, BasicWriteable) {
+ Pipe pipe;
+ int number_writes = 0;
+ epoll_.OnWriteable(pipe.write_fd(), [&]() {
+ pipe.Write(" ");
+ ++number_writes;
+ });
+
+ // First, fill up the pipe's write buffer.
+ RunFor(tick_duration());
+ EXPECT_GT(number_writes, 0);
+
+ // Now, if we try again, we shouldn't do anything.
+ const int bytes_in_pipe = number_writes;
+ number_writes = 0;
+ RunFor(tick_duration());
+ EXPECT_EQ(number_writes, 0);
+
+ // Empty the pipe, then fill it up again.
+ for (int i = 0; i < bytes_in_pipe; ++i) {
+ ASSERT_EQ(" ", pipe.Read(1));
+ }
+ number_writes = 0;
+ RunFor(tick_duration());
+ EXPECT_EQ(number_writes, bytes_in_pipe);
+
+ epoll_.DeleteFd(pipe.write_fd());
+}
+
+TEST(EPollDeathTest, InvalidFd) {
+ EPoll epoll;
+ Pipe pipe;
+ epoll.OnReadable(pipe.read_fd(), []() {});
+ EXPECT_DEATH(epoll.OnReadable(pipe.read_fd(), []() {}),
+ "Duplicate in functions");
+ epoll.OnWriteable(pipe.read_fd(), []() {});
+ EXPECT_DEATH(epoll.OnWriteable(pipe.read_fd(), []() {}),
+ "Duplicate out functions");
+
+ epoll.DeleteFd(pipe.read_fd());
+ EXPECT_DEATH(epoll.DeleteFd(pipe.read_fd()), "fd [0-9]+ not found");
+ EXPECT_DEATH(epoll.DeleteFd(pipe.write_fd()), "fd [0-9]+ not found");
+}
+
+// Tests that enabling/disabling a writeable FD works.
+TEST_F(EPollTest, WriteableEnableDisable) {
+ Pipe pipe;
+ int number_writes = 0;
+ epoll_.OnWriteable(pipe.write_fd(), [&]() {
+ pipe.Write(" ");
+ ++number_writes;
+ });
+
+ // First, fill up the pipe's write buffer.
+ RunFor(tick_duration());
+ EXPECT_GT(number_writes, 0);
+
+ // Empty the pipe.
+ const int bytes_in_pipe = number_writes;
+ for (int i = 0; i < bytes_in_pipe; ++i) {
+ ASSERT_EQ(" ", pipe.Read(1));
+ }
+
+ // If we disable writeable checking, then nothing should happen.
+ epoll_.DisableWriteable(pipe.write_fd());
+ number_writes = 0;
+ RunFor(tick_duration());
+ EXPECT_EQ(number_writes, 0);
+
+ // Disabling it again should be a NOP.
+ epoll_.DisableWriteable(pipe.write_fd());
+
+ // And then when we re-enable, it should fill the pipe up again.
+ epoll_.EnableWriteable(pipe.write_fd());
+ number_writes = 0;
+ RunFor(tick_duration());
+ EXPECT_EQ(number_writes, bytes_in_pipe);
+
+ epoll_.DeleteFd(pipe.write_fd());
+}
+
+} // namespace testing
+} // namespace internal
+} // namespace aos
diff --git a/aos/events/event_loop.cc b/aos/events/event_loop.cc
index c6d3755..76b83c9 100644
--- a/aos/events/event_loop.cc
+++ b/aos/events/event_loop.cc
@@ -74,6 +74,16 @@
return configuration::ChannelIndex(configuration_, channel);
}
+WatcherState *EventLoop::GetWatcherState(const Channel *channel) {
+ const int channel_index = ChannelIndex(channel);
+ for (const std::unique_ptr<WatcherState> &watcher : watchers_) {
+ if (watcher->channel_index() == channel_index) {
+ return watcher.get();
+ }
+ }
+ LOG(FATAL) << "No watcher found for channel";
+}
+
void EventLoop::NewSender(RawSender *sender) {
senders_.emplace_back(sender);
UpdateTimingReport();
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index 9618a32..3fe46ec 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -1,4 +1,5 @@
#ifndef AOS_EVENTS_EVENT_LOOP_H_
+
#define AOS_EVENTS_EVENT_LOOP_H_
#include <atomic>
@@ -512,6 +513,16 @@
// Validates that channel exists inside configuration_ and finds its index.
int ChannelIndex(const Channel *channel);
+ // Returns the state for the watcher on the corresponding channel. This
+ // watcher must exist before calling this.
+ WatcherState *GetWatcherState(const Channel *channel);
+
+ // Returns a Sender's protected RawSender
+ template <typename T>
+ static RawSender *GetRawSender(aos::Sender<T> *sender) {
+ return sender->sender_.get();
+ }
+
// Context available for watchers, timers, and phased loops.
Context context_;
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index c4656d9..b4578cb 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -124,6 +124,10 @@
const ipc_lib::LocklessQueueConfiguration &config() const { return config_; }
+ absl::Span<char> GetSharedMemory() const {
+ return absl::Span<char>(static_cast<char *>(data_), size_);
+ }
+
private:
ipc_lib::LocklessQueueConfiguration config_;
@@ -298,6 +302,10 @@
void UnregisterWakeup() { lockless_queue_.UnregisterWakeup(); }
+ absl::Span<char> GetSharedMemory() const {
+ return lockless_queue_memory_.GetSharedMemory();
+ }
+
private:
char *data_storage_start() {
return RoundChannelData(data_storage_.get(), channel_->max_size());
@@ -383,6 +391,10 @@
return true;
}
+ absl::Span<char> GetSharedMemory() const {
+ return lockless_queue_memory_.GetSharedMemory();
+ }
+
private:
MMapedQueue lockless_queue_memory_;
ipc_lib::LocklessQueue lockless_queue_;
@@ -437,6 +449,10 @@
void UnregisterWakeup() { return simple_shm_fetcher_.UnregisterWakeup(); }
+ absl::Span<char> GetSharedMemory() const {
+ return simple_shm_fetcher_.GetSharedMemory();
+ }
+
private:
bool has_new_data_ = false;
@@ -689,7 +705,8 @@
ScopedSignalMask mask({SIGINT, SIGHUP, SIGTERM});
std::unique_lock<stl_mutex> locker(mutex_);
- event_loops_.erase(std::find(event_loops_.begin(), event_loops_.end(), event_loop));
+ event_loops_.erase(
+ std::find(event_loops_.begin(), event_loops_.end(), event_loop));
if (event_loops_.size() == 0u) {
// The last caller restores the original signal handlers.
@@ -817,6 +834,17 @@
UpdateTimingReport();
}
+absl::Span<char> ShmEventLoop::GetWatcherSharedMemory(const Channel *channel) {
+ internal::WatcherState *const watcher_state =
+ static_cast<internal::WatcherState *>(GetWatcherState(channel));
+ return watcher_state->GetSharedMemory();
+}
+
+absl::Span<char> ShmEventLoop::GetShmSenderSharedMemory(
+ const aos::RawSender *sender) const {
+ return static_cast<const internal::ShmSender *>(sender)->GetSharedMemory();
+}
+
pid_t ShmEventLoop::GetTid() { return syscall(SYS_gettid); }
} // namespace aos
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index 52bb338..fa870b8 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -3,6 +3,8 @@
#include <vector>
+#include "absl/types/span.h"
+
#include "aos/events/epoll.h"
#include "aos/events/event_loop.h"
#include "aos/events/event_loop_generated.h"
@@ -68,8 +70,20 @@
int priority() const override { return priority_; }
+ // Returns the epoll loop used to run the event loop.
internal::EPoll *epoll() { return &epoll_; }
+ // Returns the local mapping of the shared memory used by the watcher on the
+ // specified channel. A watcher must be created on this channel before calling
+ // this.
+ absl::Span<char> GetWatcherSharedMemory(const Channel *channel);
+
+ // Returns the local mapping of the shared memory used by the provided Sender
+ template <typename T>
+ absl::Span<char> GetSenderSharedMemory(aos::Sender<T> *sender) const {
+ return GetShmSenderSharedMemory(GetRawSender(sender));
+ }
+
private:
friend class internal::WatcherState;
friend class internal::TimerHandlerState;
@@ -82,6 +96,9 @@
// Returns the TID of the event loop.
pid_t GetTid() override;
+ // Private method to access the shared memory mapping of a ShmSender
+ absl::Span<char> GetShmSenderSharedMemory(const aos::RawSender *sender) const;
+
std::vector<std::function<void()>> on_run_;
int priority_ = 0;
std::string name_;
@@ -90,7 +107,6 @@
internal::EPoll epoll_;
};
-
} // namespace aos
#endif // AOS_EVENTS_SHM_EVENT_LOOP_H_
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index 2edc3fc..984f978 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -183,6 +183,34 @@
EXPECT_EQ(times.size(), 2u);
}
+// Test GetWatcherSharedMemory in a few basic scenarios.
+TEST(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
+ ShmEventLoopTestFactory factory;
+ auto generic_loop1 = factory.MakePrimary("primary");
+ ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+ const auto channel = configuration::GetChannel(
+ loop1->configuration(), "/test", TestMessage::GetFullyQualifiedName(),
+ loop1->name(), loop1->node());
+
+ // First verify it handles an invalid channel reasonably.
+ EXPECT_DEATH(loop1->GetWatcherSharedMemory(channel),
+ "No watcher found for channel");
+
+ // Then, actually create a watcher, and verify it returns something sane.
+ loop1->MakeWatcher("/test", [](const TestMessage &) {});
+ EXPECT_FALSE(loop1->GetWatcherSharedMemory(channel).empty());
+}
+
+TEST(ShmEventLoopTest, GetSenderSharedMemory) {
+ ShmEventLoopTestFactory factory;
+ auto generic_loop1 = factory.MakePrimary("primary");
+ ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+
+ // check that GetSenderSharedMemory returns non-null/non-empty memory span
+ auto sender = loop1->MakeSender<TestMessage>("/test");
+ EXPECT_FALSE(loop1->GetSenderSharedMemory(&sender).empty());
+}
+
// TODO(austin): Test that missing a deadline with a timer recovers as expected.
} // namespace testing