Merge "Add a simple multi-node test as an example"
diff --git a/aos/configuration.cc b/aos/configuration.cc
index d794a37..76704a4 100644
--- a/aos/configuration.cc
+++ b/aos/configuration.cc
@@ -10,10 +10,14 @@
#include <map>
#include <set>
+#include <string>
#include <string_view>
+#include <vector>
#include "absl/container/btree_set.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "aos/configuration_generated.h"
#include "aos/flatbuffer_merge.h"
#include "aos/json_to_flatbuffer.h"
@@ -131,10 +135,33 @@
return buffer;
}
+std::string RemoveDotDots(const std::string_view filename) {
+ std::vector<std::string> split = absl::StrSplit(filename, '/');
+ auto iterator = split.begin();
+ while (iterator != split.end()) {
+ if (iterator->empty()) {
+ iterator = split.erase(iterator);
+ } else if (*iterator == ".") {
+ iterator = split.erase(iterator);
+ } else if (*iterator == "..") {
+ CHECK(iterator != split.begin())
+ << ": Import path may not start with ..: " << filename;
+ auto previous = iterator;
+ --previous;
+ split.erase(iterator);
+ iterator = split.erase(previous);
+ } else {
+ ++iterator;
+ }
+ }
+ return absl::StrJoin(split, "/");
+}
+
FlatbufferDetachedBuffer<Configuration> ReadConfig(
const std::string_view path, absl::btree_set<std::string> *visited_paths,
const std::vector<std::string_view> &extra_import_paths) {
std::string binary_path = MaybeReplaceExtension(path, ".json", ".bfbs");
+ VLOG(1) << "Looking up: " << path << ", starting with: " << binary_path;
bool binary_path_exists = util::PathExists(binary_path);
std::string raw_path(path);
// For each .json file, look and see if we can find a .bfbs file next to it
@@ -151,13 +178,15 @@
bool found_path = false;
for (const auto &import_path : extra_import_paths) {
- raw_path = std::string(import_path) + "/" + std::string(path);
+ raw_path = std::string(import_path) + "/" + RemoveDotDots(path);
binary_path = MaybeReplaceExtension(raw_path, ".json", ".bfbs");
+ VLOG(1) << "Checking: " << binary_path;
binary_path_exists = util::PathExists(binary_path);
if (binary_path_exists) {
found_path = true;
break;
}
+ VLOG(1) << "Checking: " << raw_path;
if (util::PathExists(raw_path)) {
found_path = true;
break;
@@ -560,11 +589,11 @@
FlatbufferDetachedBuffer<Configuration> ReadConfig(
const std::string_view path,
- const std::vector<std::string_view> &import_paths) {
+ const std::vector<std::string_view> &extra_import_paths) {
// We only want to read a file once. So track the visited files in a set.
absl::btree_set<std::string> visited_paths;
FlatbufferDetachedBuffer<Configuration> read_config =
- ReadConfig(path, &visited_paths, import_paths);
+ ReadConfig(path, &visited_paths, extra_import_paths);
// If we only read one file, and it had a .bfbs extension, it has to be a
// fully formatted config. Do a quick verification and return it.
diff --git a/aos/events/channel_preallocated_allocator.h b/aos/events/channel_preallocated_allocator.h
index c4f0eca..76fb694 100644
--- a/aos/events/channel_preallocated_allocator.h
+++ b/aos/events/channel_preallocated_allocator.h
@@ -62,11 +62,12 @@
uint8_t *reallocate_downward(uint8_t * /*old_p*/, size_t /*old_size*/,
size_t new_size, size_t /*in_use_back*/,
size_t /*in_use_front*/) override {
- LOG(FATAL) << "Requested " << new_size << " bytes, max size "
- << channel_->max_size() << " for channel "
- << configuration::CleanedChannelToString(channel_)
- << ". Increase the memory reserved to at least " << new_size
- << ".";
+ LOG(FATAL)
+ << "Requested " << new_size
+ << " bytes (includes extra for room to grow even more), max size "
+ << channel_->max_size() << " for channel "
+ << configuration::CleanedChannelToString(channel_)
+ << ". Increase the memory reserved to at least " << new_size << ".";
return nullptr;
}
diff --git a/aos/events/epoll.cc b/aos/events/epoll.cc
index 1c4427b..8cb553b 100644
--- a/aos/events/epoll.cc
+++ b/aos/events/epoll.cc
@@ -4,6 +4,7 @@
#include <sys/epoll.h>
#include <sys/timerfd.h>
#include <unistd.h>
+
#include <atomic>
#include <vector>
@@ -109,62 +110,66 @@
}
EventData *const event_data = static_cast<struct EventData *>(event.data.ptr);
- if (event.events & kInEvents) {
- CHECK(event_data->in_fn)
- << ": No handler registered for input events on " << event_data->fd;
- event_data->in_fn();
- }
- if (event.events & kOutEvents) {
- CHECK(event_data->out_fn)
- << ": No handler registered for output events on " << event_data->fd;
- event_data->out_fn();
- }
- if (event.events & kErrorEvents) {
- CHECK(event_data->err_fn)
- << ": No handler registered for error events on " << event_data->fd;
- event_data->err_fn();
- }
+ event_data->DoCallbacks(event.events);
return true;
}
-void EPoll::Quit() { PCHECK(write(quit_signal_fd_, "q", 1) == 1); }
+void EPoll::Quit() {
+ // Shortcut to break us out of infinite loops. We might write more than once
+ // to the pipe, but we'll stop once the first is read on the other end.
+ if (!run_) {
+ return;
+ }
+ PCHECK(write(quit_signal_fd_, "q", 1) == 1);
+}
void EPoll::OnReadable(int fd, ::std::function<void()> function) {
EventData *event_data = GetEventData(fd);
if (event_data == nullptr) {
- fns_.emplace_back(std::make_unique<EventData>(fd));
+ fns_.emplace_back(std::make_unique<InOutEventData>(fd));
event_data = fns_.back().get();
} else {
- CHECK(!event_data->in_fn) << ": Duplicate in functions for " << fd;
+ CHECK(!static_cast<InOutEventData *>(event_data)->in_fn)
+ << ": Duplicate in functions for " << fd;
}
- event_data->in_fn = ::std::move(function);
+ static_cast<InOutEventData *>(event_data)->in_fn = ::std::move(function);
DoEpollCtl(event_data, event_data->events | kInEvents);
}
void EPoll::OnError(int fd, ::std::function<void()> function) {
EventData *event_data = GetEventData(fd);
if (event_data == nullptr) {
- fns_.emplace_back(std::make_unique<EventData>(fd));
+ fns_.emplace_back(std::make_unique<InOutEventData>(fd));
event_data = fns_.back().get();
} else {
- CHECK(!event_data->err_fn) << ": Duplicate in functions for " << fd;
+ CHECK(!static_cast<InOutEventData *>(event_data)->err_fn)
+ << ": Duplicate error functions for " << fd;
}
- event_data->err_fn = ::std::move(function);
+ static_cast<InOutEventData *>(event_data)->err_fn = ::std::move(function);
DoEpollCtl(event_data, event_data->events | kErrorEvents);
}
void EPoll::OnWriteable(int fd, ::std::function<void()> function) {
EventData *event_data = GetEventData(fd);
if (event_data == nullptr) {
- fns_.emplace_back(std::make_unique<EventData>(fd));
+ fns_.emplace_back(std::make_unique<InOutEventData>(fd));
event_data = fns_.back().get();
} else {
- CHECK(!event_data->out_fn) << ": Duplicate out functions for " << fd;
+ CHECK(!static_cast<InOutEventData *>(event_data)->out_fn)
+ << ": Duplicate out functions for " << fd;
}
- event_data->out_fn = ::std::move(function);
+ static_cast<InOutEventData *>(event_data)->out_fn = ::std::move(function);
DoEpollCtl(event_data, event_data->events | kOutEvents);
}
+void EPoll::OnEvents(int fd, ::std::function<void(uint32_t)> function) {
+ if (GetEventData(fd) != nullptr) {
+ LOG(FATAL) << "May not replace OnEvents handlers";
+ }
+ fns_.emplace_back(std::make_unique<SingleEventData>(fd));
+ static_cast<SingleEventData *>(fns_.back().get())->fn = std::move(function);
+}
+
void EPoll::ForgetClosedFd(int fd) {
auto element = fns_.begin();
while (fns_.end() != element) {
@@ -177,6 +182,10 @@
LOG(FATAL) << "fd " << fd << " not found";
}
+void EPoll::SetEvents(int fd, uint32_t events) {
+ DoEpollCtl(CHECK_NOTNULL(GetEventData(fd)), events);
+}
+
// Removes fd from the event loop.
void EPoll::DeleteFd(int fd) {
auto element = fns_.begin();
@@ -192,6 +201,21 @@
LOG(FATAL) << "fd " << fd << " not found";
}
+void EPoll::InOutEventData::DoCallbacks(uint32_t events) {
+ if (events & kInEvents) {
+ CHECK(in_fn) << ": No handler registered for input events on " << fd;
+ in_fn();
+ }
+ if (events & kOutEvents) {
+ CHECK(out_fn) << ": No handler registered for output events on " << fd;
+ out_fn();
+ }
+ if (events & kErrorEvents) {
+ CHECK(err_fn) << ": No handler registered for error events on " << fd;
+ err_fn();
+ }
+}
+
void EPoll::EnableEvents(int fd, uint32_t events) {
EventData *const event_data = CHECK_NOTNULL(GetEventData(fd));
DoEpollCtl(event_data, event_data->events | events);
@@ -214,12 +238,14 @@
void EPoll::DoEpollCtl(EventData *event_data, const uint32_t new_events) {
const uint32_t old_events = event_data->events;
+ if (old_events == new_events) {
+ // Shortcut without calling into the kernel. This happens often with
+ // external event loop integrations that are emulating poll, so make it
+ // fast.
+ return;
+ }
event_data->events = new_events;
if (new_events == 0) {
- if (old_events == 0) {
- // Not added, and doesn't need to be. Nothing to do here.
- return;
- }
// It was added, but should now be removed.
PCHECK(epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, event_data->fd, nullptr) == 0);
return;
diff --git a/aos/events/epoll.h b/aos/events/epoll.h
index 4bcedf1..2b7eb76 100644
--- a/aos/events/epoll.h
+++ b/aos/events/epoll.h
@@ -5,6 +5,7 @@
#include <sys/epoll.h>
#include <sys/timerfd.h>
#include <unistd.h>
+
#include <atomic>
#include <functional>
#include <vector>
@@ -68,21 +69,34 @@
// Quits. Async safe.
void Quit();
- // Called before waiting on the epoll file descriptor.
+ // Adds a function which will be called before waiting on the epoll file
+ // descriptor.
void BeforeWait(std::function<void()> function);
// Registers a function to be called if the fd becomes readable.
// Only one function may be registered for readability on each fd.
+ // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+ // OnEvents.
void OnReadable(int fd, ::std::function<void()> function);
// Registers a function to be called if the fd reports an error.
// Only one function may be registered for errors on each fd.
+ // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+ // OnEvents.
void OnError(int fd, ::std::function<void()> function);
// Registers a function to be called if the fd becomes writeable.
// Only one function may be registered for writability on each fd.
+ // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+ // OnEvents.
void OnWriteable(int fd, ::std::function<void()> function);
+ // Registers a function to be called when the configured events occur on fd.
+ // Which events occur will be passed to the function.
+ // A fd may be registered exclusively with OnReadable/OnWriteable/OnError OR
+ // OnEvents.
+ void OnEvents(int fd, ::std::function<void(uint32_t)> function);
+
// Removes fd from the event loop.
// All Fds must be cleaned up before this class is destroyed.
void DeleteFd(int fd);
@@ -100,19 +114,50 @@
// writeable.
void DisableWriteable(int fd) { DisableEvents(fd, kOutEvents); }
+ // Sets the epoll events for the given fd. Be careful using this with
+ // OnReadable/OnWriteable/OnError: enabled events which fire with no handler
+ // registered will result in a crash.
+ void SetEvents(int fd, uint32_t events);
+
+ // Returns whether we're currently running. This changes to false when we
+ // start draining events to finish.
+ bool should_run() const { return run_; }
+
private:
// Structure whose pointer should be returned by epoll. Makes looking up the
// function fast and easy.
struct EventData {
EventData(int fd_in) : fd(fd_in) {}
+ virtual ~EventData() = default;
+
// We use pointers to these objects as persistent identifiers, so they can't
// be moved.
EventData(const EventData &) = delete;
EventData &operator=(const EventData &) = delete;
+ // Calls the appropriate callbacks when events are returned from the kernel.
+ virtual void DoCallbacks(uint32_t events) = 0;
+
const int fd;
uint32_t events = 0;
+ };
+
+ struct InOutEventData : public EventData {
+ InOutEventData(int fd) : EventData(fd) {}
+ ~InOutEventData() override = default;
+
std::function<void()> in_fn, out_fn, err_fn;
+
+ void DoCallbacks(uint32_t events) override;
+ };
+
+ struct SingleEventData : public EventData {
+ SingleEventData(int fd) : EventData(fd) {}
+ ~SingleEventData() override = default;
+
+ std::function<void(uint32_t)> fn;
+
+ void DoCallbacks(uint32_t events) override { fn(events); }
};
void EnableEvents(int fd, uint32_t events);
diff --git a/aos/events/epoll_test.cc b/aos/events/epoll_test.cc
index 66460c2..053a09b 100644
--- a/aos/events/epoll_test.cc
+++ b/aos/events/epoll_test.cc
@@ -3,8 +3,8 @@
#include <fcntl.h>
#include <unistd.h>
-#include "gtest/gtest.h"
#include "glog/logging.h"
+#include "gtest/gtest.h"
namespace aos {
namespace internal {
@@ -48,22 +48,22 @@
};
class EPollTest : public ::testing::Test {
- public:
- void RunFor(std::chrono::nanoseconds duration) {
- TimerFd timerfd;
- bool did_quit = false;
- epoll_.OnReadable(timerfd.fd(), [this, &timerfd, &did_quit]() {
- CHECK(!did_quit);
- epoll_.Quit();
- did_quit = true;
- timerfd.Read();
- });
- timerfd.SetTime(monotonic_clock::now() + duration,
- monotonic_clock::duration::zero());
- epoll_.Run();
- CHECK(did_quit);
- epoll_.DeleteFd(timerfd.fd());
- }
+ public:
+ void RunFor(std::chrono::nanoseconds duration) {
+ TimerFd timerfd;
+ bool did_quit = false;
+ epoll_.OnReadable(timerfd.fd(), [this, &timerfd, &did_quit]() {
+ CHECK(!did_quit);
+ epoll_.Quit();
+ did_quit = true;
+ timerfd.Read();
+ });
+ timerfd.SetTime(monotonic_clock::now() + duration,
+ monotonic_clock::duration::zero());
+ epoll_.Run();
+ CHECK(did_quit);
+ epoll_.DeleteFd(timerfd.fd());
+ }
// Tests should avoid relying on ordering for events closer in time than this,
// or waiting for longer than this to ensure events happen in order.
@@ -71,7 +71,7 @@
return std::chrono::milliseconds(50);
}
- EPoll epoll_;
+ EPoll epoll_;
};
// Test that the basics of OnReadable work.
@@ -201,6 +201,11 @@
epoll_.DeleteFd(pipe.write_fd());
}
+TEST_F(EPollTest, QuitInBeforeWait) {
+ epoll_.BeforeWait([this]() { epoll_.Quit(); });
+ epoll_.Run();
+}
+
} // namespace testing
} // namespace internal
} // namespace aos
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index d23314e..7cb6a5a 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -7,6 +7,7 @@
#include <string>
#include <string_view>
+#include "absl/container/btree_set.h"
#include "aos/configuration.h"
#include "aos/configuration_generated.h"
#include "aos/events/channel_preallocated_allocator.h"
@@ -20,8 +21,6 @@
#include "aos/time/time.h"
#include "aos/util/phased_loop.h"
#include "aos/uuid.h"
-
-#include "absl/container/btree_set.h"
#include "flatbuffers/flatbuffers.h"
#include "glog/logging.h"
@@ -115,6 +114,7 @@
protected:
EventLoop *event_loop() { return event_loop_; }
+ const EventLoop *event_loop() const { return event_loop_; }
Context context_;
@@ -190,6 +190,7 @@
protected:
EventLoop *event_loop() { return event_loop_; }
+ const EventLoop *event_loop() const { return event_loop_; }
monotonic_clock::time_point monotonic_sent_time_ = monotonic_clock::min_time;
realtime_clock::time_point realtime_sent_time_ = realtime_clock::min_time;
@@ -473,8 +474,13 @@
Ftrace ftrace_;
};
+// Note, it is supported to create only:
+// multiple fetchers, and (one sender or one watcher) per <name, type>
+// tuple.
class EventLoop {
public:
+ // Holds configuration by reference for the lifetime of this object. It may
+ // never be mutated externally in any way.
EventLoop(const Configuration *configuration);
virtual ~EventLoop();
@@ -495,10 +501,6 @@
return GetChannel<T>(channel_name) != nullptr;
}
- // Note, it is supported to create:
- // multiple fetchers, and (one sender or one watcher) per <name, type>
- // tuple.
-
// Makes a class that will always fetch the most recent value
// sent to the provided channel.
template <typename T>
@@ -596,7 +598,7 @@
// TODO(austin): OnExit for cleanup.
- // Threadsafe.
+ // May be safely called from any thread.
bool is_running() const { return is_running_.load(); }
// Sets the scheduler priority to run the event loop at. This may not be
diff --git a/aos/events/logging/log_cat.cc b/aos/events/logging/log_cat.cc
index 5c69f07..342a491 100644
--- a/aos/events/logging/log_cat.cc
+++ b/aos/events/logging/log_cat.cc
@@ -39,6 +39,8 @@
"confirming they can be parsed.");
DEFINE_bool(print_parts_only, false,
"If true, only print out the results of logfile sorting.");
+DEFINE_bool(channels, false,
+ "If true, print out all the configured channels for this log.");
// Print the flatbuffer out to stdout, both to remove the unnecessary cruft from
// glog and to allow the user to readily redirect just the logged output
@@ -227,6 +229,15 @@
aos::logger::LogReader reader(logfiles);
+ if (FLAGS_channels) {
+ const aos::Configuration *config = reader.configuration();
+ for (const aos::Channel *channel : *config->channels()) {
+ std::cout << channel->name()->c_str() << " " << channel->type()->c_str()
+ << '\n';
+ }
+ return 0;
+ }
+
aos::FastStringBuilder builder;
aos::SimulatedEventLoopFactory event_loop_factory(reader.configuration());
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index 63e1cb9..835210c 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -477,9 +477,13 @@
simple_shm_fetcher_.RetrieveData();
}
- ~ShmFetcher() { context_.data = nullptr; }
+ ~ShmFetcher() override {
+ shm_event_loop()->CheckCurrentThread();
+ context_.data = nullptr;
+ }
std::pair<bool, monotonic_clock::time_point> DoFetchNext() override {
+ shm_event_loop()->CheckCurrentThread();
if (simple_shm_fetcher_.FetchNext()) {
context_ = simple_shm_fetcher_.context();
return std::make_pair(true, monotonic_clock::now());
@@ -488,6 +492,7 @@
}
std::pair<bool, monotonic_clock::time_point> DoFetch() override {
+ shm_event_loop()->CheckCurrentThread();
if (simple_shm_fetcher_.Fetch()) {
context_ = simple_shm_fetcher_.context();
return std::make_pair(true, monotonic_clock::now());
@@ -500,6 +505,10 @@
}
private:
+ const ShmEventLoop *shm_event_loop() const {
+ return static_cast<const ShmEventLoop *>(event_loop());
+ }
+
SimpleShmFetcher simple_shm_fetcher_;
};
@@ -517,7 +526,7 @@
channel)),
wake_upper_(lockless_queue_memory_.queue()) {}
- ~ShmSender() override {}
+ ~ShmSender() override { shm_event_loop()->CheckCurrentThread(); }
static ipc_lib::LocklessQueueSender VerifySender(
std::optional<ipc_lib::LocklessQueueSender> sender,
@@ -530,13 +539,20 @@
<< ", too many senders.";
}
- void *data() override { return lockless_queue_sender_.Data(); }
- size_t size() override { return lockless_queue_sender_.size(); }
+ void *data() override {
+ shm_event_loop()->CheckCurrentThread();
+ return lockless_queue_sender_.Data();
+ }
+ size_t size() override {
+ shm_event_loop()->CheckCurrentThread();
+ return lockless_queue_sender_.size();
+ }
bool DoSend(size_t length,
aos::monotonic_clock::time_point monotonic_remote_time,
aos::realtime_clock::time_point realtime_remote_time,
uint32_t remote_queue_index,
const UUID &remote_boot_uuid) override {
+ shm_event_loop()->CheckCurrentThread();
CHECK_LE(length, static_cast<size_t>(channel()->max_size()))
<< ": Sent too big a message on "
<< configuration::CleanedChannelToString(channel());
@@ -547,7 +563,8 @@
<< ": Somebody wrote outside the buffer of their message on channel "
<< configuration::CleanedChannelToString(channel());
- wake_upper_.Wakeup(event_loop()->priority());
+ wake_upper_.Wakeup(event_loop()->is_running() ? event_loop()->priority()
+ : 0);
return true;
}
@@ -556,6 +573,7 @@
aos::realtime_clock::time_point realtime_remote_time,
uint32_t remote_queue_index,
const UUID &remote_boot_uuid) override {
+ shm_event_loop()->CheckCurrentThread();
CHECK_LE(length, static_cast<size_t>(channel()->max_size()))
<< ": Sent too big a message on "
<< configuration::CleanedChannelToString(channel());
@@ -565,7 +583,8 @@
&monotonic_sent_time_, &realtime_sent_time_, &sent_queue_index_))
<< ": Somebody wrote outside the buffer of their message on channel "
<< configuration::CleanedChannelToString(channel());
- wake_upper_.Wakeup(event_loop()->priority());
+ wake_upper_.Wakeup(event_loop()->is_running() ? event_loop()->priority()
+ : 0);
// TODO(austin): Return an error if we send too fast.
return true;
}
@@ -574,9 +593,16 @@
return lockless_queue_memory_.GetMutableSharedMemory();
}
- int buffer_index() override { return lockless_queue_sender_.buffer_index(); }
+ int buffer_index() override {
+ shm_event_loop()->CheckCurrentThread();
+ return lockless_queue_sender_.buffer_index();
+ }
private:
+ const ShmEventLoop *shm_event_loop() const {
+ return static_cast<const ShmEventLoop *>(event_loop());
+ }
+
MMappedQueue lockless_queue_memory_;
ipc_lib::LocklessQueueSender lockless_queue_sender_;
ipc_lib::LocklessQueueWakeUpper wake_upper_;
@@ -599,9 +625,13 @@
}
}
- ~ShmWatcherState() override { event_loop_->RemoveEvent(&event_); }
+ ~ShmWatcherState() override {
+ event_loop_->CheckCurrentThread();
+ event_loop_->RemoveEvent(&event_);
+ }
void Startup(EventLoop *event_loop) override {
+ event_loop_->CheckCurrentThread();
simple_shm_fetcher_.PointAtNextQueueIndex();
CHECK(RegisterWakeup(event_loop->priority()));
}
@@ -666,6 +696,7 @@
}
~ShmTimerHandler() {
+ shm_event_loop_->CheckCurrentThread();
Disable();
shm_event_loop_->epoll_.DeleteFd(timerfd_.fd());
}
@@ -705,6 +736,7 @@
void Setup(monotonic_clock::time_point base,
monotonic_clock::duration repeat_offset) override {
+ shm_event_loop_->CheckCurrentThread();
if (event_.valid()) {
shm_event_loop_->RemoveEvent(&event_);
}
@@ -717,6 +749,7 @@
}
void Disable() override {
+ shm_event_loop_->CheckCurrentThread();
shm_event_loop_->RemoveEvent(&event_);
timerfd_.Disable();
disabled_ = true;
@@ -766,6 +799,7 @@
}
~ShmPhasedLoopHandler() override {
+ shm_event_loop_->CheckCurrentThread();
shm_event_loop_->epoll_.DeleteFd(timerfd_.fd());
shm_event_loop_->RemoveEvent(&event_);
}
@@ -773,6 +807,7 @@
private:
// Reschedules the timer.
void Schedule(monotonic_clock::time_point sleep_time) override {
+ shm_event_loop_->CheckCurrentThread();
if (event_.valid()) {
shm_event_loop_->RemoveEvent(&event_);
}
@@ -792,6 +827,7 @@
::std::unique_ptr<RawFetcher> ShmEventLoop::MakeRawFetcher(
const Channel *channel) {
+ CheckCurrentThread();
if (!configuration::ChannelIsReadableOnNode(channel, node())) {
LOG(FATAL) << "Channel { \"name\": \"" << channel->name()->string_view()
<< "\", \"type\": \"" << channel->type()->string_view()
@@ -805,6 +841,7 @@
::std::unique_ptr<RawSender> ShmEventLoop::MakeRawSender(
const Channel *channel) {
+ CheckCurrentThread();
TakeSender(channel);
return ::std::unique_ptr<RawSender>(new ShmSender(shm_base_, this, channel));
@@ -813,6 +850,7 @@
void ShmEventLoop::MakeRawWatcher(
const Channel *channel,
std::function<void(const Context &context, const void *message)> watcher) {
+ CheckCurrentThread();
TakeWatcher(channel);
NewWatcher(::std::unique_ptr<WatcherState>(
@@ -822,6 +860,7 @@
void ShmEventLoop::MakeRawNoArgWatcher(
const Channel *channel,
std::function<void(const Context &context)> watcher) {
+ CheckCurrentThread();
TakeWatcher(channel);
NewWatcher(::std::unique_ptr<WatcherState>(new ShmWatcherState(
@@ -831,6 +870,7 @@
}
TimerHandler *ShmEventLoop::AddTimer(::std::function<void()> callback) {
+ CheckCurrentThread();
return NewTimer(::std::unique_ptr<TimerHandler>(
new ShmTimerHandler(this, ::std::move(callback))));
}
@@ -839,14 +879,28 @@
::std::function<void(int)> callback,
const monotonic_clock::duration interval,
const monotonic_clock::duration offset) {
+ CheckCurrentThread();
return NewPhasedLoop(::std::unique_ptr<PhasedLoopHandler>(
new ShmPhasedLoopHandler(this, ::std::move(callback), interval, offset)));
}
void ShmEventLoop::OnRun(::std::function<void()> on_run) {
+ CheckCurrentThread();
on_run_.push_back(::std::move(on_run));
}
+void ShmEventLoop::CheckCurrentThread() const {
+ if (__builtin_expect(check_mutex_ != nullptr, false)) {
+ CHECK(check_mutex_->is_locked())
+ << ": The configured mutex is not locked while calling a "
+ "ShmEventLoop function";
+ }
+ if (__builtin_expect(!!check_tid_, false)) {
+ CHECK_EQ(syscall(SYS_gettid), *check_tid_)
+ << ": Being called from the wrong thread";
+ }
+}
+
// This is a bit tricky because watchers can generate new events at any time (as
// long as it's in the past). We want to check the watchers at least once before
// declaring there are no events to handle, and we want to check them again if
@@ -1021,6 +1075,7 @@
};
void ShmEventLoop::Run() {
+ CheckCurrentThread();
SignalHandler::global()->Register(this);
if (watchers_.size() > 0) {
@@ -1100,6 +1155,7 @@
void ShmEventLoop::Exit() { epoll_.Quit(); }
ShmEventLoop::~ShmEventLoop() {
+ CheckCurrentThread();
// Force everything with a registered fd with epoll to be destroyed now.
timers_.clear();
phased_loops_.clear();
@@ -1109,6 +1165,7 @@
}
void ShmEventLoop::SetRuntimeRealtimePriority(int priority) {
+ CheckCurrentThread();
if (is_running()) {
LOG(FATAL) << "Cannot set realtime priority while running.";
}
@@ -1116,6 +1173,7 @@
}
void ShmEventLoop::SetRuntimeAffinity(const cpu_set_t &cpuset) {
+ CheckCurrentThread();
if (is_running()) {
LOG(FATAL) << "Cannot set affinity while running.";
}
@@ -1123,18 +1181,21 @@
}
void ShmEventLoop::set_name(const std::string_view name) {
+ CheckCurrentThread();
name_ = std::string(name);
UpdateTimingReport();
}
absl::Span<const char> ShmEventLoop::GetWatcherSharedMemory(
const Channel *channel) {
+ CheckCurrentThread();
ShmWatcherState *const watcher_state =
static_cast<ShmWatcherState *>(GetWatcherState(channel));
return watcher_state->GetSharedMemory();
}
int ShmEventLoop::NumberBuffers(const Channel *channel) {
+ CheckCurrentThread();
return MakeQueueConfiguration(
channel, chrono::ceil<chrono::seconds>(chrono::nanoseconds(
configuration()->channel_storage_duration())))
@@ -1143,14 +1204,19 @@
absl::Span<char> ShmEventLoop::GetShmSenderSharedMemory(
const aos::RawSender *sender) const {
+ CheckCurrentThread();
return static_cast<const ShmSender *>(sender)->GetSharedMemory();
}
absl::Span<const char> ShmEventLoop::GetShmFetcherPrivateMemory(
const aos::RawFetcher *fetcher) const {
+ CheckCurrentThread();
return static_cast<const ShmFetcher *>(fetcher)->GetPrivateMemory();
}
-pid_t ShmEventLoop::GetTid() { return syscall(SYS_gettid); }
+pid_t ShmEventLoop::GetTid() {
+ CheckCurrentThread();
+ return syscall(SYS_gettid);
+}
} // namespace aos
diff --git a/aos/events/shm_event_loop.h b/aos/events/shm_event_loop.h
index 845857c..3245f11 100644
--- a/aos/events/shm_event_loop.h
+++ b/aos/events/shm_event_loop.h
@@ -4,11 +4,11 @@
#include <vector>
#include "absl/types/span.h"
-
#include "aos/events/epoll.h"
#include "aos/events/event_loop.h"
#include "aos/events/event_loop_generated.h"
#include "aos/ipc_lib/signalfd.h"
+#include "aos/stl_mutex/stl_mutex.h"
DECLARE_string(application_name);
DECLARE_string(shm_base);
@@ -92,6 +92,7 @@
// Returns the local mapping of the shared memory used by the provided Sender.
template <typename T>
absl::Span<char> GetSenderSharedMemory(aos::Sender<T> *sender) const {
+ CheckCurrentThread();
return GetShmSenderSharedMemory(GetRawSender(sender));
}
@@ -103,11 +104,30 @@
template <typename T>
absl::Span<const char> GetFetcherPrivateMemory(
aos::Fetcher<T> *fetcher) const {
+ CheckCurrentThread();
return GetShmFetcherPrivateMemory(GetRawFetcher(fetcher));
}
int NumberBuffers(const Channel *channel) override;
+ // All public-facing APIs will verify this mutex is held when they are called.
+ // For normal use with everything in a single thread, this is unnecessary.
+ //
+ // This is helpful as a safety check when using a ShmEventLoop with external
+ // synchronization across multiple threads. It will NOT reliably catch race
+ // conditions, but if you have a race condition triggered repeatedly it'll
+ // probably catch it eventually.
+ void CheckForMutex(aos::stl_mutex *check_mutex) {
+ check_mutex_ = check_mutex;
+ }
+
+ // All public-facing APIs will verify they are called in this thread.
+ // For normal use with the whole program in a single thread, this is
+ // unnecessary. It's helpful as a safety check for programs with multiple
+ // threads, where the EventLoop should only be interacted with from a single
+ // one.
+ void LockToThread() { check_tid_ = GetTid(); }
+
private:
friend class shm_event_loop_internal::ShmWatcherState;
friend class shm_event_loop_internal::ShmTimerHandler;
@@ -126,6 +146,8 @@
return result;
}
+ void CheckCurrentThread() const;
+
void HandleEvent();
// Returns the TID of the event loop.
@@ -151,6 +173,9 @@
std::string name_;
const Node *const node_;
+ aos::stl_mutex *check_mutex_ = nullptr;
+ std::optional<pid_t> check_tid_;
+
internal::EPoll epoll_;
// Only set during Run().
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index 5524597..4c12213 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -3,13 +3,12 @@
#include <string_view>
#include "aos/events/event_loop_param_test.h"
+#include "aos/events/test_message_generated.h"
+#include "aos/network/team_number.h"
#include "aos/realtime.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
-#include "aos/events/test_message_generated.h"
-#include "aos/network/team_number.h"
-
namespace aos {
namespace testing {
namespace {
@@ -123,6 +122,36 @@
using ShmEventLoopDeathTest = ShmEventLoopTest;
+// Tests that we don't leave the calling thread realtime when calling Send
+// before Run.
+TEST_P(ShmEventLoopTest, SendBeforeRun) {
+ auto loop = factory()->MakePrimary("primary");
+ loop->SetRuntimeRealtimePriority(1);
+
+ auto loop2 = factory()->Make("loop2");
+ loop2->SetRuntimeRealtimePriority(2);
+ loop2->MakeWatcher("/test", [](const TestMessage &) {});
+ // Need the other one running for its watcher to record in SHM that it wants
+ // wakers to boost their priority, so leave it running in a thread for this
+ // test.
+ std::thread loop2_thread(
+ [&loop2]() { static_cast<ShmEventLoop *>(loop2.get())->Run(); });
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ auto sender = loop->MakeSender<TestMessage>("/test");
+ EXPECT_FALSE(IsRealtime());
+ {
+ aos::Sender<TestMessage>::Builder msg = sender.MakeBuilder();
+ TestMessage::Builder builder = msg.MakeBuilder<TestMessage>();
+ builder.add_value(200);
+ msg.Send(builder.Finish());
+ }
+ EXPECT_FALSE(IsRealtime());
+
+ static_cast<ShmEventLoop *>(loop2.get())->Exit();
+ loop2_thread.join();
+}
+
// Tests that every handler type is realtime and runs. There are threads
// involved and it's easy to miss one.
TEST_P(ShmEventLoopTest, AllHandlersAreRealtime) {
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index 08350a8..1044b2f 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -186,24 +186,26 @@
}
void SctpClientConnection::SendConnect() {
+ VLOG(1) << "Sending Connect";
// Try to send the connect message. If that fails, retry.
- if (!client_.Send(kConnectStream(),
- std::string_view(reinterpret_cast<const char *>(
- connect_message_.span().data()),
- connect_message_.span().size()),
- 0)) {
+ if (client_.Send(kConnectStream(),
+ std::string_view(reinterpret_cast<const char *>(
+ connect_message_.span().data()),
+ connect_message_.span().size()),
+ 0)) {
+ ScheduleConnectTimeout();
+ } else {
NodeDisconnected();
}
}
void SctpClientConnection::NodeConnected(sctp_assoc_t assoc_id) {
- connect_timer_->Disable();
+ ScheduleConnectTimeout();
// We want to tell the kernel to schedule the packets on this new stream with
// the priority scheduler. This only needs to be done once per stream.
client_.SetPriorityScheduler(assoc_id);
- remote_assoc_id_ = assoc_id;
connection_->mutate_state(State::CONNECTED);
client_status_->SampleReset(client_index_);
}
@@ -212,13 +214,14 @@
connect_timer_->Setup(
event_loop_->monotonic_now() + chrono::milliseconds(100),
chrono::milliseconds(100));
- remote_assoc_id_ = 0;
connection_->mutate_state(State::DISCONNECTED);
connection_->mutate_monotonic_offset(0);
client_status_->SampleReset(client_index_);
}
void SctpClientConnection::HandleData(const Message *message) {
+ ScheduleConnectTimeout();
+
const RemoteData *remote_data =
flatbuffers::GetSizePrefixedRoot<RemoteData>(message->data());
diff --git a/aos/network/message_bridge_client_lib.h b/aos/network/message_bridge_client_lib.h
index 2b48906..0552f79 100644
--- a/aos/network/message_bridge_client_lib.h
+++ b/aos/network/message_bridge_client_lib.h
@@ -53,6 +53,14 @@
void NodeDisconnected();
void HandleData(const Message *message);
+ // Schedules connect_timer_ for a ways in the future. If one of our messages
+ // gets dropped, the server might be waiting for this, so if we don't hear
+ // from the server for a while we'll try sending it again.
+ void ScheduleConnectTimeout() {
+ connect_timer_->Setup(event_loop_->context().monotonic_event_time +
+ std::chrono::seconds(1));
+ }
+
// Event loop to register the server on.
aos::ShmEventLoop *const event_loop_;
@@ -81,10 +89,6 @@
// Timer which fires to handle reconnections.
aos::TimerHandler *connect_timer_;
- // id of the server once known. This is only valid if connection_ says
- // connected.
- sctp_assoc_t remote_assoc_id_ = 0;
-
// ClientConnection statistics message to modify. This will be published
// periodicially.
MessageBridgeClientStatus *client_status_;
diff --git a/aos/stl_mutex/stl_mutex.h b/aos/stl_mutex/stl_mutex.h
index 86a5988..e3a930e 100644
--- a/aos/stl_mutex/stl_mutex.h
+++ b/aos/stl_mutex/stl_mutex.h
@@ -61,6 +61,11 @@
bool owner_died() const { return owner_died_; }
void consistent() { owner_died_ = false; }
+ // Returns whether this mutex is locked by the current thread. This is very
+ // hard to use reliably, please think very carefully before using it for
+ // anything beyond probabilistic assertion checks.
+ bool is_locked() const { return mutex_islocked(&native_handle_); }
+
private:
aos_mutex native_handle_;
@@ -82,7 +87,7 @@
constexpr stl_recursive_mutex() {}
void lock() {
- if (mutex_islocked(mutex_.native_handle())) {
+ if (mutex_.is_locked()) {
CHECK(!owner_died());
++recursive_locks_;
} else {
@@ -95,7 +100,7 @@
}
}
bool try_lock() {
- if (mutex_islocked(mutex_.native_handle())) {
+ if (mutex_.is_locked()) {
CHECK(!owner_died());
++recursive_locks_;
return true;
diff --git a/y2021_bot3/control_loops/python/drivetrain.py b/y2021_bot3/control_loops/python/drivetrain.py
index e5b001a..fdd08c2 100644
--- a/y2021_bot3/control_loops/python/drivetrain.py
+++ b/y2021_bot3/control_loops/python/drivetrain.py
@@ -16,10 +16,11 @@
J=6.0,
mass=58.0,
# TODO(austin): Measure radius a bit better.
- robot_radius=0.7 / 2.0,
- wheel_radius=6.0 * 0.0254 / 2.0,
+ robot_radius= 0.39,
+ wheel_radius= 3/39.37,
motor_type=control_loop.Falcon(),
- G=(8.0 / 70.0) * (17.0 / 24.0),
+ num_motors = 3,
+ G=8.0 / 80.0,
q_pos=0.24,
q_vel=2.5,
efficiency=0.80,