Use a read-only mapping for reading from shared memory

This makes it a lot harder for readers to accidentally write.

Change-Id: I29025f37b0767825e57314cc1221510fd9455b55
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index 1ac8656..6013709 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -65,7 +65,7 @@
   return ShmFolder(channel) + channel->type()->str() + ".v3";
 }
 
-void PageFaultData(char *data, size_t size) {
+void PageFaultDataWrite(char *data, size_t size) {
   // This just has to divide the actual page size. Being smaller will make this
   // a bit slower than necessary, but not much. 1024 is a pretty conservative
   // choice (most pages are probably 4096).
@@ -90,6 +90,18 @@
   }
 }
 
+void PageFaultDataRead(const char *data, size_t size) {
+  // This just has to divide the actual page size. Being smaller will make this
+  // a bit slower than necessary, but not much. 1024 is a pretty conservative
+  // choice (most pages are probably 4096).
+  static constexpr size_t kPageSize = 1024;
+  const size_t pages = (size + kPageSize - 1) / kPageSize;
+  for (size_t i = 0; i < pages; ++i) {
+    // We need to ensure there's a readable pagetable entry.
+    __atomic_load_n(&data[i * kPageSize], __ATOMIC_RELAXED);
+  }
+}
+
 ipc_lib::LocklessQueueConfiguration MakeQueueConfiguration(
     const Channel *channel, std::chrono::seconds channel_storage_duration) {
   ipc_lib::LocklessQueueConfiguration config;
@@ -150,33 +162,49 @@
 
     data_ = mmap(NULL, size_, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
     PCHECK(data_ != MAP_FAILED);
+    const_data_ = mmap(NULL, size_, PROT_READ, MAP_SHARED, fd, 0);
+    PCHECK(const_data_ != MAP_FAILED);
     PCHECK(close(fd) == 0);
-    PageFaultData(static_cast<char *>(data_), size_);
+    PageFaultDataWrite(static_cast<char *>(data_), size_);
+    PageFaultDataRead(static_cast<const char *>(const_data_), size_);
 
     ipc_lib::InitializeLocklessQueueMemory(memory(), config_);
   }
 
-  ~MMapedQueue() { PCHECK(munmap(data_, size_) == 0); }
+  ~MMapedQueue() {
+    PCHECK(munmap(data_, size_) == 0);
+    PCHECK(munmap(const_cast<void *>(const_data_), size_) == 0);
+  }
 
   ipc_lib::LocklessQueueMemory *memory() const {
     return reinterpret_cast<ipc_lib::LocklessQueueMemory *>(data_);
   }
 
+  const ipc_lib::LocklessQueueMemory *const_memory() const {
+    return reinterpret_cast<const ipc_lib::LocklessQueueMemory *>(const_data_);
+  }
+
   const ipc_lib::LocklessQueueConfiguration &config() const { return config_; }
 
   ipc_lib::LocklessQueue queue() const {
-    return ipc_lib::LocklessQueue(memory(), memory(), config());
+    return ipc_lib::LocklessQueue(const_memory(), memory(), config());
   }
 
-  absl::Span<char> GetSharedMemory() const {
+  absl::Span<char> GetMutableSharedMemory() const {
     return absl::Span<char>(static_cast<char *>(data_), size_);
   }
 
+  absl::Span<const char> GetConstSharedMemory() const {
+    return absl::Span<const char>(static_cast<const char *>(const_data_),
+                                  size_);
+  }
+
  private:
   const ipc_lib::LocklessQueueConfiguration config_;
 
   size_t size_;
   void *data_;
+  const void *const_data_;
 };
 
 const Node *MaybeMyNode(const Configuration *configuration) {
@@ -310,14 +338,18 @@
     watcher_ = std::nullopt;
   }
 
-  absl::Span<char> GetSharedMemory() const {
-    return lockless_queue_memory_.GetSharedMemory();
+  absl::Span<char> GetMutableSharedMemory() {
+    return lockless_queue_memory_.GetMutableSharedMemory();
   }
 
-  absl::Span<char> GetPrivateMemory() const {
-    // Can't usefully expose this for pinning, because the buffer changes
-    // address for each message. Callers who want to work with that should just
-    // grab the whole shared memory buffer instead.
+  absl::Span<const char> GetConstSharedMemory() const {
+    return lockless_queue_memory_.GetConstSharedMemory();
+  }
+
+  absl::Span<const char> GetPrivateMemory() const {
+    if (pin_data()) {
+      return lockless_queue_memory_.GetConstSharedMemory();
+    }
     return absl::Span<char>(
         const_cast<SimpleShmFetcher *>(this)->data_storage_start(),
         LocklessQueueMessageDataSize(lockless_queue_memory_.memory()));
@@ -457,7 +489,7 @@
     return std::make_pair(false, monotonic_clock::min_time);
   }
 
-  absl::Span<char> GetPrivateMemory() const {
+  absl::Span<const char> GetPrivateMemory() const {
     return simple_shm_fetcher_.GetPrivateMemory();
   }
 
@@ -524,7 +556,7 @@
   }
 
   absl::Span<char> GetSharedMemory() const {
-    return lockless_queue_memory_.GetSharedMemory();
+    return lockless_queue_memory_.GetMutableSharedMemory();
   }
 
   int buffer_index() override { return lockless_queue_sender_.buffer_index(); }
@@ -588,8 +620,8 @@
 
   void UnregisterWakeup() { return simple_shm_fetcher_.UnregisterWakeup(); }
 
-  absl::Span<char> GetSharedMemory() const {
-    return simple_shm_fetcher_.GetSharedMemory();
+  absl::Span<const char> GetSharedMemory() const {
+    return simple_shm_fetcher_.GetConstSharedMemory();
   }
 
  private:
@@ -1016,7 +1048,8 @@
   UpdateTimingReport();
 }
 
-absl::Span<char> ShmEventLoop::GetWatcherSharedMemory(const Channel *channel) {
+absl::Span<const char> ShmEventLoop::GetWatcherSharedMemory(
+    const Channel *channel) {
   ShmWatcherState *const watcher_state =
       static_cast<ShmWatcherState *>(GetWatcherState(channel));
   return watcher_state->GetSharedMemory();
@@ -1034,7 +1067,7 @@
   return static_cast<const ShmSender *>(sender)->GetSharedMemory();
 }
 
-absl::Span<char> ShmEventLoop::GetShmFetcherPrivateMemory(
+absl::Span<const char> ShmEventLoop::GetShmFetcherPrivateMemory(
     const aos::RawFetcher *fetcher) const {
   return static_cast<const ShmFetcher *>(fetcher)->GetPrivateMemory();
 }
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index d7ea5e6..8dabcb5 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -83,7 +83,7 @@
   // Returns the local mapping of the shared memory used by the watcher on the
   // specified channel. A watcher must be created on this channel before calling
   // this.
-  absl::Span<char> GetWatcherSharedMemory(const Channel *channel);
+  absl::Span<const char> GetWatcherSharedMemory(const Channel *channel);
 
   // Returns the local mapping of the shared memory used by the provided Sender.
   template <typename T>
@@ -93,8 +93,12 @@
 
   // Returns the local mapping of the private memory used by the provided
   // Fetcher to hold messages.
+  //
+  // Note that this may be the entire shared memory region held by this fetcher,
+  // depending on its channel's read_method.
   template <typename T>
-  absl::Span<char> GetFetcherPrivateMemory(aos::Fetcher<T> *fetcher) const {
+  absl::Span<const char> GetFetcherPrivateMemory(
+      aos::Fetcher<T> *fetcher) const {
     return GetShmFetcherPrivateMemory(GetRawFetcher(fetcher));
   }
 
@@ -127,7 +131,7 @@
   absl::Span<char> GetShmSenderSharedMemory(const aos::RawSender *sender) const;
 
   // Private method to access the private memory mapping of a ShmFetcher.
-  absl::Span<char> GetShmFetcherPrivateMemory(
+  absl::Span<const char> GetShmFetcherPrivateMemory(
       const aos::RawFetcher *fetcher) const;
 
   std::vector<std::function<void()>> on_run_;
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index cbb28f9..d9a8872 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -99,12 +99,27 @@
   return scheduler == SCHED_FIFO || scheduler == SCHED_RR;
 }
 
+class ShmEventLoopTest : public ::testing::TestWithParam<ReadMethod> {
+ public:
+  ShmEventLoopTest() {
+    if (GetParam() == ReadMethod::PIN) {
+      factory_.PinReads();
+    }
+  }
+
+  ShmEventLoopTestFactory *factory() { return &factory_; }
+
+ private:
+  ShmEventLoopTestFactory factory_;
+};
+
+using ShmEventLoopDeathTest = ShmEventLoopTest;
+
 // Tests that every handler type is realtime and runs.  There are threads
 // involved and it's easy to miss one.
-TEST(ShmEventLoopTest, AllHandlersAreRealtime) {
-  ShmEventLoopTestFactory factory;
-  auto loop = factory.MakePrimary("primary");
-  auto loop2 = factory.Make("loop2");
+TEST_P(ShmEventLoopTest, AllHandlersAreRealtime) {
+  auto loop = factory()->MakePrimary("primary");
+  auto loop2 = factory()->Make("loop2");
 
   loop->SetRuntimeRealtimePriority(1);
 
@@ -114,10 +129,10 @@
   bool did_timer = false;
   bool did_watcher = false;
 
-  auto timer = loop->AddTimer([&did_timer, &factory]() {
+  auto timer = loop->AddTimer([this, &did_timer]() {
     EXPECT_TRUE(IsRealtime());
     did_timer = true;
-    factory.Exit();
+    factory()->Exit();
   });
 
   loop->MakeWatcher("/test", [&did_watcher](const TestMessage &) {
@@ -136,7 +151,7 @@
     msg.Send(builder.Finish());
   });
 
-  factory.Run();
+  factory()->Run();
 
   EXPECT_TRUE(did_onrun);
   EXPECT_TRUE(did_timer);
@@ -145,16 +160,15 @@
 
 // Tests that missing a deadline inside the function still results in PhasedLoop
 // running at the right offset.
-TEST(ShmEventLoopTest, DelayedPhasedLoop) {
-  ShmEventLoopTestFactory factory;
-  auto loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopTest, DelayedPhasedLoop) {
+  auto loop1 = factory()->MakePrimary("primary");
 
   ::std::vector<::aos::monotonic_clock::time_point> times;
 
   constexpr chrono::milliseconds kOffset = chrono::milliseconds(400);
 
   loop1->AddPhasedLoop(
-      [&times, &loop1, &kOffset, &factory](int count) {
+      [this, &times, &loop1, &kOffset](int count) {
         const ::aos::monotonic_clock::time_point monotonic_now =
             loop1->monotonic_now();
 
@@ -179,7 +193,7 @@
 
         times.push_back(loop1->monotonic_now());
         if (times.size() == 2) {
-          factory.Exit();
+          factory()->Exit();
         }
 
         // Now, add a large delay.  This should push us up to 3 cycles.
@@ -187,15 +201,14 @@
       },
       chrono::seconds(1), kOffset);
 
-  factory.Run();
+  factory()->Run();
 
   EXPECT_EQ(times.size(), 2u);
 }
 
 // Test GetWatcherSharedMemory in a few basic scenarios.
-TEST(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
-  ShmEventLoopTestFactory factory;
-  auto generic_loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopDeathTest, GetWatcherSharedMemory) {
+  auto generic_loop1 = factory()->MakePrimary("primary");
   ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
   const auto channel = configuration::GetChannel(
       loop1->configuration(), "/test", TestMessage::GetFullyQualifiedName(),
@@ -206,31 +219,85 @@
                "No watcher found for channel");
 
   // Then, actually create a watcher, and verify it returns something sane.
-  loop1->MakeWatcher("/test", [](const TestMessage &) {});
-  EXPECT_FALSE(loop1->GetWatcherSharedMemory(channel).empty());
+  absl::Span<const char> shared_memory;
+  bool ran = false;
+  loop1->MakeWatcher("/test", [this, &shared_memory,
+                               &ran](const TestMessage &message) {
+    EXPECT_FALSE(ran);
+    ran = true;
+    // If we're using pinning, then we can verify that the message is actually
+    // in the specified region.
+    if (GetParam() == ReadMethod::PIN) {
+      EXPECT_GE(reinterpret_cast<const char *>(&message),
+                shared_memory.begin());
+      EXPECT_LT(reinterpret_cast<const char *>(&message), shared_memory.end());
+    }
+    factory()->Exit();
+  });
+  shared_memory = loop1->GetWatcherSharedMemory(channel);
+  EXPECT_FALSE(shared_memory.empty());
+
+  auto loop2 = factory()->Make("sender");
+  auto sender = loop2->MakeSender<TestMessage>("/test");
+  generic_loop1->OnRun([&sender]() {
+    auto builder = sender.MakeBuilder();
+    TestMessage::Builder test_builder(*builder.fbb());
+    test_builder.add_value(1);
+    CHECK(builder.Send(test_builder.Finish()));
+  });
+  factory()->Run();
+  EXPECT_TRUE(ran);
 }
 
-TEST(ShmEventLoopTest, GetSenderSharedMemory) {
-  ShmEventLoopTestFactory factory;
-  auto generic_loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopTest, GetSenderSharedMemory) {
+  auto generic_loop1 = factory()->MakePrimary("primary");
   ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
 
-  // check that GetSenderSharedMemory returns non-null/non-empty memory span.
+  // Check that GetSenderSharedMemory returns non-null/non-empty memory span.
   auto sender = loop1->MakeSender<TestMessage>("/test");
-  EXPECT_FALSE(loop1->GetSenderSharedMemory(&sender).empty());
+  const absl::Span<char> shared_memory = loop1->GetSenderSharedMemory(&sender);
+  EXPECT_FALSE(shared_memory.empty());
+
+  auto builder = sender.MakeBuilder();
+  uint8_t *buffer;
+  builder.fbb()->CreateUninitializedVector(5, 1, &buffer);
+  EXPECT_GE(reinterpret_cast<char *>(buffer), shared_memory.begin());
+  EXPECT_LT(reinterpret_cast<char *>(buffer), shared_memory.end());
 }
 
-TEST(ShmEventLoopTest, GetFetcherPrivateMemory) {
-  ShmEventLoopTestFactory factory;
-  auto generic_loop1 = factory.MakePrimary("primary");
+TEST_P(ShmEventLoopTest, GetFetcherPrivateMemory) {
+  auto generic_loop1 = factory()->MakePrimary("primary");
   ShmEventLoop *const loop1 = static_cast<ShmEventLoop *>(generic_loop1.get());
 
-  // check that GetFetcherPrivateMemory returns non-null/non-empty memory span.
+  // Check that GetFetcherPrivateMemory returns non-null/non-empty memory span.
   auto fetcher = loop1->MakeFetcher<TestMessage>("/test");
-  EXPECT_FALSE(loop1->GetFetcherPrivateMemory(&fetcher).empty());
+  const auto private_memory = loop1->GetFetcherPrivateMemory(&fetcher);
+  EXPECT_FALSE(private_memory.empty());
+
+  auto loop2 = factory()->Make("sender");
+  auto sender = loop2->MakeSender<TestMessage>("/test");
+  {
+    auto builder = sender.MakeBuilder();
+    TestMessage::Builder test_builder(*builder.fbb());
+    test_builder.add_value(1);
+    CHECK(builder.Send(test_builder.Finish()));
+  }
+
+  ASSERT_TRUE(fetcher.Fetch());
+  EXPECT_GE(fetcher.context().data, private_memory.begin());
+  EXPECT_LT(fetcher.context().data, private_memory.end());
 }
 
 // TODO(austin): Test that missing a deadline with a timer recovers as expected.
 
+INSTANTIATE_TEST_CASE_P(ShmEventLoopCopyTest, ShmEventLoopTest,
+                        ::testing::Values(ReadMethod::COPY));
+INSTANTIATE_TEST_CASE_P(ShmEventLoopPinTest, ShmEventLoopTest,
+                        ::testing::Values(ReadMethod::PIN));
+INSTANTIATE_TEST_CASE_P(ShmEventLoopCopyDeathTest, ShmEventLoopDeathTest,
+                        ::testing::Values(ReadMethod::COPY));
+INSTANTIATE_TEST_CASE_P(ShmEventLoopPinDeathTest, ShmEventLoopDeathTest,
+                        ::testing::Values(ReadMethod::PIN));
+
 }  // namespace testing
 }  // namespace aos