Use a read-only mapping for reading from shared memory
This makes it a lot harder for readers to accidentally write.
Change-Id: I29025f37b0767825e57314cc1221510fd9455b55
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index 1ac8656..6013709 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -65,7 +65,7 @@
return ShmFolder(channel) + channel->type()->str() + ".v3";
}
-void PageFaultData(char *data, size_t size) {
+void PageFaultDataWrite(char *data, size_t size) {
// This just has to divide the actual page size. Being smaller will make this
// a bit slower than necessary, but not much. 1024 is a pretty conservative
// choice (most pages are probably 4096).
@@ -90,6 +90,18 @@
}
}
+void PageFaultDataRead(const char *data, size_t size) {
+ // This just has to divide the actual page size. Being smaller will make this
+ // a bit slower than necessary, but not much. 1024 is a pretty conservative
+ // choice (most pages are probably 4096).
+ static constexpr size_t kPageSize = 1024;
+ const size_t pages = (size + kPageSize - 1) / kPageSize;
+ for (size_t i = 0; i < pages; ++i) {
+ // We need to ensure there's a readable pagetable entry.
+ __atomic_load_n(&data[i * kPageSize], __ATOMIC_RELAXED);
+ }
+}
+
ipc_lib::LocklessQueueConfiguration MakeQueueConfiguration(
const Channel *channel, std::chrono::seconds channel_storage_duration) {
ipc_lib::LocklessQueueConfiguration config;
@@ -150,33 +162,49 @@
data_ = mmap(NULL, size_, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
PCHECK(data_ != MAP_FAILED);
+ const_data_ = mmap(NULL, size_, PROT_READ, MAP_SHARED, fd, 0);
+ PCHECK(const_data_ != MAP_FAILED);
PCHECK(close(fd) == 0);
- PageFaultData(static_cast<char *>(data_), size_);
+ PageFaultDataWrite(static_cast<char *>(data_), size_);
+ PageFaultDataRead(static_cast<const char *>(const_data_), size_);
ipc_lib::InitializeLocklessQueueMemory(memory(), config_);
}
- ~MMapedQueue() { PCHECK(munmap(data_, size_) == 0); }
+ ~MMapedQueue() {
+ PCHECK(munmap(data_, size_) == 0);
+ PCHECK(munmap(const_cast<void *>(const_data_), size_) == 0);
+ }
ipc_lib::LocklessQueueMemory *memory() const {
return reinterpret_cast<ipc_lib::LocklessQueueMemory *>(data_);
}
+ const ipc_lib::LocklessQueueMemory *const_memory() const {
+ return reinterpret_cast<const ipc_lib::LocklessQueueMemory *>(const_data_);
+ }
+
const ipc_lib::LocklessQueueConfiguration &config() const { return config_; }
ipc_lib::LocklessQueue queue() const {
- return ipc_lib::LocklessQueue(memory(), memory(), config());
+ return ipc_lib::LocklessQueue(const_memory(), memory(), config());
}
- absl::Span<char> GetSharedMemory() const {
+ absl::Span<char> GetMutableSharedMemory() const {
return absl::Span<char>(static_cast<char *>(data_), size_);
}
+ absl::Span<const char> GetConstSharedMemory() const {
+ return absl::Span<const char>(static_cast<const char *>(const_data_),
+ size_);
+ }
+
private:
const ipc_lib::LocklessQueueConfiguration config_;
size_t size_;
void *data_;
+ const void *const_data_;
};
const Node *MaybeMyNode(const Configuration *configuration) {
@@ -310,14 +338,18 @@
watcher_ = std::nullopt;
}
- absl::Span<char> GetSharedMemory() const {
- return lockless_queue_memory_.GetSharedMemory();
+ absl::Span<char> GetMutableSharedMemory() {
+ return lockless_queue_memory_.GetMutableSharedMemory();
}
- absl::Span<char> GetPrivateMemory() const {
- // Can't usefully expose this for pinning, because the buffer changes
- // address for each message. Callers who want to work with that should just
- // grab the whole shared memory buffer instead.
+ absl::Span<const char> GetConstSharedMemory() const {
+ return lockless_queue_memory_.GetConstSharedMemory();
+ }
+
+ absl::Span<const char> GetPrivateMemory() const {
+ if (pin_data()) {
+ return lockless_queue_memory_.GetConstSharedMemory();
+ }
return absl::Span<char>(
const_cast<SimpleShmFetcher *>(this)->data_storage_start(),
LocklessQueueMessageDataSize(lockless_queue_memory_.memory()));
@@ -457,7 +489,7 @@
return std::make_pair(false, monotonic_clock::min_time);
}
- absl::Span<char> GetPrivateMemory() const {
+ absl::Span<const char> GetPrivateMemory() const {
return simple_shm_fetcher_.GetPrivateMemory();
}
@@ -524,7 +556,7 @@
}
absl::Span<char> GetSharedMemory() const {
- return lockless_queue_memory_.GetSharedMemory();
+ return lockless_queue_memory_.GetMutableSharedMemory();
}
int buffer_index() override { return lockless_queue_sender_.buffer_index(); }
@@ -588,8 +620,8 @@
void UnregisterWakeup() { return simple_shm_fetcher_.UnregisterWakeup(); }
- absl::Span<char> GetSharedMemory() const {
- return simple_shm_fetcher_.GetSharedMemory();
+ absl::Span<const char> GetSharedMemory() const {
+ return simple_shm_fetcher_.GetConstSharedMemory();
}
private:
@@ -1016,7 +1048,8 @@
UpdateTimingReport();
}
-absl::Span<char> ShmEventLoop::GetWatcherSharedMemory(const Channel *channel) {
+absl::Span<const char> ShmEventLoop::GetWatcherSharedMemory(
+ const Channel *channel) {
ShmWatcherState *const watcher_state =
static_cast<ShmWatcherState *>(GetWatcherState(channel));
return watcher_state->GetSharedMemory();
@@ -1034,7 +1067,7 @@
return static_cast<const ShmSender *>(sender)->GetSharedMemory();
}
-absl::Span<char> ShmEventLoop::GetShmFetcherPrivateMemory(
+absl::Span<const char> ShmEventLoop::GetShmFetcherPrivateMemory(
const aos::RawFetcher *fetcher) const {
return static_cast<const ShmFetcher *>(fetcher)->GetPrivateMemory();
}
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index d7ea5e6..8dabcb5 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -83,7 +83,7 @@
// Returns the local mapping of the shared memory used by the watcher on the
// specified channel. A watcher must be created on this channel before calling
// this.
- absl::Span<char> GetWatcherSharedMemory(const Channel *channel);
+ absl::Span<const char> GetWatcherSharedMemory(const Channel *channel);
// Returns the local mapping of the shared memory used by the provided Sender.
template <typename T>
@@ -93,8 +93,12 @@
// Returns the local mapping of the private memory used by the provided
// Fetcher to hold messages.
+ //
+ // Note that this may be the entire shared memory region held by this fetcher,
+ // depending on its channel's read_method.
template <typename T>
- absl::Span<char> GetFetcherPrivateMemory(aos::Fetcher<T> *fetcher) const {
+ absl::Span<const char> GetFetcherPrivateMemory(
+ aos::Fetcher<T> *fetcher) const {
return GetShmFetcherPrivateMemory(GetRawFetcher(fetcher));
}
@@ -127,7 +131,7 @@
absl::Span<char> GetShmSenderSharedMemory(const aos::RawSender *sender) const;
// Private method to access the private memory mapping of a ShmFetcher.
- absl::Span<char> GetShmFetcherPrivateMemory(
+ absl::Span<const char> GetShmFetcherPrivateMemory(
const aos::RawFetcher *fetcher) const;
std::vector<std::function<void()>> on_run_;
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index cbb28f9..d9a8872 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -99,12 +99,27 @@
return scheduler == SCHED_FIFO || scheduler == SCHED_RR;
}
+class ShmEventLoopTest : public ::testing::TestWithParam<ReadMethod> {
+ public:
+ ShmEventLoopTest() {
+ if (GetParam() == ReadMethod::PIN) {
+ factory_.PinReads();
+ }
+ }
+
+ ShmEventLoopTestFactory *factory() { return &factory_; }
+
+ private:
+ ShmEventLoopTestFactory factory_;
+};
+
+using ShmEventLoopDeathTest = ShmEventLoopTest;
+
// Tests that every handler type is realtime and runs. There are threads
// involved and it's easy to miss one.
-TEST(ShmEventLoopTest, AllHandlersAreRealtime) {
- ShmEventLoopTestFactory factory;
- auto loop = factory.MakePrimary("primary");
- auto loop2 = factory.Make("loop2");
+TEST_P(ShmEventLoopTest, AllHandlersAreRealtime) {
+ auto loop = factory()->MakePrimary("primary");
+ auto loop2 = factory()->Make("loop2");
loop->SetRuntimeRealtimePriority(1);
@@ -114,10 +129,10 @@
bool did_timer = false;
bool did_watcher = false;
- auto timer = loop->AddTimer([&did_timer, &factory]() {
+ auto timer = loop->AddTimer([this, &did_timer]() {
EXPECT_TRUE(IsRealtime());
did_timer = true;
- factory.Exit();
+ factory()->Exit();
});
loop->MakeWatcher("/test", [&did_watcher](const TestMessage &) {
@@ -136,7 +151,7 @@
msg.Send(builder.Finish());
});
- factory.Run();
+ factory()->Run();
EXPECT_TRUE(did_onrun);
EXPECT_TRUE(did_timer);
@@ -145,16 +160,15 @@
// Tests that missing a deadline inside the function still results in PhasedLoop
// running at the right offset.
-TEST(ShmEventLoopTest, DelayedPhasedLoop) {
- ShmEventLoopTestFactory factory;
- auto loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopTest, DelayedPhasedLoop) {
+ auto loop1 = factory()->MakePrimary("primary");
::std::vector<::aos::monotonic_clock::time_point> times;
constexpr chrono::milliseconds kOffset = chrono::milliseconds(400);
loop1->AddPhasedLoop(
- [×, &loop1, &kOffset, &factory](int count) {
+ [this, ×, &loop1, &kOffset](int count) {
const ::aos::monotonic_clock::time_point monotonic_now =
loop1->monotonic_now();
@@ -179,7 +193,7 @@
times.push_back(loop1->monotonic_now());
if (times.size() == 2) {
- factory.Exit();
+ factory()->Exit();
}
// Now, add a large delay. This should push us up to 3 cycles.
@@ -187,15 +201,14 @@
},
chrono::seconds(1), kOffset);
- factory.Run();
+ factory()->Run();
EXPECT_EQ(times.size(), 2u);
}
// Test GetWatcherSharedMemory in a few basic scenarios.
-TEST(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
- ShmEventLoopTestFactory factory;
- auto generic_loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
+ auto generic_loop1 = factory()->MakePrimary("primary");
ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
const auto channel = configuration::GetChannel(
loop1->configuration(), "/test", TestMessage::GetFullyQualifiedName(),
@@ -206,31 +219,85 @@
"No watcher found for channel");
// Then, actually create a watcher, and verify it returns something sane.
- loop1->MakeWatcher("/test", [](const TestMessage &) {});
- EXPECT_FALSE(loop1->GetWatcherSharedMemory(channel).empty());
+ absl::Span<const char> shared_memory;
+ bool ran = false;
+ loop1->MakeWatcher("/test", [this, &shared_memory,
+ &ran](const TestMessage &message) {
+ EXPECT_FALSE(ran);
+ ran = true;
+ // If we're using pinning, then we can verify that the message is actually
+ // in the specified region.
+ if (GetParam() == ReadMethod::PIN) {
+ EXPECT_GE(reinterpret_cast<const char *>(&message),
+ shared_memory.begin());
+ EXPECT_LT(reinterpret_cast<const char *>(&message), shared_memory.end());
+ }
+ factory()->Exit();
+ });
+ shared_memory = loop1->GetWatcherSharedMemory(channel);
+ EXPECT_FALSE(shared_memory.empty());
+
+ auto loop2 = factory()->Make("sender");
+ auto sender = loop2->MakeSender<TestMessage>("/test");
+ generic_loop1->OnRun([&sender]() {
+ auto builder = sender.MakeBuilder();
+ TestMessage::Builder test_builder(*builder.fbb());
+ test_builder.add_value(1);
+ CHECK(builder.Send(test_builder.Finish()));
+ });
+ factory()->Run();
+ EXPECT_TRUE(ran);
}
-TEST(ShmEventLoopTest, GetSenderSharedMemory) {
- ShmEventLoopTestFactory factory;
- auto generic_loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopTest, GetSenderSharedMemory) {
+ auto generic_loop1 = factory()->MakePrimary("primary");
ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
- // check that GetSenderSharedMemory returns non-null/non-empty memory span.
+ // Check that GetSenderSharedMemory returns non-null/non-empty memory span.
auto sender = loop1->MakeSender<TestMessage>("/test");
- EXPECT_FALSE(loop1->GetSenderSharedMemory(&sender).empty());
+ const absl::Span<char> shared_memory = loop1->GetSenderSharedMemory(&sender);
+ EXPECT_FALSE(shared_memory.empty());
+
+ auto builder = sender.MakeBuilder();
+ uint8_t *buffer;
+ builder.fbb()->CreateUninitializedVector(5, 1, &buffer);
+ EXPECT_GE(reinterpret_cast<char *>(buffer), shared_memory.begin());
+ EXPECT_LT(reinterpret_cast<char *>(buffer), shared_memory.end());
}
-TEST(ShmEventLoopTest, GetFetcherPrivateMemory) {
- ShmEventLoopTestFactory factory;
- auto generic_loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopTest, GetFetcherPrivateMemory) {
+ auto generic_loop1 = factory()->MakePrimary("primary");
ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
- // check that GetFetcherPrivateMemory returns non-null/non-empty memory span.
+ // Check that GetFetcherPrivateMemory returns non-null/non-empty memory span.
auto fetcher = loop1->MakeFetcher<TestMessage>("/test");
- EXPECT_FALSE(loop1->GetFetcherPrivateMemory(&fetcher).empty());
+ const auto private_memory = loop1->GetFetcherPrivateMemory(&fetcher);
+ EXPECT_FALSE(private_memory.empty());
+
+ auto loop2 = factory()->Make("sender");
+ auto sender = loop2->MakeSender<TestMessage>("/test");
+ {
+ auto builder = sender.MakeBuilder();
+ TestMessage::Builder test_builder(*builder.fbb());
+ test_builder.add_value(1);
+ CHECK(builder.Send(test_builder.Finish()));
+ }
+
+ ASSERT_TRUE(fetcher.Fetch());
+ EXPECT_GE(fetcher.context().data, private_memory.begin());
+ EXPECT_LT(fetcher.context().data, private_memory.end());
}
// TODO(austin): Test that missing a deadline with a timer recovers as expected.
+INSTANTIATE_TEST_CASE_P(ShmEventLoopCopyTest, ShmEventLoopTest,
+ ::testing::Values(ReadMethod::COPY));
+INSTANTIATE_TEST_CASE_P(ShmEventLoopPinTest, ShmEventLoopTest,
+ ::testing::Values(ReadMethod::PIN));
+INSTANTIATE_TEST_CASE_P(ShmEventLoopCopyDeathTest, ShmEventLoopDeathTest,
+ ::testing::Values(ReadMethod::COPY));
+INSTANTIATE_TEST_CASE_P(ShmEventLoopPinDeathTest, ShmEventLoopDeathTest,
+ ::testing::Values(ReadMethod::PIN));
+
} // namespace testing
} // namespace aos