Create an ExitHandle interface
This is a safer alternative to our existing pattern of capturing the
pointer in a lambda, and it's more Rust-friendly.
Change-Id: Id0f3fe5a2badcf1a4ae871d0cc7c3ff48d1c22f8
Signed-off-by: Brian Silverman <bsilver16384@gmail.com>
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index 3ecd93f..3ceb240 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -862,6 +862,28 @@
absl::btree_set<const Channel *> taken_watchers_, taken_senders_;
};
+// Interface for terminating execution of an EventLoop.
+//
+// Prefer this over binding a lambda to an Exit() method when passing ownership
+// in complicated ways because implementations should have assertions to catch
+// it outliving the object it's referring to, instead of having a
+// use-after-free.
+//
+// This is not exposed by EventLoop directly because different EventLoop
+// implementations provide this functionality at different scopes, or possibly
+// not at all.
+class ExitHandle {
+ public:
+ ExitHandle() = default;
+ virtual ~ExitHandle() = default;
+
+ // Exits some set of event loops. Details depend on the implementation.
+ //
+ // This means no more events will be processed, but any currently being
+ // processed will finish.
+ virtual void Exit() = 0;
+};
+
} // namespace aos
#include "aos/events/event_loop_tmpl.h" // IWYU pragma: export
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index a43de78..467353e 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -509,6 +509,22 @@
SimpleShmFetcher simple_shm_fetcher_;
};
+class ShmExitHandle : public ExitHandle {
+ public:
+ ShmExitHandle(ShmEventLoop *event_loop) : event_loop_(event_loop) {
+ ++event_loop_->exit_handle_count_;
+ }
+ ~ShmExitHandle() override {
+ CHECK_GT(event_loop_->exit_handle_count_, 0);
+ --event_loop_->exit_handle_count_;
+ }
+
+ void Exit() override { event_loop_->Exit(); }
+
+ private:
+ ShmEventLoop *const event_loop_;
+};
+
class ShmSender : public RawSender {
public:
explicit ShmSender(std::string_view shm_base, EventLoop *event_loop,
@@ -1171,6 +1187,10 @@
void ShmEventLoop::Exit() { epoll_.Quit(); }
+std::unique_ptr<ExitHandle> ShmEventLoop::MakeExitHandle() {
+ return std::make_unique<ShmExitHandle>(this);
+}
+
ShmEventLoop::~ShmEventLoop() {
CheckCurrentThread();
// Force everything with a registered fd with epoll to be destroyed now.
@@ -1179,6 +1199,8 @@
watchers_.clear();
CHECK(!is_running()) << ": ShmEventLoop destroyed while running";
+ CHECK_EQ(0, exit_handle_count_)
+ << ": All ExitHandles must be destroyed before the ShmEventLoop";
}
void ShmEventLoop::SetRuntimeRealtimePriority(int priority) {
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index e51f21b..425d334 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -22,6 +22,7 @@
class ShmSender;
class SimpleShmFetcher;
class ShmFetcher;
+class ShmExitHandle;
} // namespace shm_event_loop_internal
@@ -48,6 +49,8 @@
// Exits the event loop. Async safe.
void Exit();
+ std::unique_ptr<ExitHandle> MakeExitHandle();
+
aos::monotonic_clock::time_point monotonic_now() const override {
return aos::monotonic_clock::now();
}
@@ -138,6 +141,7 @@
friend class shm_event_loop_internal::ShmSender;
friend class shm_event_loop_internal::SimpleShmFetcher;
friend class shm_event_loop_internal::ShmFetcher;
+ friend class shm_event_loop_internal::ShmExitHandle;
using EventLoop::SendTimingReport;
@@ -165,6 +169,8 @@
const UUID boot_uuid_;
+ int exit_handle_count_ = 0;
+
// Capture the --shm_base flag at construction time. This makes it much
// easier to make different shared memory regions for doing things like
// multi-node tests.
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index f4107b0..f119479 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -425,6 +425,14 @@
TestNextMessageNotAvailableNoRun(true);
}
+// Test that an ExitHandle outliving its EventLoop is caught.
+TEST_P(ShmEventLoopDeathTest, ExitHandleOutlivesEventLoop) {
+ auto loop1 = factory()->MakePrimary("loop1");
+ auto exit_handle = static_cast<ShmEventLoop *>(loop1.get())->MakeExitHandle();
+ EXPECT_DEATH(loop1.reset(),
+ "All ExitHandles must be destroyed before the ShmEventLoop");
+}
+
// TODO(austin): Test that missing a deadline with a timer recovers as expected.
INSTANTIATE_TEST_SUITE_P(ShmEventLoopCopyTest, ShmEventLoopTest,
diff --git a/aos/events/simulated_event_loop.cc b/aos/events/simulated_event_loop.cc
index e0ed7a0..88504a5 100644
--- a/aos/events/simulated_event_loop.cc
+++ b/aos/events/simulated_event_loop.cc
@@ -156,6 +156,23 @@
SimulatedChannel *simulated_channel_ = nullptr;
};
+class SimulatedFactoryExitHandle : public ExitHandle {
+ public:
+ SimulatedFactoryExitHandle(SimulatedEventLoopFactory *factory)
+ : factory_(factory) {
+ ++factory_->exit_handle_count_;
+ }
+ ~SimulatedFactoryExitHandle() override {
+ CHECK_GT(factory_->exit_handle_count_, 0);
+ --factory_->exit_handle_count_;
+ }
+
+ void Exit() override { factory_->Exit(); }
+
+ private:
+ SimulatedEventLoopFactory *const factory_;
+};
+
class SimulatedChannel {
public:
explicit SimulatedChannel(const Channel *channel,
@@ -1273,7 +1290,10 @@
}
}
-SimulatedEventLoopFactory::~SimulatedEventLoopFactory() {}
+SimulatedEventLoopFactory::~SimulatedEventLoopFactory() {
+ CHECK_EQ(0, exit_handle_count_)
+ << ": All ExitHandles must be destroyed before the factory";
+}
NodeEventLoopFactory *SimulatedEventLoopFactory::GetNodeEventLoopFactory(
std::string_view node) {
@@ -1455,6 +1475,10 @@
void SimulatedEventLoopFactory::Exit() { scheduler_scheduler_.Exit(); }
+std::unique_ptr<ExitHandle> SimulatedEventLoopFactory::MakeExitHandle() {
+ return std::make_unique<SimulatedFactoryExitHandle>(this);
+}
+
void SimulatedEventLoopFactory::DisableForwarding(const Channel *channel) {
CHECK(bridge_) << ": Can't disable forwarding without a message bridge.";
bridge_->DisableForwarding(channel);
diff --git a/aos/events/simulated_event_loop.h b/aos/events/simulated_event_loop.h
index 9200044..7b6eaa7 100644
--- a/aos/events/simulated_event_loop.h
+++ b/aos/events/simulated_event_loop.h
@@ -27,6 +27,7 @@
class NodeEventLoopFactory;
class SimulatedEventLoop;
+class SimulatedFactoryExitHandle;
namespace message_bridge {
class SimulatedMessageBridge;
}
@@ -94,6 +95,8 @@
// loop handler.
void Exit();
+ std::unique_ptr<ExitHandle> MakeExitHandle();
+
const std::vector<const Node *> &nodes() const { return nodes_; }
// Sets the simulated send delay for all messages sent within a single node.
@@ -135,6 +138,7 @@
private:
friend class NodeEventLoopFactory;
+ friend class SimulatedFactoryExitHandle;
const Configuration *const configuration_;
EventSchedulerScheduler scheduler_scheduler_;
@@ -147,6 +151,8 @@
std::vector<std::unique_ptr<NodeEventLoopFactory>> node_factories_;
std::vector<const Node *> nodes_;
+
+ int exit_handle_count_ = 0;
};
// This class holds all the state required to be a single node.
diff --git a/aos/events/simulated_event_loop_test.cc b/aos/events/simulated_event_loop_test.cc
index 1a35b69..eddacb1 100644
--- a/aos/events/simulated_event_loop_test.cc
+++ b/aos/events/simulated_event_loop_test.cc
@@ -2076,6 +2076,19 @@
EXPECT_DEATH({ factory.RunFor(dt * 2); }, "Event loop");
}
+// Test that an ExitHandle outliving its factory is caught.
+TEST(SimulatedEventLoopDeathTest, ExitHandleOutlivesFactory) {
+ aos::FlatbufferDetachedBuffer<aos::Configuration> config =
+ aos::configuration::ReadConfig(
+ ArtifactPath("aos/events/multinode_pingpong_test_split_config.json"));
+ auto factory = std::make_unique<SimulatedEventLoopFactory>(&config.message());
+ NodeEventLoopFactory *pi1 = factory->GetNodeEventLoopFactory("pi1");
+ std::unique_ptr<EventLoop> loop = pi1->MakeEventLoop("foo");
+ auto exit_handle = factory->MakeExitHandle();
+ EXPECT_DEATH(factory.reset(),
+ "All ExitHandles must be destroyed before the factory");
+}
+
// Tests that messages don't survive a reboot of a node.
TEST(SimulatedEventLoopTest, ChannelClearedOnReboot) {
aos::FlatbufferDetachedBuffer<aos::Configuration> config =