Create an ExitHandle interface

This is a safer alternative to our existing pattern of capturing the
pointer in a lambda, and it's more Rust-friendly.

Change-Id: Id0f3fe5a2badcf1a4ae871d0cc7c3ff48d1c22f8
Signed-off-by: Brian Silverman <bsilver16384@gmail.com>
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index 3ecd93f..3ceb240 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -862,6 +862,28 @@
   absl::btree_set<const Channel *> taken_watchers_, taken_senders_;
 };
 
+// Interface for terminating execution of an EventLoop.
+//
+// Prefer this over binding a lambda to an Exit() method when passing ownership
+// in complicated ways because implementations should have assertions to catch
+// it outliving the object it's referring to, instead of having a
+// use-after-free.
+//
+// This is not exposed by EventLoop directly because different EventLoop
+// implementations provide this functionality at different scopes, or possibly
+// not at all.
+class ExitHandle {
+ public:
+  ExitHandle() = default;
+  virtual ~ExitHandle() = default;
+
+  // Exits some set of event loops. Details depend on the implementation.
+  //
+  // This means no more events will be processed, but any currently being
+  // processed will finish.
+  virtual void Exit() = 0;
+};
+
 }  // namespace aos
 
 #include "aos/events/event_loop_tmpl.h"  // IWYU pragma: export
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index a43de78..467353e 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -509,6 +509,22 @@
   SimpleShmFetcher simple_shm_fetcher_;
 };
 
+class ShmExitHandle : public ExitHandle {
+ public:
+  ShmExitHandle(ShmEventLoop *event_loop) : event_loop_(event_loop) {
+    ++event_loop_->exit_handle_count_;
+  }
+  ~ShmExitHandle() override {
+    CHECK_GT(event_loop_->exit_handle_count_, 0);
+    --event_loop_->exit_handle_count_;
+  }
+
+  void Exit() override { event_loop_->Exit(); }
+
+ private:
+  ShmEventLoop *const event_loop_;
+};
+
 class ShmSender : public RawSender {
  public:
   explicit ShmSender(std::string_view shm_base, EventLoop *event_loop,
@@ -1171,6 +1187,10 @@
 
 void ShmEventLoop::Exit() { epoll_.Quit(); }
 
+std::unique_ptr<ExitHandle> ShmEventLoop::MakeExitHandle() {
+  return std::make_unique<ShmExitHandle>(this);
+}
+
 ShmEventLoop::~ShmEventLoop() {
   CheckCurrentThread();
   // Force everything with a registered fd with epoll to be destroyed now.
@@ -1179,6 +1199,8 @@
   watchers_.clear();
 
   CHECK(!is_running()) << ": ShmEventLoop destroyed while running";
+  CHECK_EQ(0, exit_handle_count_)
+      << ": All ExitHandles must be destroyed before the ShmEventLoop";
 }
 
 void ShmEventLoop::SetRuntimeRealtimePriority(int priority) {
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index e51f21b..425d334 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -22,6 +22,7 @@
 class ShmSender;
 class SimpleShmFetcher;
 class ShmFetcher;
+class ShmExitHandle;
 
 }  // namespace shm_event_loop_internal
 
@@ -48,6 +49,8 @@
   // Exits the event loop.  Async safe.
   void Exit();
 
+  std::unique_ptr<ExitHandle> MakeExitHandle();
+
   aos::monotonic_clock::time_point monotonic_now() const override {
     return aos::monotonic_clock::now();
   }
@@ -138,6 +141,7 @@
   friend class shm_event_loop_internal::ShmSender;
   friend class shm_event_loop_internal::SimpleShmFetcher;
   friend class shm_event_loop_internal::ShmFetcher;
+  friend class shm_event_loop_internal::ShmExitHandle;
 
   using EventLoop::SendTimingReport;
 
@@ -165,6 +169,8 @@
 
   const UUID boot_uuid_;
 
+  int exit_handle_count_ = 0;
+
   // Capture the --shm_base flag at construction time.  This makes it much
   // easier to make different shared memory regions for doing things like
   // multi-node tests.
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index f4107b0..f119479 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -425,6 +425,14 @@
   TestNextMessageNotAvailableNoRun(true);
 }
 
+// Test that an ExitHandle outliving its EventLoop is caught.
+TEST_P(ShmEventLoopDeathTest, ExitHandleOutlivesEventLoop) {
+  auto loop1 = factory()->MakePrimary("loop1");
+  auto exit_handle = static_cast<ShmEventLoop *>(loop1.get())->MakeExitHandle();
+  EXPECT_DEATH(loop1.reset(),
+               "All ExitHandles must be destroyed before the ShmEventLoop");
+}
+
 // TODO(austin): Test that missing a deadline with a timer recovers as expected.
 
 INSTANTIATE_TEST_SUITE_P(ShmEventLoopCopyTest, ShmEventLoopTest,
diff --git a/aos/events/simulated_event_loop.cc b/aos/events/simulated_event_loop.cc
index e0ed7a0..88504a5 100644
--- a/aos/events/simulated_event_loop.cc
+++ b/aos/events/simulated_event_loop.cc
@@ -156,6 +156,23 @@
   SimulatedChannel *simulated_channel_ = nullptr;
 };
 
+class SimulatedFactoryExitHandle : public ExitHandle {
+ public:
+  SimulatedFactoryExitHandle(SimulatedEventLoopFactory *factory)
+      : factory_(factory) {
+    ++factory_->exit_handle_count_;
+  }
+  ~SimulatedFactoryExitHandle() override {
+    CHECK_GT(factory_->exit_handle_count_, 0);
+    --factory_->exit_handle_count_;
+  }
+
+  void Exit() override { factory_->Exit(); }
+
+ private:
+  SimulatedEventLoopFactory *const factory_;
+};
+
 class SimulatedChannel {
  public:
   explicit SimulatedChannel(const Channel *channel,
@@ -1273,7 +1290,10 @@
   }
 }
 
-SimulatedEventLoopFactory::~SimulatedEventLoopFactory() {}
+SimulatedEventLoopFactory::~SimulatedEventLoopFactory() {
+  CHECK_EQ(0, exit_handle_count_)
+      << ": All ExitHandles must be destroyed before the factory";
+}
 
 NodeEventLoopFactory *SimulatedEventLoopFactory::GetNodeEventLoopFactory(
     std::string_view node) {
@@ -1455,6 +1475,10 @@
 
 void SimulatedEventLoopFactory::Exit() { scheduler_scheduler_.Exit(); }
 
+std::unique_ptr<ExitHandle> SimulatedEventLoopFactory::MakeExitHandle() {
+  return std::make_unique<SimulatedFactoryExitHandle>(this);
+}
+
 void SimulatedEventLoopFactory::DisableForwarding(const Channel *channel) {
   CHECK(bridge_) << ": Can't disable forwarding without a message bridge.";
   bridge_->DisableForwarding(channel);
diff --git a/aos/events/simulated_event_loop.h b/aos/events/simulated_event_loop.h
index 9200044..7b6eaa7 100644
--- a/aos/events/simulated_event_loop.h
+++ b/aos/events/simulated_event_loop.h
@@ -27,6 +27,7 @@
 
 class NodeEventLoopFactory;
 class SimulatedEventLoop;
+class SimulatedFactoryExitHandle;
 namespace message_bridge {
 class SimulatedMessageBridge;
 }
@@ -94,6 +95,8 @@
   // loop handler.
   void Exit();
 
+  std::unique_ptr<ExitHandle> MakeExitHandle();
+
   const std::vector<const Node *> &nodes() const { return nodes_; }
 
   // Sets the simulated send delay for all messages sent within a single node.
@@ -135,6 +138,7 @@
 
  private:
   friend class NodeEventLoopFactory;
+  friend class SimulatedFactoryExitHandle;
 
   const Configuration *const configuration_;
   EventSchedulerScheduler scheduler_scheduler_;
@@ -147,6 +151,8 @@
   std::vector<std::unique_ptr<NodeEventLoopFactory>> node_factories_;
 
   std::vector<const Node *> nodes_;
+
+  int exit_handle_count_ = 0;
 };
 
 // This class holds all the state required to be a single node.
diff --git a/aos/events/simulated_event_loop_test.cc b/aos/events/simulated_event_loop_test.cc
index 1a35b69..eddacb1 100644
--- a/aos/events/simulated_event_loop_test.cc
+++ b/aos/events/simulated_event_loop_test.cc
@@ -2076,6 +2076,19 @@
   EXPECT_DEATH({ factory.RunFor(dt * 2); }, "Event loop");
 }
 
+// Test that an ExitHandle outliving its factory is caught.
+TEST(SimulatedEventLoopDeathTest, ExitHandleOutlivesFactory) {
+  aos::FlatbufferDetachedBuffer<aos::Configuration> config =
+      aos::configuration::ReadConfig(
+          ArtifactPath("aos/events/multinode_pingpong_test_split_config.json"));
+  auto factory = std::make_unique<SimulatedEventLoopFactory>(&config.message());
+  NodeEventLoopFactory *pi1 = factory->GetNodeEventLoopFactory("pi1");
+  std::unique_ptr<EventLoop> loop = pi1->MakeEventLoop("foo");
+  auto exit_handle = factory->MakeExitHandle();
+  EXPECT_DEATH(factory.reset(),
+               "All ExitHandles must be destroyed before the factory");
+}
+
 // Tests that messages don't survive a reboot of a node.
 TEST(SimulatedEventLoopTest, ChannelClearedOnReboot) {
   aos::FlatbufferDetachedBuffer<aos::Configuration> config =