Expose the underlying shared memory buffers from ShmEventLoop
For advanced use cases, this is handy to set up these memory regions
specially.
Change-Id: I7e858d3af5f3fa51f980e0e4cab5dfcfa7b83fd9
diff --git a/aos/events/event_loop.cc b/aos/events/event_loop.cc
index e368df2..ec23afc 100644
--- a/aos/events/event_loop.cc
+++ b/aos/events/event_loop.cc
@@ -81,6 +81,16 @@
return std::distance(configuration()->channels()->begin(), c);
}
+WatcherState *EventLoop::GetWatcherState(const Channel *channel) {
+ const int channel_index = ChannelIndex(channel);
+ for (const std::unique_ptr<WatcherState> &watcher : watchers_) {
+ if (watcher->channel_index() == channel_index) {
+ return watcher.get();
+ }
+ }
+ LOG(FATAL) << "No watcher found for channel";
+}
+
void EventLoop::NewSender(RawSender *sender) {
senders_.emplace_back(sender);
UpdateTimingReport();
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index 4a12096..ed365c3 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -1,4 +1,5 @@
#ifndef AOS_EVENTS_EVENT_LOOP_H_
+
#define AOS_EVENTS_EVENT_LOOP_H_
#include <atomic>
@@ -482,6 +483,16 @@
// Validates that channel exists inside configuration_ and finds its index.
int ChannelIndex(const Channel *channel);
+ // Returns the state for the watcher on the corresponding channel. This
+ // watcher must exist before calling this.
+ WatcherState *GetWatcherState(const Channel *channel);
+
+ // Returns a Sender's protected RawSender
+ template <typename T>
+ static RawSender *GetRawSender(aos::Sender<T> *sender) {
+ return sender->sender_.get();
+ }
+
// Context available for watchers, timers, and phased loops.
Context context_;
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index cc11520..c125792 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -123,6 +123,10 @@
const ipc_lib::LocklessQueueConfiguration &config() const { return config_; }
+ absl::Span<char> GetSharedMemory() const {
+ return absl::Span<char>(static_cast<char *>(data_), size_);
+ }
+
private:
void MkdirP(std::string_view path) {
auto last_slash_pos = path.find_last_of("/");
@@ -314,6 +318,10 @@
void UnregisterWakeup() { lockless_queue_.UnregisterWakeup(); }
+ absl::Span<char> GetSharedMemory() const {
+ return lockless_queue_memory_.GetSharedMemory();
+ }
+
private:
const Channel *const channel_;
MMapedQueue lockless_queue_memory_;
@@ -402,6 +410,10 @@
return true;
}
+ absl::Span<char> GetSharedMemory() const {
+ return lockless_queue_memory_.GetSharedMemory();
+ }
+
private:
MMapedQueue lockless_queue_memory_;
ipc_lib::LocklessQueue lockless_queue_;
@@ -456,6 +468,10 @@
void UnregisterWakeup() { return simple_shm_fetcher_.UnregisterWakeup(); }
+ absl::Span<char> GetSharedMemory() const {
+ return simple_shm_fetcher_.GetSharedMemory();
+ }
+
private:
bool has_new_data_ = false;
@@ -708,7 +724,8 @@
ScopedSignalMask mask({SIGINT, SIGHUP, SIGTERM});
std::unique_lock<stl_mutex> locker(mutex_);
- event_loops_.erase(std::find(event_loops_.begin(), event_loops_.end(), event_loop));
+ event_loops_.erase(
+ std::find(event_loops_.begin(), event_loops_.end(), event_loop));
if (event_loops_.size() == 0u) {
// The last caller restores the original signal handlers.
@@ -836,6 +853,17 @@
UpdateTimingReport();
}
+absl::Span<char> ShmEventLoop::GetWatcherSharedMemory(const Channel *channel) {
+ internal::WatcherState *const watcher_state =
+ static_cast<internal::WatcherState *>(GetWatcherState(channel));
+ return watcher_state->GetSharedMemory();
+}
+
+absl::Span<char> ShmEventLoop::GetShmSenderSharedMemory(
+ const aos::RawSender *sender) const {
+ return static_cast<const internal::ShmSender *>(sender)->GetSharedMemory();
+}
+
pid_t ShmEventLoop::GetTid() { return syscall(SYS_gettid); }
} // namespace aos
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index 52bb338..fa870b8 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -3,6 +3,8 @@
#include <vector>
+#include "absl/types/span.h"
+
#include "aos/events/epoll.h"
#include "aos/events/event_loop.h"
#include "aos/events/event_loop_generated.h"
@@ -68,8 +70,20 @@
int priority() const override { return priority_; }
+ // Returns the epoll loop used to run the event loop.
internal::EPoll *epoll() { return &epoll_; }
+ // Returns the local mapping of the shared memory used by the watcher on the
+ // specified channel. A watcher must be created on this channel before calling
+ // this.
+ absl::Span<char> GetWatcherSharedMemory(const Channel *channel);
+
+ // Returns the local mapping of the shared memory used by the provided Sender
+ template <typename T>
+ absl::Span<char> GetSenderSharedMemory(aos::Sender<T> *sender) const {
+ return GetShmSenderSharedMemory(GetRawSender(sender));
+ }
+
private:
friend class internal::WatcherState;
friend class internal::TimerHandlerState;
@@ -82,6 +96,9 @@
// Returns the TID of the event loop.
pid_t GetTid() override;
+ // Private method to access the shared memory mapping of a ShmSender
+ absl::Span<char> GetShmSenderSharedMemory(const aos::RawSender *sender) const;
+
std::vector<std::function<void()>> on_run_;
int priority_ = 0;
std::string name_;
@@ -90,7 +107,6 @@
internal::EPoll epoll_;
};
-
} // namespace aos
#endif // AOS_EVENTS_SHM_EVENT_LOOP_H_
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index d33f496..949d7db 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -181,6 +181,34 @@
EXPECT_EQ(times.size(), 2u);
}
+// Test GetWatcherSharedMemory in a few basic scenarios.
+TEST(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
+ ShmEventLoopTestFactory factory;
+ auto generic_loop1 = factory.MakePrimary("primary");
+ ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+ const auto channel = configuration::GetChannel(
+ loop1->configuration(), "/test", TestMessage::GetFullyQualifiedName(),
+ loop1->name(), loop1->node());
+
+ // First verify it handles an invalid channel reasonably.
+ EXPECT_DEATH(loop1->GetWatcherSharedMemory(channel),
+ "No watcher found for channel");
+
+ // Then, actually create a watcher, and verify it returns something sane.
+ loop1->MakeWatcher("/test", [](const TestMessage &) {});
+ EXPECT_FALSE(loop1->GetWatcherSharedMemory(channel).empty());
+}
+
+TEST(ShmEventLoopTest, GetSenderSharedMemory) {
+ ShmEventLoopTestFactory factory;
+ auto generic_loop1 = factory.MakePrimary("primary");
+ ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
+
+ // check that GetSenderSharedMemory returns non-null/non-empty memory span
+ auto sender = loop1->MakeSender<TestMessage>("/test");
+ EXPECT_FALSE(loop1->GetSenderSharedMemory(&sender).empty());
+}
+
// TODO(austin): Test that missing a deadline with a timer recovers as expected.
} // namespace testing