Merge "Drop support for armhf"
diff --git a/README.md b/README.md
index 0971119..29b57b2 100644
--- a/README.md
+++ b/README.md
@@ -121,11 +121,11 @@
```
3. Once you hear back from Stephan, test SSH.
```console
-ssh [user]@build.frc971.org -p 2222 -i ~/.ssh/id_971_ed25519
+ssh REPLACE_WITH_YOUR_USERNAME@build.frc971.org -p 2222 -i ~/.ssh/id_971_ed25519
```
4. If that doesnt work, then send the error msg to #coding However, if it does then use the `exit` command and then SSH tunnel.
```console
-ssh [user]@build.frc971.org -p 2222 -i ~/.ssh/id_971_ed25519 -L 9971:127.0.0.1:3389
+ssh REPLACE_WITH_YOUR_USERNAME@build.frc971.org -p 2222 -i ~/.ssh/id_971_ed25519 -L 9971:127.0.0.1:3389
```
5. So at this point you run the Remote Desktop app in Windows. Once
you get there, all you need to do is put `127.0.0.1:9971` for the
diff --git a/aos/BUILD b/aos/BUILD
index a5fb6c8..d7a9b14 100644
--- a/aos/BUILD
+++ b/aos/BUILD
@@ -177,6 +177,7 @@
visibility = ["//visibility:public"],
deps = [
":realtime",
+ ":uuid",
"//aos:die",
"//aos/logging:implementations",
],
@@ -209,6 +210,7 @@
visibility = ["//visibility:public"],
deps = [
":thread_local",
+ ":uuid",
"@com_github_google_glog//:glog",
],
)
@@ -277,6 +279,7 @@
":flatbuffers",
":json_to_flatbuffer",
"//aos:unique_malloc_ptr",
+ "//aos/ipc_lib:index",
"//aos/network:team_number",
"//aos/util:file",
"@com_github_google_glog//:glog",
@@ -712,6 +715,18 @@
)
cc_test(
+ name = "uuid_collision_test",
+ timeout = "eternal",
+ srcs = ["uuid_collision_test.cc"],
+ shard_count = 2,
+ target_compatible_with = ["@platforms//os:linux"],
+ deps = [
+ ":uuid",
+ "//aos/testing:googletest",
+ ],
+)
+
+cc_test(
name = "uuid_test",
srcs = ["uuid_test.cc"],
target_compatible_with = ["@platforms//os:linux"],
diff --git a/aos/config.bzl b/aos/config.bzl
index 4aa6b4a..d281027 100644
--- a/aos/config.bzl
+++ b/aos/config.bzl
@@ -1,4 +1,5 @@
load("//tools/build_rules:label.bzl", "expand_label")
+load("//tools/build_rules:select.bzl", "address_size_select")
AosConfigInfo = provider(fields = [
"transitive_flatbuffers",
@@ -9,6 +10,10 @@
_aos_config(
name = name,
src = src,
+ flags = address_size_select({
+ "32": ["--max_queue_size_override=0xffff"],
+ "64": ["--max_queue_size_override=0xffffffff"],
+ }),
config_json = name + ".json",
config_stripped = name + ".stripped.json",
config_binary = name + ".bfbs",
@@ -38,7 +43,7 @@
ctx.actions.run(
outputs = [config, stripped_config, binary_config],
inputs = all_files,
- arguments = [
+ arguments = ctx.attr.flags + [
config.path,
stripped_config.path,
binary_config.path,
@@ -74,6 +79,9 @@
mandatory = True,
allow_files = True,
),
+ "flags": attr.string_list(
+ doc = "Additional flags to pass to config_flattener.",
+ ),
"deps": attr.label_list(
providers = [AosConfigInfo],
),
diff --git a/aos/configuration.cc b/aos/configuration.cc
index d100321..70e7b48 100644
--- a/aos/configuration.cc
+++ b/aos/configuration.cc
@@ -23,11 +23,18 @@
#include "aos/configuration_generated.h"
#include "aos/flatbuffer_merge.h"
+#include "aos/ipc_lib/index.h"
#include "aos/json_to_flatbuffer.h"
#include "aos/network/team_number.h"
#include "aos/unique_malloc_ptr.h"
#include "aos/util/file.h"
+DEFINE_uint32(max_queue_size_override, 0,
+ "If nonzero, this is the max number of elements in a queue to "
+ "enforce. If zero, use the number that the processor that this "
+ "application is compiled for can support. This is mostly useful "
+ "for config validation, and shouldn't be touched.");
+
namespace aos {
namespace {
namespace chrono = std::chrono;
@@ -364,10 +371,13 @@
}
CHECK_LT(QueueSize(&config.message(), c) + QueueScratchBufferSize(c),
- std::numeric_limits<uint16_t>::max())
+ FLAGS_max_queue_size_override != 0
+ ? FLAGS_max_queue_size_override
+ : std::numeric_limits<
+ ipc_lib::QueueIndex::PackedIndexType>::max())
<< ": More messages/second configured than the queue can hold on "
<< CleanedChannelToString(c) << ", " << c->frequency() << "hz for "
- << config.message().channel_storage_duration() << "ns";
+ << ChannelStorageDuration(&config.message(), c).count() << "ns";
if (c->has_logger_nodes()) {
// Confirm that we don't have duplicate logger nodes.
@@ -968,7 +978,7 @@
std::map<std::string_view, flatbuffers::Offset<reflection::Schema>>
schema_cache;
- CHECK_EQ(Channel::MiniReflectTypeTable()->num_elems, 13u)
+ CHECK_EQ(Channel::MiniReflectTypeTable()->num_elems, 14u)
<< ": Merging logic needs to be updated when the number of channel "
"fields changes.";
@@ -1065,6 +1075,10 @@
if (c->has_num_readers()) {
channel_builder.add_num_readers(c->num_readers());
}
+ if (c->has_channel_storage_duration()) {
+ channel_builder.add_channel_storage_duration(
+ c->channel_storage_duration());
+ }
channel_offsets.emplace_back(channel_builder.Finish());
}
channels_offset = fbb.CreateVector(channel_offsets);
@@ -1609,12 +1623,22 @@
return result;
}
-int QueueSize(const Configuration *config, const Channel *channel) {
- return QueueSize(channel->frequency(),
- chrono::nanoseconds(config->channel_storage_duration()));
+chrono::nanoseconds ChannelStorageDuration(const Configuration *config,
+ const Channel *channel) {
+ CHECK(channel != nullptr);
+ if (channel->has_channel_storage_duration()) {
+ return chrono::nanoseconds(channel->channel_storage_duration());
+ }
+ return chrono::nanoseconds(config->channel_storage_duration());
}
-int QueueSize(size_t frequency, chrono::nanoseconds channel_storage_duration) {
+size_t QueueSize(const Configuration *config, const Channel *channel) {
+ return QueueSize(channel->frequency(),
+ ChannelStorageDuration(config, channel));
+}
+
+size_t QueueSize(size_t frequency,
+ chrono::nanoseconds channel_storage_duration) {
// Use integer arithmetic and round up at all cost.
return static_cast<int>(
(999999999 + static_cast<int64_t>(frequency) *
diff --git a/aos/configuration.fbs b/aos/configuration.fbs
index 5283d9e..0b42f5b 100644
--- a/aos/configuration.fbs
+++ b/aos/configuration.fbs
@@ -105,6 +105,10 @@
//
// Currently, this must be set if and only if read_method is PIN.
num_readers:int (id: 12);
+
+ // Length of this channel in nanoseconds. This overrides
+ // channel_storage_duration below in Configuration for just this channel.
+ channel_storage_duration:long = 2000000000 (id: 13);
}
// Table to support renaming channel names.
diff --git a/aos/configuration.h b/aos/configuration.h
index 06f4db4..e35c0bf 100644
--- a/aos/configuration.h
+++ b/aos/configuration.h
@@ -209,9 +209,9 @@
const Application *application);
// Returns the number of messages in the queue.
-int QueueSize(const Configuration *config, const Channel *channel);
-int QueueSize(size_t frequency,
- std::chrono::nanoseconds channel_storage_duration);
+size_t QueueSize(const Configuration *config, const Channel *channel);
+size_t QueueSize(size_t frequency,
+ std::chrono::nanoseconds channel_storage_duration);
// Returns the number of scratch buffers in the queue.
int QueueScratchBufferSize(const Channel *channel);
@@ -231,6 +231,10 @@
GetSchemaDetachedBuffer(const Configuration *config,
std::string_view schema_type);
+// Returns the storage duration for a channel.
+std::chrono::nanoseconds ChannelStorageDuration(const Configuration *config,
+ const Channel *channel);
+
// Adds the specified channel to the config and returns the new, merged, config.
// The channel name is derived from the specified name, the type and schema from
// the provided schema, the source node from the specified node, and all other
diff --git a/aos/events/BUILD b/aos/events/BUILD
index 08e3998..a65c71a 100644
--- a/aos/events/BUILD
+++ b/aos/events/BUILD
@@ -96,6 +96,20 @@
)
cc_library(
+ name = "context",
+ hdrs = [
+ "context.h",
+ ],
+ target_compatible_with = ["@platforms//os:linux"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//aos:flatbuffers",
+ "//aos:uuid",
+ "//aos/time",
+ ],
+)
+
+cc_library(
name = "event_loop",
srcs = [
"event_loop.cc",
@@ -109,6 +123,7 @@
target_compatible_with = ["@platforms//os:linux"],
visibility = ["//visibility:public"],
deps = [
+ ":context",
":event_loop_fbs",
":timing_statistics",
"//aos:configuration",
diff --git a/aos/events/context.h b/aos/events/context.h
new file mode 100644
index 0000000..6648949
--- /dev/null
+++ b/aos/events/context.h
@@ -0,0 +1,70 @@
+#ifndef AOS_EVENTS_CONTEXT_H_
+#define AOS_EVENTS_CONTEXT_H_
+
+#include "aos/flatbuffers.h"
+#include "aos/time/time.h"
+#include "aos/uuid.h"
+
+namespace aos {
+
+// Struct available on Watchers, Fetchers, Timers, and PhasedLoops with context
+// about the current message.
+struct Context {
+ // Time that the message was sent on this node, or the timer was triggered.
+ monotonic_clock::time_point monotonic_event_time;
+ // Realtime the message was sent on this node. This is set to min_time for
+ // Timers and PhasedLoops.
+ realtime_clock::time_point realtime_event_time;
+
+ // The rest are only valid for Watchers and Fetchers.
+
+ // For a single-node configuration, these two are identical to *_event_time.
+ // In a multinode configuration, these are the times that the message was
+ // sent on the original node.
+ monotonic_clock::time_point monotonic_remote_time;
+ realtime_clock::time_point realtime_remote_time;
+
+ // Index in the queue.
+ uint32_t queue_index;
+ // Index into the remote queue. Useful to determine if data was lost. In a
+ // single-node configuration, this will match queue_index.
+ uint32_t remote_queue_index;
+
+ // Size of the data sent.
+ size_t size;
+ // Pointer to the data.
+ const void *data;
+
+ // Index of the message buffer. This will be in [0, NumberBuffers) on
+ // read_method=PIN channels, and -1 for other channels.
+ //
+ // This only tells you about the underlying storage for this message, not
+ // anything about its position in the queue. This is only useful for advanced
+ // zero-copy use cases, on read_method=PIN channels.
+ //
+ // This will uniquely identify a message on this channel at a point in time.
+ // For senders, this point in time is while the sender has the message. With
+ // read_method==PIN, this point in time includes while the caller has access
+ // to this context. For other read_methods, this point in time may be before
+ // the caller has access to this context, which makes this pretty useless.
+ int buffer_index;
+
+ // UUID of the remote node which sent this message, or this node in the case
+ // of events which are local to this node.
+ UUID source_boot_uuid = UUID::Zero();
+
+ // Efficiently copies the flatbuffer into a FlatbufferVector, allocating
+ // memory in the process. It is vital that T matches the type of the
+ // underlying flatbuffer.
+ template <typename T>
+ FlatbufferVector<T> CopyFlatBuffer() const {
+ ResizeableBuffer buffer;
+ buffer.resize(size);
+ memcpy(buffer.data(), data, size);
+ return FlatbufferVector<T>(std::move(buffer));
+ }
+};
+
+} // namespace aos
+
+#endif // AOS_EVENTS_CONTEXT_H_
diff --git a/aos/events/event_loop.cc b/aos/events/event_loop.cc
index e84763e..072630f 100644
--- a/aos/events/event_loop.cc
+++ b/aos/events/event_loop.cc
@@ -298,7 +298,7 @@
// of the buffer. We only have to care because we are using this in a very
// raw fashion.
CHECK_LE(timing_report_.span().size(), timing_report_sender_->size())
- << ": Timing report bigger than the sender size.";
+ << ": Timing report bigger than the sender size for " << name() << ".";
std::copy(timing_report_.span().data(),
timing_report_.span().data() + timing_report_.span().size(),
reinterpret_cast<uint8_t *>(timing_report_sender_->data()) +
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index e53bc21..17e3e00 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -15,6 +15,7 @@
#include "aos/configuration.h"
#include "aos/configuration_generated.h"
#include "aos/events/channel_preallocated_allocator.h"
+#include "aos/events/context.h"
#include "aos/events/event_loop_event.h"
#include "aos/events/event_loop_generated.h"
#include "aos/events/timing_statistics.h"
@@ -34,64 +35,6 @@
class EventLoop;
class WatcherState;
-// Struct available on Watchers, Fetchers, Timers, and PhasedLoops with context
-// about the current message.
-struct Context {
- // Time that the message was sent on this node, or the timer was triggered.
- monotonic_clock::time_point monotonic_event_time;
- // Realtime the message was sent on this node. This is set to min_time for
- // Timers and PhasedLoops.
- realtime_clock::time_point realtime_event_time;
-
- // The rest are only valid for Watchers and Fetchers.
-
- // For a single-node configuration, these two are identical to *_event_time.
- // In a multinode configuration, these are the times that the message was
- // sent on the original node.
- monotonic_clock::time_point monotonic_remote_time;
- realtime_clock::time_point realtime_remote_time;
-
- // Index in the queue.
- uint32_t queue_index;
- // Index into the remote queue. Useful to determine if data was lost. In a
- // single-node configuration, this will match queue_index.
- uint32_t remote_queue_index;
-
- // Size of the data sent.
- size_t size;
- // Pointer to the data.
- const void *data;
-
- // Index of the message buffer. This will be in [0, NumberBuffers) on
- // read_method=PIN channels, and -1 for other channels.
- //
- // This only tells you about the underlying storage for this message, not
- // anything about its position in the queue. This is only useful for advanced
- // zero-copy use cases, on read_method=PIN channels.
- //
- // This will uniquely identify a message on this channel at a point in time.
- // For senders, this point in time is while the sender has the message. With
- // read_method==PIN, this point in time includes while the caller has access
- // to this context. For other read_methods, this point in time may be before
- // the caller has access to this context, which makes this pretty useless.
- int buffer_index;
-
- // UUID of the remote node which sent this message, or this node in the case
- // of events which are local to this node.
- UUID source_boot_uuid = UUID::Zero();
-
- // Efficiently copies the flatbuffer into a FlatbufferVector, allocating
- // memory in the process. It is vital that T matches the type of the
- // underlying flatbuffer.
- template <typename T>
- FlatbufferVector<T> CopyFlatBuffer() const {
- ResizeableBuffer buffer;
- buffer.resize(size);
- memcpy(buffer.data(), data, size);
- return FlatbufferVector<T>(std::move(buffer));
- }
-};
-
// Raw version of fetcher. Contains a local variable that the fetcher will
// update. This is used for reflection and as an interface to implement typed
// fetchers.
diff --git a/aos/events/event_loop_param_test.cc b/aos/events/event_loop_param_test.cc
index 7e79a8b..0559aae 100644
--- a/aos/events/event_loop_param_test.cc
+++ b/aos/events/event_loop_param_test.cc
@@ -2241,7 +2241,9 @@
// Sanity check channel frequencies to ensure that we've designed the test
// correctly.
ASSERT_EQ(800, sender.channel()->frequency());
- ASSERT_EQ(2000000000, loop1->configuration()->channel_storage_duration());
+ ASSERT_EQ(2000000000, configuration::ChannelStorageDuration(
+ loop1->configuration(), sender.channel())
+ .count());
constexpr int kMaxAllowedMessages = 800 * 2;
constexpr int kSendMessages = kMaxAllowedMessages * 2;
constexpr int kDroppedMessages = kSendMessages - kMaxAllowedMessages;
@@ -3195,15 +3197,8 @@
}
int TestChannelQueueSize(EventLoop *event_loop) {
- const int frequency = TestChannelFrequency(event_loop);
- const auto channel_storage_duration = std::chrono::nanoseconds(
- event_loop->configuration()->channel_storage_duration());
- const int queue_size =
- frequency * std::chrono::duration_cast<std::chrono::duration<double>>(
- channel_storage_duration)
- .count();
-
- return queue_size;
+ return configuration::QueueSize(event_loop->configuration(),
+ event_loop->GetChannel<TestMessage>("/test"));
}
RawSender::Error SendTestMessage(aos::Sender<TestMessage> &sender) {
@@ -3244,10 +3239,9 @@
});
const auto kRepeatOffset = std::chrono::milliseconds(1);
- const auto base_offset =
- std::chrono::nanoseconds(
- event_loop->configuration()->channel_storage_duration()) -
- (kRepeatOffset * (queue_size / 2));
+ const auto base_offset = configuration::ChannelStorageDuration(
+ event_loop->configuration(), sender.channel()) -
+ (kRepeatOffset * (queue_size / 2));
event_loop->OnRun([&event_loop, &timer, &base_offset, &kRepeatOffset]() {
timer->Schedule(event_loop->monotonic_now() + base_offset, kRepeatOffset);
});
@@ -3271,8 +3265,8 @@
const std::chrono::milliseconds kInterval = std::chrono::milliseconds(10);
const monotonic_clock::duration channel_storage_duration =
- std::chrono::nanoseconds(
- event_loop->configuration()->channel_storage_duration());
+ configuration::ChannelStorageDuration(event_loop->configuration(),
+ sender.channel());
const int queue_size = TestChannelQueueSize(event_loop.get());
int msgs_sent = 0;
@@ -3338,5 +3332,24 @@
EXPECT_EQ(SendTestMessage(sender2), RawSender::Error::kMessagesSentTooFast);
}
+// Tests that a longer storage durations store more messages.
+TEST_P(AbstractEventLoopTest, SendingTooFastWithLongDuration) {
+ auto loop1 = MakePrimary();
+
+ auto sender1 = loop1->MakeSender<TestMessage>("/test3");
+
+ // Send queue_size messages split between the senders.
+ const int queue_size =
+ configuration::QueueSize(loop1->configuration(), sender1.channel());
+ EXPECT_EQ(queue_size, 100 * 10);
+ for (int i = 0; i < queue_size; i++) {
+ ASSERT_EQ(SendTestMessage(sender1), RawSender::Error::kOk);
+ }
+
+ // Since queue_size messages have been sent, and little time has elapsed,
+ // this should return an error.
+ EXPECT_EQ(SendTestMessage(sender1), RawSender::Error::kMessagesSentTooFast);
+}
+
} // namespace testing
} // namespace aos
diff --git a/aos/events/event_loop_param_test.h b/aos/events/event_loop_param_test.h
index fe85732..d466a1e 100644
--- a/aos/events/event_loop_param_test.h
+++ b/aos/events/event_loop_param_test.h
@@ -47,6 +47,11 @@
{
"name": "/test2",
"type": "aos.TestMessage"
+ },
+ {
+ "name": "/test3",
+ "type": "aos.TestMessage",
+ "channel_storage_duration": 10000000000
}
]
})config",
@@ -97,6 +102,11 @@
{
"name": "/test2",
"type": "aos.TestMessage"
+ },
+ {
+ "name": "/test3",
+ "type": "aos.TestMessage",
+ "channel_storage_duration": 10000000000
}
]
})config",
@@ -139,6 +149,13 @@
"type": "aos.TestMessage",
"read_method": "PIN",
"num_readers": 10
+ },
+ {
+ "name": "/test3",
+ "type": "aos.TestMessage",
+ "read_method": "PIN",
+ "num_readers": 10,
+ "channel_storage_duration": 10000000000
}
]
})config";
diff --git a/aos/events/logging/BUILD b/aos/events/logging/BUILD
index 4e58923..2516e6c 100644
--- a/aos/events/logging/BUILD
+++ b/aos/events/logging/BUILD
@@ -88,7 +88,7 @@
srcs = ["s3_file_operations.cc"],
hdrs = ["s3_file_operations.h"],
deps = [
- ":file_operations",
+ ":log_backend",
":s3_fetcher",
],
)
@@ -390,6 +390,7 @@
":logger_fbs",
"//aos:uuid",
"@com_github_google_flatbuffers//:flatbuffers",
+ "@com_github_google_glog//:glog",
],
)
diff --git a/aos/events/logging/file_operations.cc b/aos/events/logging/file_operations.cc
index 75910e5..04e695e 100644
--- a/aos/events/logging/file_operations.cc
+++ b/aos/events/logging/file_operations.cc
@@ -11,13 +11,16 @@
absl::EndsWith(filename, ".bfbs.sz");
}
-void LocalFileOperations::FindLogs(std::vector<std::string> *files) {
- auto MaybeAddFile = [&files](std::string_view filename) {
+void LocalFileOperations::FindLogs(std::vector<File> *files) {
+ auto MaybeAddFile = [&files](std::string_view filename, size_t size) {
if (!IsValidFilename(filename)) {
VLOG(1) << "Ignoring " << filename << " with invalid extension.";
} else {
VLOG(1) << "Found log " << filename;
- files->emplace_back(filename);
+ files->emplace_back(File{
+ .name = std::string(filename),
+ .size = size,
+ });
}
};
if (std::filesystem::is_directory(filename_)) {
@@ -28,10 +31,10 @@
VLOG(1) << file << " is not file.";
continue;
}
- MaybeAddFile(file.path().string());
+ MaybeAddFile(file.path().string(), file.file_size());
}
} else {
- MaybeAddFile(filename_);
+ MaybeAddFile(filename_, std::filesystem::file_size(filename_));
}
}
diff --git a/aos/events/logging/file_operations.h b/aos/events/logging/file_operations.h
index faf63cb..538ae60 100644
--- a/aos/events/logging/file_operations.h
+++ b/aos/events/logging/file_operations.h
@@ -14,10 +14,15 @@
// associated with either a single file or directory that contains log files.
class FileOperations {
public:
+ struct File {
+ std::string name;
+ size_t size;
+ };
+
virtual ~FileOperations() = default;
virtual bool Exists() = 0;
- virtual void FindLogs(std::vector<std::string> *files) = 0;
+ virtual void FindLogs(std::vector<File> *files) = 0;
};
// Implements FileOperations with standard POSIX filesystem APIs. These work on
@@ -29,7 +34,7 @@
bool Exists() override { return std::filesystem::exists(filename_); }
- void FindLogs(std::vector<std::string> *files) override;
+ void FindLogs(std::vector<File> *files) override;
private:
std::string filename_;
diff --git a/aos/events/logging/log_backend.cc b/aos/events/logging/log_backend.cc
index 6cd9a7d..f8a6846 100644
--- a/aos/events/logging/log_backend.cc
+++ b/aos/events/logging/log_backend.cc
@@ -334,32 +334,35 @@
base_name_(base_name),
separator_(base_name_.back() == '/' ? "" : "_") {}
-std::unique_ptr<LogSink> FileBackend::RequestFile(std::string_view id) {
+std::unique_ptr<LogSink> FileBackend::RequestFile(const std::string_view id) {
const std::string filename = absl::StrCat(base_name_, separator_, id);
return std::make_unique<FileHandler>(filename, supports_odirect_);
}
-std::vector<std::string> FileBackend::ListFiles() const {
+std::vector<FileBackend::File> FileBackend::ListFiles() const {
std::filesystem::path directory(base_name_);
if (!is_directory(directory)) {
directory = directory.parent_path();
}
internal::LocalFileOperations operations(directory.string());
- std::vector<std::string> files;
+ std::vector<internal::FileOperations::File> files;
operations.FindLogs(&files);
- std::vector<std::string> names;
+ std::vector<File> names;
const std::string prefix = absl::StrCat(base_name_, separator_);
for (const auto &file : files) {
- CHECK(absl::StartsWith(file, prefix))
- << ": File " << file << ", prefix " << prefix;
- names.push_back(file.substr(prefix.size()));
+ CHECK(absl::StartsWith(file.name, prefix))
+ << ": File " << file.name << ", prefix " << prefix;
+ names.emplace_back(File{
+ .name = file.name.substr(prefix.size()),
+ .size = file.size,
+ });
}
return names;
}
std::unique_ptr<DataDecoder> FileBackend::GetDecoder(
- std::string_view id) const {
+ const std::string_view id) const {
const std::string filename = absl::StrCat(base_name_, separator_, id);
CHECK(std::filesystem::exists(filename));
return std::make_unique<DummyDecoder>(filename);
@@ -372,7 +375,7 @@
separator_(base_name_.back() == '/' ? "" : "_") {}
std::unique_ptr<LogSink> RenamableFileBackend::RequestFile(
- std::string_view id) {
+ const std::string_view id) {
const std::string filename =
absl::StrCat(base_name_, separator_, id, temp_suffix_);
return std::make_unique<RenamableFileHandler>(this, filename,
diff --git a/aos/events/logging/log_backend.h b/aos/events/logging/log_backend.h
index e75b632..3a47f42 100644
--- a/aos/events/logging/log_backend.h
+++ b/aos/events/logging/log_backend.h
@@ -238,6 +238,27 @@
int flags_ = 0;
};
+// Interface to decouple reading of logs and media (file system, memory or S3).
+class LogSource {
+ public:
+ struct File {
+ std::string name;
+ size_t size;
+ };
+
+ virtual ~LogSource() = default;
+
+ // Provides a list of readable sources for log reading.
+ virtual std::vector<File> ListFiles() const = 0;
+
+ // Entry point for reading the content of log file.
+ virtual std::unique_ptr<DataDecoder> GetDecoder(
+ const std::string_view id) const = 0;
+ std::unique_ptr<DataDecoder> GetDecoder(const LogSource::File &id) const {
+ return GetDecoder(id.name);
+ }
+};
+
// Interface to decouple log writing and media (file system or memory). It is
// handy to use for tests.
class LogBackend {
@@ -247,20 +268,11 @@
// Request file-like object from the log backend. It maybe a file on a disk or
// in memory. id is usually generated by log namer and looks like name of the
// file within a log folder.
- virtual std::unique_ptr<LogSink> RequestFile(std::string_view id) = 0;
-};
+ virtual std::unique_ptr<LogSink> RequestFile(const std::string_view id) = 0;
-// Interface to decouple reading of logs and media (file system, memory or S3).
-class LogSource {
- public:
- virtual ~LogSource() = default;
-
- // Provides a list of readable sources for log reading.
- virtual std::vector<std::string> ListFiles() const = 0;
-
- // Entry point for reading the content of log file.
- virtual std::unique_ptr<DataDecoder> GetDecoder(
- std::string_view id) const = 0;
+ std::unique_ptr<LogSink> RequestFile(const LogSource::File &id) {
+ return RequestFile(id.name);
+ }
};
// Implements requests log files from file system.
@@ -271,13 +283,14 @@
~FileBackend() override = default;
// Request file from a file system. It is not open yet.
- std::unique_ptr<LogSink> RequestFile(std::string_view id) override;
+ std::unique_ptr<LogSink> RequestFile(const std::string_view id) override;
// List all files that looks like log files under base_name.
- std::vector<std::string> ListFiles() const override;
+ std::vector<File> ListFiles() const override;
// Open decoder to read the content of the file.
- std::unique_ptr<DataDecoder> GetDecoder(std::string_view id) const override;
+ std::unique_ptr<DataDecoder> GetDecoder(
+ const std::string_view id) const override;
private:
const bool supports_odirect_;
@@ -309,7 +322,7 @@
~RenamableFileBackend() = default;
// Request file from a file system. It is not open yet.
- std::unique_ptr<LogSink> RequestFile(std::string_view id) override;
+ std::unique_ptr<LogSink> RequestFile(const std::string_view id) override;
// TODO (Alexei): it is called by Logger, and left here for compatibility.
// Logger should not call it.
diff --git a/aos/events/logging/log_backend_test.cc b/aos/events/logging/log_backend_test.cc
index 08b6f1e..d3c83cc 100644
--- a/aos/events/logging/log_backend_test.cc
+++ b/aos/events/logging/log_backend_test.cc
@@ -25,6 +25,8 @@
}
} // namespace
+MATCHER_P(FileEq, o, "") { return arg.name == o.name && arg.size == o.size; }
+
TEST(LogBackendTest, CreateSimpleFile) {
const std::string logevent = aos::testing::TestTmpDir() + "/logevent/";
const std::string filename = "test.bfbs";
@@ -37,7 +39,11 @@
EXPECT_EQ(file->Close(), WriteCode::kOk);
EXPECT_TRUE(std::filesystem::exists(logevent + filename));
- EXPECT_THAT(backend.ListFiles(), ::testing::ElementsAre(filename));
+ EXPECT_THAT(backend.ListFiles(),
+ ::testing::ElementsAre(FileEq(LogSource::File{
+ .name = filename,
+ .size = 4,
+ })));
auto decoder = backend.GetDecoder(filename);
std::vector<uint8_t> buffer;
diff --git a/aos/events/logging/log_config_extractor.cc b/aos/events/logging/log_config_extractor.cc
index da0c65f..f2df5b3 100644
--- a/aos/events/logging/log_config_extractor.cc
+++ b/aos/events/logging/log_config_extractor.cc
@@ -98,11 +98,8 @@
WriteFlatbufferToJson(output_path + ".json", config);
LOG(INFO) << "Done writing json to " << output_path << ".json";
} else {
- const std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
const std::vector<aos::logger::LogFile> logfiles =
- aos::logger::SortParts(unsorted_logfiles);
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv));
WriteConfig(logfiles[0].config.get(), output_path);
}
diff --git a/aos/events/logging/log_namer.cc b/aos/events/logging/log_namer.cc
index 0e914d9..28223b2 100644
--- a/aos/events/logging/log_namer.cc
+++ b/aos/events/logging/log_namer.cc
@@ -730,7 +730,7 @@
if (!data_writer.second.writer) continue;
data_writer.second.writer->WriteStatistics()->ResetStats();
}
- if (data_writer_) {
+ if (data_writer_ != nullptr && data_writer_->writer != nullptr) {
data_writer_->writer->WriteStatistics()->ResetStats();
}
max_write_time_ = std::chrono::nanoseconds::zero();
diff --git a/aos/events/logging/log_namer.h b/aos/events/logging/log_namer.h
index 8c62559..8d69fbb 100644
--- a/aos/events/logging/log_namer.h
+++ b/aos/events/logging/log_namer.h
@@ -9,6 +9,7 @@
#include "absl/container/btree_map.h"
#include "flatbuffers/flatbuffers.h"
+#include "glog/logging.h"
#include "aos/events/logging/logfile_utils.h"
#include "aos/events/logging/logger_generated.h"
@@ -346,6 +347,7 @@
bool ran_out_of_space() const {
return accumulate_data_writers<bool>(
ran_out_of_space_, [](bool x, const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
return x ||
(data_writer.writer && data_writer.writer->ran_out_of_space());
});
@@ -358,6 +360,7 @@
size_t maximum_total_bytes() const {
return accumulate_data_writers<size_t>(
0, [](size_t x, const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
return std::max(x, data_writer.writer->total_bytes());
});
}
@@ -374,6 +377,7 @@
return accumulate_data_writers(
max_write_time_,
[](std::chrono::nanoseconds x, const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
return std::max(
x, data_writer.writer->WriteStatistics()->max_write_time());
});
@@ -383,6 +387,7 @@
std::make_tuple(max_write_time_bytes_, max_write_time_),
[](std::tuple<int, std::chrono::nanoseconds> x,
const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
if (data_writer.writer->WriteStatistics()->max_write_time() >
std::get<1>(x)) {
return std::make_tuple(
@@ -397,6 +402,7 @@
std::make_tuple(max_write_time_messages_, max_write_time_),
[](std::tuple<int, std::chrono::nanoseconds> x,
const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
if (data_writer.writer->WriteStatistics()->max_write_time() >
std::get<1>(x)) {
return std::make_tuple(
@@ -411,12 +417,14 @@
return accumulate_data_writers(
total_write_time_,
[](std::chrono::nanoseconds x, const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
return x + data_writer.writer->WriteStatistics()->total_write_time();
});
}
int total_write_count() const {
return accumulate_data_writers(
total_write_count_, [](int x, const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
return x + data_writer.writer->WriteStatistics()->total_write_count();
});
}
@@ -430,6 +438,7 @@
int total_write_bytes() const {
return accumulate_data_writers(
total_write_bytes_, [](int x, const NewDataWriter &data_writer) {
+ CHECK_NOTNULL(data_writer.writer);
return x + data_writer.writer->WriteStatistics()->total_write_bytes();
});
}
@@ -463,10 +472,11 @@
T accumulate_data_writers(T t, BinaryOperation op) const {
for (const std::pair<const Channel *const, NewDataWriter> &data_writer :
data_writers_) {
- if (!data_writer.second.writer) continue;
- t = op(std::move(t), data_writer.second);
+ if (data_writer.second.writer != nullptr) {
+ t = op(std::move(t), data_writer.second);
+ }
}
- if (data_writer_) {
+ if (data_writer_ != nullptr && data_writer_->writer != nullptr) {
t = op(std::move(t), *data_writer_);
}
return t;
diff --git a/aos/events/logging/log_reader.cc b/aos/events/logging/log_reader.cc
index 874fe43..e3dd904 100644
--- a/aos/events/logging/log_reader.cc
+++ b/aos/events/logging/log_reader.cc
@@ -98,6 +98,10 @@
std::string_view new_name,
std::string_view new_type,
flatbuffers::FlatBufferBuilder *fbb) {
+ CHECK_EQ(Channel::MiniReflectTypeTable()->num_elems, 14u)
+ << ": Merging logic needs to be updated when the number of channel "
+ "fields changes.";
+
flatbuffers::Offset<flatbuffers::String> name_offset =
fbb->CreateSharedString(new_name.empty() ? c->name()->string_view()
: new_name);
@@ -148,6 +152,9 @@
if (c->has_num_readers()) {
channel_builder.add_num_readers(c->num_readers());
}
+ if (c->has_channel_storage_duration()) {
+ channel_builder.add_channel_storage_duration(c->channel_storage_duration());
+ }
return channel_builder.Finish();
}
@@ -1475,7 +1482,7 @@
fbb.ForceDefaults(true);
std::vector<flatbuffers::Offset<Channel>> channel_offsets;
- CHECK_EQ(Channel::MiniReflectTypeTable()->num_elems, 13u)
+ CHECK_EQ(Channel::MiniReflectTypeTable()->num_elems, 14u)
<< ": Merging logic needs to be updated when the number of channel "
"fields changes.";
@@ -1547,6 +1554,10 @@
if (c->has_frequency()) {
channel_builder.add_frequency(c->frequency());
}
+ if (c->has_channel_storage_duration()) {
+ channel_builder.add_channel_storage_duration(
+ c->channel_storage_duration());
+ }
channel_offsets.emplace_back(channel_builder.Finish());
}
break;
diff --git a/aos/events/logging/log_reader_utils_test.cc b/aos/events/logging/log_reader_utils_test.cc
index 911c4eb..b61c9de 100644
--- a/aos/events/logging/log_reader_utils_test.cc
+++ b/aos/events/logging/log_reader_utils_test.cc
@@ -134,10 +134,10 @@
util::WriteStringToFileOrDie(log_file, "test");
internal::LocalFileOperations file_op(log_file);
EXPECT_TRUE(file_op.Exists());
- std::vector<std::string> logs;
+ std::vector<internal::LocalFileOperations::File> logs;
file_op.FindLogs(&logs);
ASSERT_EQ(logs.size(), 1);
- EXPECT_EQ(logs.front(), log_file);
+ EXPECT_EQ(logs.front().name, log_file);
}
// Verify that it is OK to list folder with log file.
@@ -148,10 +148,10 @@
util::WriteStringToFileOrDie(log_file, "test");
internal::LocalFileOperations file_op(log_folder);
EXPECT_TRUE(file_op.Exists());
- std::vector<std::string> logs;
+ std::vector<internal::LocalFileOperations::File> logs;
file_op.FindLogs(&logs);
ASSERT_EQ(logs.size(), 1);
- EXPECT_EQ(logs.front(), log_file);
+ EXPECT_EQ(logs.front().name, log_file);
}
} // namespace aos::logger::testing
diff --git a/aos/events/logging/log_replayer.cc b/aos/events/logging/log_replayer.cc
index 7231a91..b69df4a 100644
--- a/aos/events/logging/log_replayer.cc
+++ b/aos/events/logging/log_replayer.cc
@@ -52,11 +52,8 @@
namespace aos::logger {
int Main(int argc, char *argv[]) {
- const std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
const std::vector<aos::logger::LogFile> logfiles =
- aos::logger::SortParts(unsorted_logfiles);
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv));
aos::logger::LogReader config_reader(logfiles);
aos::FlatbufferDetachedBuffer<aos::Configuration> config =
diff --git a/aos/events/logging/log_stats.cc b/aos/events/logging/log_stats.cc
index ea265b5..bbed18f 100644
--- a/aos/events/logging/log_stats.cc
+++ b/aos/events/logging/log_stats.cc
@@ -146,7 +146,7 @@
current_message_time_ = context.monotonic_event_time;
channel_storage_duration_messages_.push(current_message_time_);
while (channel_storage_duration_messages_.front() +
- std::chrono::nanoseconds(config_->channel_storage_duration()) <=
+ aos::configuration::ChannelStorageDuration(config_, channel_) <=
current_message_time_) {
channel_storage_duration_messages_.pop();
}
@@ -188,7 +188,10 @@
double max_messages_per_sec() const {
return max_messages_per_period_ /
std::min(SecondsActive(),
- 1e-9 * config_->channel_storage_duration());
+ std::chrono::duration<double>(
+ aos::configuration::ChannelStorageDuration(config_,
+ channel_))
+ .count());
}
size_t avg_message_size() const {
return total_message_size_ / total_num_messages_;
@@ -284,16 +287,8 @@
LOG(FATAL) << "Expected at least 1 logfile as an argument.";
}
- // find logfiles
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
- // sort logfiles
- const std::vector<aos::logger::LogFile> logfiles =
- aos::logger::SortParts(unsorted_logfiles);
-
- // open logfiles
- aos::logger::LogReader reader(logfiles);
+ aos::logger::LogReader reader(
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv)));
LogfileStats logfile_stats;
std::vector<ChannelStats> channel_stats;
diff --git a/aos/events/logging/log_writer.cc b/aos/events/logging/log_writer.cc
index d129073..db6bcc1 100644
--- a/aos/events/logging/log_writer.cc
+++ b/aos/events/logging/log_writer.cc
@@ -231,18 +231,23 @@
}
}
+aos::SizePrefixedFlatbufferDetachedBuffer<LogFileHeader> PackConfiguration(
+ const Configuration *const configuration) {
+ flatbuffers::FlatBufferBuilder fbb;
+ flatbuffers::Offset<aos::Configuration> configuration_offset =
+ CopyFlatBuffer(configuration, &fbb);
+ LogFileHeader::Builder log_file_header_builder(fbb);
+ log_file_header_builder.add_configuration(configuration_offset);
+ fbb.FinishSizePrefixed(log_file_header_builder.Finish());
+ return fbb.Release();
+}
+
std::string Logger::WriteConfiguration(LogNamer *log_namer) {
std::string config_sha256;
if (separate_config_) {
- flatbuffers::FlatBufferBuilder fbb;
- flatbuffers::Offset<aos::Configuration> configuration_offset =
- CopyFlatBuffer(configuration_, &fbb);
- LogFileHeader::Builder log_file_header_builder(fbb);
- log_file_header_builder.add_configuration(configuration_offset);
- fbb.FinishSizePrefixed(log_file_header_builder.Finish());
- aos::SizePrefixedFlatbufferDetachedBuffer<LogFileHeader> config_header(
- fbb.Release());
+ aos::SizePrefixedFlatbufferDetachedBuffer<LogFileHeader> config_header =
+ PackConfiguration(configuration_);
config_sha256 = Sha256(config_header.span());
LOG(INFO) << "Config sha256 of " << config_sha256;
log_namer->WriteConfiguration(&config_header, config_sha256);
diff --git a/aos/events/logging/log_writer.h b/aos/events/logging/log_writer.h
index 114e874..837a02a 100644
--- a/aos/events/logging/log_writer.h
+++ b/aos/events/logging/log_writer.h
@@ -20,6 +20,11 @@
namespace aos {
namespace logger {
+// Packs the provided configuration into the separate config LogFileHeader
+// container.
+aos::SizePrefixedFlatbufferDetachedBuffer<LogFileHeader> PackConfiguration(
+ const Configuration *const configuration);
+
// Logs all channels available in the event loop to disk every 100 ms.
// Start by logging one message per channel to capture any state and
// configuration that is sent rately on a channel and would affect execution.
diff --git a/aos/events/logging/logfile_sorting.cc b/aos/events/logging/logfile_sorting.cc
index a343356..59d18d0 100644
--- a/aos/events/logging/logfile_sorting.cc
+++ b/aos/events/logging/logfile_sorting.cc
@@ -110,18 +110,19 @@
} // namespace
-void FindLogs(std::vector<std::string> *files, std::string filename) {
+void FindLogs(std::vector<internal::FileOperations::File> *files,
+ std::string filename) {
MakeFileOperations(filename)->FindLogs(files);
}
-std::vector<std::string> FindLogs(std::string filename) {
- std::vector<std::string> files;
+std::vector<internal::FileOperations::File> FindLogs(std::string filename) {
+ std::vector<internal::FileOperations::File> files;
FindLogs(&files, filename);
return files;
}
-std::vector<std::string> FindLogs(int argc, char **argv) {
- std::vector<std::string> found_logfiles;
+std::vector<internal::FileOperations::File> FindLogs(int argc, char **argv) {
+ std::vector<internal::FileOperations::File> found_logfiles;
for (int i = 1; i < argc; i++) {
std::string filename = argv[i];
@@ -320,8 +321,9 @@
std::vector<UnsortedOldParts> old_parts;
// Populates the class's datastructures from the input file list.
- void PopulateFromFiles(ReadersPool *readers,
- const std::vector<std::string> &parts);
+ void PopulateFromFiles(
+ ReadersPool *readers,
+ const std::vector<internal::FileOperations::File> &parts);
// Wrangles everything into a map of boot uuids -> boot counts.
MapBoots ComputeBootCounts();
@@ -342,18 +344,19 @@
std::vector<LogFile> SortParts();
};
-void PartsSorter::PopulateFromFiles(ReadersPool *readers,
- const std::vector<std::string> &parts) {
+void PartsSorter::PopulateFromFiles(
+ ReadersPool *readers,
+ const std::vector<internal::FileOperations::File> &parts) {
// Now extract everything into our datastructures above for sorting.
- for (const std::string &part : parts) {
- SpanReader *reader = readers->BorrowReader(part);
+ for (const internal::FileOperations::File &part : parts) {
+ SpanReader *reader = readers->BorrowReader(part.name);
std::optional<SizePrefixedFlatbufferVector<LogFileHeader>> log_header =
ReadHeader(reader);
if (!log_header) {
if (!FLAGS_quiet_sorting) {
- LOG(WARNING) << "Skipping " << part << " without a header";
+ LOG(WARNING) << "Skipping " << part.name << " without a header";
}
- corrupted.emplace_back(part);
+ corrupted.emplace_back(part.name);
continue;
}
@@ -433,11 +436,12 @@
continue;
}
- VLOG(1) << "Header " << FlatbufferToJson(log_header.value()) << " " << part;
+ VLOG(1) << "Header " << FlatbufferToJson(log_header.value()) << " "
+ << part.name;
if (configuration_sha256.empty()) {
CHECK(log_header->message().has_configuration())
- << ": Failed to find header on " << part;
+ << ": Failed to find header on " << part.name;
// If we don't have a configuration_sha256, we need to have a
// configuration directly inside the header. This ends up being a bit
// unwieldy to deal with, so let's instead copy the configuration, hash
@@ -464,7 +468,7 @@
configuration_sha256 = std::move(config_copy_sha256);
} else {
CHECK(!log_header->message().has_configuration())
- << ": Found header where one shouldn't be on " << part;
+ << ": Found header where one shouldn't be on " << part.name;
config_sha256_list.emplace(configuration_sha256);
}
@@ -475,12 +479,12 @@
!log_header->message().has_parts_index() &&
!log_header->message().has_node()) {
std::optional<SizePrefixedFlatbufferVector<MessageHeader>> first_message =
- ReadNthMessage(part, 0);
+ ReadNthMessage(part.name, 0);
if (!first_message) {
if (!FLAGS_quiet_sorting) {
- LOG(WARNING) << "Skipping " << part << " without any messages";
+ LOG(WARNING) << "Skipping " << part.name << " without any messages";
}
- corrupted.emplace_back(part);
+ corrupted.emplace_back(part.name);
continue;
}
const monotonic_clock::time_point first_message_time(
@@ -500,11 +504,11 @@
old_parts.back().parts.realtime_start_time = realtime_start_time;
old_parts.back().parts.config_sha256 = configuration_sha256;
old_parts.back().unsorted_parts.emplace_back(
- std::make_pair(first_message_time, part));
+ std::make_pair(first_message_time, part.name));
old_parts.back().name = name;
} else {
result->unsorted_parts.emplace_back(
- std::make_pair(first_message_time, part));
+ std::make_pair(first_message_time, part.name));
CHECK_EQ(result->name, name);
CHECK_EQ(result->parts.config_sha256, configuration_sha256);
}
@@ -890,7 +894,7 @@
CHECK_EQ(it->second.realtime_start_time, realtime_start_time);
}
- it->second.parts.emplace_back(std::make_pair(part, parts_index));
+ it->second.parts.emplace_back(std::make_pair(part.name, parts_index));
}
}
@@ -1985,16 +1989,39 @@
}
std::vector<LogFile> SortParts(const std::vector<std::string> &parts) {
- LogReadersPool readers;
- PartsSorter sorter;
- sorter.PopulateFromFiles(&readers, parts);
- return sorter.SortParts();
+ std::vector<internal::FileOperations::File> full_parts;
+ full_parts.reserve(parts.size());
+ for (const auto &p : parts) {
+ full_parts.emplace_back(internal::FileOperations::File{
+ .name = p,
+ .size = 0u,
+ });
+ }
+ return SortParts(full_parts);
}
std::vector<LogFile> SortParts(const LogSource &log_source) {
LogReadersPool readers(&log_source);
PartsSorter sorter;
- sorter.PopulateFromFiles(&readers, log_source.ListFiles());
+ std::vector<LogSource::File> files = log_source.ListFiles();
+ std::vector<internal::FileOperations::File> arg;
+ arg.reserve(files.size());
+ for (LogSource::File &f : files) {
+ arg.emplace_back(internal::FileOperations::File{
+ .name = std::move(f.name),
+ .size = f.size,
+ });
+ }
+ sorter.PopulateFromFiles(&readers, arg);
+ return sorter.SortParts();
+}
+
+std::vector<LogFile> SortParts(
+ const std::vector<internal::FileOperations::File> &files) {
+ LogReadersPool readers;
+ PartsSorter sorter;
+
+ sorter.PopulateFromFiles(&readers, files);
return sorter.SortParts();
}
diff --git a/aos/events/logging/logfile_sorting.h b/aos/events/logging/logfile_sorting.h
index 7461e89..b7f8297 100644
--- a/aos/events/logging/logfile_sorting.h
+++ b/aos/events/logging/logfile_sorting.h
@@ -10,6 +10,7 @@
#include <vector>
#include "aos/configuration.h"
+#include "aos/events/logging/file_operations.h"
#include "aos/events/logging/log_backend.h"
#include "aos/time/time.h"
#include "aos/uuid.h"
@@ -136,20 +137,23 @@
// Takes a bunch of parts and sorts them based on part_uuid and part_index.
std::vector<LogFile> SortParts(const std::vector<std::string> &parts);
+std::vector<LogFile> SortParts(
+ const std::vector<internal::FileOperations::File> &parts);
// Sort parts of a single log.
std::vector<LogFile> SortParts(const LogSource &log_source);
// Recursively searches the file/folder for .bfbs and .bfbs.xz files and adds
// them to the vector.
-void FindLogs(std::vector<std::string> *files, std::string filename);
+void FindLogs(std::vector<internal::FileOperations::File> *files,
+ std::string filename);
// Recursively searches the file/folder for .bfbs and .bfbs.xz files and returns
// them in a vector.
-std::vector<std::string> FindLogs(std::string filename);
+std::vector<internal::FileOperations::File> FindLogs(std::string filename);
// Recursively searches for logfiles in argv[1] and onward.
-std::vector<std::string> FindLogs(int argc, char **argv);
+std::vector<internal::FileOperations::File> FindLogs(int argc, char **argv);
// Proxy container to bind log parts with log source. It helps with reading logs
// from virtual media such as memory or S3.
diff --git a/aos/events/logging/logfile_utils.cc b/aos/events/logging/logfile_utils.cc
index 36bac42..ec6bff3 100644
--- a/aos/events/logging/logfile_utils.cc
+++ b/aos/events/logging/logfile_utils.cc
@@ -111,6 +111,26 @@
*os << "null";
}
}
+
+// A dummy LogSink implementation that handles the special case when we create
+// a DetachedBufferWriter when there's no space left on the system. The
+// DetachedBufferWriter frequently dereferences log_sink_, so we want a class
+// here that effectively refuses to do anything meaningful.
+class OutOfDiskSpaceLogSink : public LogSink {
+ public:
+ WriteCode OpenForWrite() override { return WriteCode::kOutOfSpace; }
+ WriteCode Close() override { return WriteCode::kOk; }
+ bool is_open() const override { return false; }
+ WriteResult Write(
+ const absl::Span<const absl::Span<const uint8_t>> &) override {
+ return WriteResult{
+ .code = WriteCode::kOutOfSpace,
+ .messages_written = 0,
+ };
+ }
+ std::string_view name() const override { return "<out_of_disk_space>"; }
+};
+
} // namespace
DetachedBufferWriter::DetachedBufferWriter(std::unique_ptr<LogSink> log_sink,
@@ -123,6 +143,10 @@
}
}
+DetachedBufferWriter::DetachedBufferWriter(already_out_of_space_t)
+ : DetachedBufferWriter(std::make_unique<OutOfDiskSpaceLogSink>(), nullptr) {
+}
+
DetachedBufferWriter::~DetachedBufferWriter() {
Close();
if (ran_out_of_space_) {
diff --git a/aos/events/logging/logfile_utils.h b/aos/events/logging/logfile_utils.h
index 9dc7d88..18902d3 100644
--- a/aos/events/logging/logfile_utils.h
+++ b/aos/events/logging/logfile_utils.h
@@ -56,11 +56,11 @@
std::unique_ptr<DataEncoder> encoder);
// Creates a dummy instance which won't even open a file. It will act as if
// opening the file ran out of space immediately.
- DetachedBufferWriter(already_out_of_space_t) : ran_out_of_space_(true) {}
+ DetachedBufferWriter(already_out_of_space_t);
DetachedBufferWriter(DetachedBufferWriter &&other);
DetachedBufferWriter(const DetachedBufferWriter &) = delete;
- ~DetachedBufferWriter();
+ virtual ~DetachedBufferWriter();
DetachedBufferWriter &operator=(DetachedBufferWriter &&other);
DetachedBufferWriter &operator=(const DetachedBufferWriter &) = delete;
diff --git a/aos/events/logging/logger_test.cc b/aos/events/logging/logger_test.cc
index f2cd702..5990a7a 100644
--- a/aos/events/logging/logger_test.cc
+++ b/aos/events/logging/logger_test.cc
@@ -545,10 +545,8 @@
});
constexpr std::chrono::microseconds kSendPeriod{10};
- const int max_legal_messages =
- ping_sender.channel()->frequency() *
- event_loop_factory.configuration()->channel_storage_duration() /
- 1000000000;
+ const int max_legal_messages = configuration::QueueSize(
+ event_loop_factory.configuration(), ping_sender.channel());
ping_spammer_event_loop->OnRun(
[&ping_spammer_event_loop, kSendPeriod, timer_handler]() {
diff --git a/aos/events/logging/multinode_logger_test_lib.h b/aos/events/logging/multinode_logger_test_lib.h
index a2aa4af..e207179 100644
--- a/aos/events/logging/multinode_logger_test_lib.h
+++ b/aos/events/logging/multinode_logger_test_lib.h
@@ -60,13 +60,13 @@
};
constexpr std::string_view kCombinedConfigSha1() {
- return "a72e2a1e21ac07b27648825151ff9b436fd80b62254839d4ac47ee3400fa9dc1";
+ return "d018002a9b780d45a69172a1e5dd1d6df49a7c6c63b9bae9125cdc0458ddc6ca";
}
constexpr std::string_view kSplitConfigSha1() {
- return "6e585268f58791591f48b1e6d00564f49e6dcec46d18c4809ec49d94afbb3b1c";
+ return "562f80087c0e95d9304127c4cb46962659b4bfc11def84253c67702b4213e6cf";
}
constexpr std::string_view kReloggedSplitConfigSha1() {
- return "6aa4cbc21e2382ea8b9ef0145e9031bf542827e29b93995dd5e203ed0c198ef7";
+ return "cb560559ee3111d7c67314e3e1a5fd7fc88e8b4cfd9d15ea71c8d1cae1c0480b";
}
LoggerState MakeLoggerState(NodeEventLoopFactory *node,
diff --git a/aos/events/logging/s3_fetcher.cc b/aos/events/logging/s3_fetcher.cc
index f44a20b..4207286 100644
--- a/aos/events/logging/s3_fetcher.cc
+++ b/aos/events/logging/s3_fetcher.cc
@@ -179,22 +179,25 @@
get_next_chunk_ = GetS3Client().GetObjectCallable(get_request);
}
-std::vector<std::string> ListS3Objects(std::string_view url) {
+std::vector<std::pair<std::string, size_t>> ListS3Objects(
+ std::string_view url) {
Aws::S3::Model::ListObjectsV2Request list_request;
const ObjectName object_name = ParseUrl(url);
list_request.SetBucket(object_name.bucket);
list_request.SetPrefix(object_name.key);
Aws::S3::Model::ListObjectsV2Outcome list_outcome =
GetS3Client().ListObjectsV2(list_request);
- std::vector<std::string> result;
+ std::vector<std::pair<std::string, size_t>> result;
while (true) {
CHECK(list_outcome.IsSuccess()) << ": Listing objects for " << url
<< " failed: " << list_outcome.GetError();
auto &list_result = list_outcome.GetResult();
for (const Aws::S3::Model::Object &object : list_result.GetContents()) {
- result.push_back(absl::StrCat("s3://", list_outcome.GetResult().GetName(),
- "/", object.GetKey()));
- VLOG(2) << "got " << result.back();
+ result.emplace_back(
+ absl::StrCat("s3://", list_outcome.GetResult().GetName(), "/",
+ object.GetKey()),
+ object.GetSize());
+ VLOG(2) << "got " << result.back().first;
}
if (!list_result.GetIsTruncated()) {
break;
diff --git a/aos/events/logging/s3_fetcher.h b/aos/events/logging/s3_fetcher.h
index 41bc114..73584cb 100644
--- a/aos/events/logging/s3_fetcher.h
+++ b/aos/events/logging/s3_fetcher.h
@@ -54,8 +54,8 @@
ObjectName ParseUrl(std::string_view url);
// Does an S3 object listing with the given URL prefix. Returns the URLs for all
-// the objects under it.
-std::vector<std::string> ListS3Objects(std::string_view url);
+// the objects under it, and the size.
+std::vector<std::pair<std::string, size_t>> ListS3Objects(std::string_view url);
} // namespace aos::logger
diff --git a/aos/events/logging/s3_file_operations.cc b/aos/events/logging/s3_file_operations.cc
index 8764d4a..129b7e2 100644
--- a/aos/events/logging/s3_file_operations.cc
+++ b/aos/events/logging/s3_file_operations.cc
@@ -4,17 +4,30 @@
namespace aos::logger::internal {
-S3FileOperations::S3FileOperations(std::string_view url)
- : object_urls_(ListS3Objects(url)) {}
+std::vector<FileOperations::File> Convert(
+ std::vector<std::pair<std::string, size_t>> &&input) {
+ std::vector<FileOperations::File> result;
+ result.reserve(input.size());
+ for (std::pair<std::string, size_t> &i : input) {
+ result.emplace_back(FileOperations::File{
+ .name = std::move(i.first),
+ .size = i.second,
+ });
+ }
+ return result;
+}
-void S3FileOperations::FindLogs(std::vector<std::string> *files) {
+S3FileOperations::S3FileOperations(std::string_view url)
+ : object_urls_(Convert(ListS3Objects(url))) {}
+
+void S3FileOperations::FindLogs(std::vector<File> *files) {
// We already have a recursive listing, so just grab all the objects from
// there.
- for (const std::string &object_url : object_urls_) {
- if (IsValidFilename(object_url)) {
+ for (const File &object_url : object_urls_) {
+ if (IsValidFilename(object_url.name)) {
files->push_back(object_url);
}
}
}
-} // namespace aos::logger::internal
\ No newline at end of file
+} // namespace aos::logger::internal
diff --git a/aos/events/logging/s3_file_operations.h b/aos/events/logging/s3_file_operations.h
index ae94f62..e75e56b 100644
--- a/aos/events/logging/s3_file_operations.h
+++ b/aos/events/logging/s3_file_operations.h
@@ -11,10 +11,10 @@
bool Exists() final { return !object_urls_.empty(); }
- void FindLogs(std::vector<std::string> *files) final;
+ void FindLogs(std::vector<File> *files) final;
private:
- const std::vector<std::string> object_urls_;
+ const std::vector<File> object_urls_;
};
} // namespace aos::logger::internal
diff --git a/aos/events/logging/single_node_merge.cc b/aos/events/logging/single_node_merge.cc
index afb88af..aee2c5e 100644
--- a/aos/events/logging/single_node_merge.cc
+++ b/aos/events/logging/single_node_merge.cc
@@ -20,8 +20,7 @@
namespace chrono = std::chrono;
int Main(int argc, char **argv) {
- const std::vector<std::string> unsorted_logfiles = FindLogs(argc, argv);
- const LogFilesContainer log_files(SortParts(unsorted_logfiles));
+ const LogFilesContainer log_files(SortParts(FindLogs(argc, argv)));
const Configuration *config = log_files.config();
// Haven't tested this on a single node log, and don't really see a need to
diff --git a/aos/events/logging/timestamp_extractor.cc b/aos/events/logging/timestamp_extractor.cc
index 72f9de8..0b22f93 100644
--- a/aos/events/logging/timestamp_extractor.cc
+++ b/aos/events/logging/timestamp_extractor.cc
@@ -18,8 +18,7 @@
namespace chrono = std::chrono;
int Main(int argc, char **argv) {
- const std::vector<std::string> unsorted_logfiles = FindLogs(argc, argv);
- const LogFilesContainer log_files(SortParts(unsorted_logfiles));
+ const LogFilesContainer log_files(SortParts(FindLogs(argc, argv)));
const Configuration *config = log_files.config();
CHECK(configuration::MultiNode(config))
diff --git a/aos/events/shm_event_loop.cc b/aos/events/shm_event_loop.cc
index 09cd488..8ed6e0e 100644
--- a/aos/events/shm_event_loop.cc
+++ b/aos/events/shm_event_loop.cc
@@ -40,8 +40,18 @@
DEFINE_string(shm_base, "/dev/shm/aos",
"Directory to place queue backing mmaped files in.");
+// This value is affected by the umask of the process which is calling it
+// and is set to the user's value by default (check yours running `umask` on
+// the command line).
+// Any file mode requested is transformed using: mode & ~umask and the default
+// umask is 0022 (allow any permissions for the user, dont allow writes for
+// groups or others).
+// See https://man7.org/linux/man-pages/man2/umask.2.html for more details.
+// WITH THE DEFAULT UMASK YOU WONT ACTUALLY GET THESE PERMISSIONS :)
DEFINE_uint32(permissions, 0770,
- "Permissions to make shared memory files and folders.");
+ "Permissions to make shared memory files and folders, "
+ "affected by the process's umask. "
+ "See shm_event_loop.cc for more details.");
DEFINE_string(application_name, Filename(program_invocation_name),
"The application name");
@@ -216,7 +226,8 @@
queue_index.index(), &context_.monotonic_event_time,
&context_.realtime_event_time, &context_.monotonic_remote_time,
&context_.realtime_remote_time, &context_.remote_queue_index,
- &context_.source_boot_uuid, &context_.size, copy_buffer);
+ &context_.source_boot_uuid, &context_.size, copy_buffer,
+ std::ref(should_fetch_));
if (read_result == ipc_lib::LocklessQueueReader::Result::GOOD) {
if (pin_data()) {
@@ -310,6 +321,11 @@
std::optional<ipc_lib::LocklessQueuePinner> pinner_;
Context context_;
+
+ // Pre-allocated should_fetch function so we don't allocate.
+ std::function<bool(const Context &)> should_fetch_ = [](const Context &) {
+ return true;
+ };
};
class ShmFetcher : public RawFetcher {
@@ -379,12 +395,12 @@
: RawSender(event_loop, channel),
lockless_queue_memory_(shm_base, FLAGS_permissions,
event_loop->configuration(), channel),
- lockless_queue_sender_(VerifySender(
- ipc_lib::LocklessQueueSender::Make(
- lockless_queue_memory_.queue(),
- std::chrono::nanoseconds(
- event_loop->configuration()->channel_storage_duration())),
- channel)),
+ lockless_queue_sender_(
+ VerifySender(ipc_lib::LocklessQueueSender::Make(
+ lockless_queue_memory_.queue(),
+ configuration::ChannelStorageDuration(
+ event_loop->configuration(), channel)),
+ channel)),
wake_upper_(lockless_queue_memory_.queue()) {}
~ShmSender() override { shm_event_loop()->CheckCurrentThread(); }
diff --git a/aos/events/shm_event_loop_test.cc b/aos/events/shm_event_loop_test.cc
index 5d58c29..6b40b9d 100644
--- a/aos/events/shm_event_loop_test.cc
+++ b/aos/events/shm_event_loop_test.cc
@@ -26,12 +26,12 @@
}
// Clean up anything left there before.
- unlink((FLAGS_shm_base + "/test/aos.TestMessage.v4").c_str());
- unlink((FLAGS_shm_base + "/test1/aos.TestMessage.v4").c_str());
- unlink((FLAGS_shm_base + "/test2/aos.TestMessage.v4").c_str());
- unlink((FLAGS_shm_base + "/test2/aos.TestMessage.v4").c_str());
- unlink((FLAGS_shm_base + "/aos/aos.timing.Report.v4").c_str());
- unlink((FLAGS_shm_base + "/aos/aos.logging.LogMessageFbs.v4").c_str());
+ unlink((FLAGS_shm_base + "/test/aos.TestMessage.v5").c_str());
+ unlink((FLAGS_shm_base + "/test1/aos.TestMessage.v5").c_str());
+ unlink((FLAGS_shm_base + "/test2/aos.TestMessage.v5").c_str());
+ unlink((FLAGS_shm_base + "/test2/aos.TestMessage.v5").c_str());
+ unlink((FLAGS_shm_base + "/aos/aos.timing.Report.v5").c_str());
+ unlink((FLAGS_shm_base + "/aos/aos.logging.LogMessageFbs.v5").c_str());
}
~ShmEventLoopTestFactory() { FLAGS_override_hostname = ""; }
diff --git a/aos/events/simulated_event_loop.cc b/aos/events/simulated_event_loop.cc
index b17e5ff..7e469f5 100644
--- a/aos/events/simulated_event_loop.cc
+++ b/aos/events/simulated_event_loop.cc
@@ -149,6 +149,12 @@
channel_storage_duration_(channel_storage_duration),
next_queue_index_(ipc_lib::QueueIndex::Zero(number_buffers())),
scheduler_(scheduler) {
+ // Gut check that things fit. Configuration validation should have caught
+ // this before we get here.
+ CHECK_LT(static_cast<size_t>(number_buffers()),
+ std::numeric_limits<
+ decltype(available_buffer_indices_)::value_type>::max())
+ << configuration::CleanedChannelToString(channel);
available_buffer_indices_.resize(number_buffers());
for (int i = 0; i < number_buffers(); ++i) {
available_buffer_indices_[i] = i;
@@ -291,7 +297,7 @@
// replay) and we want to prevent new senders from being accidentally created.
bool allow_new_senders_ = true;
- std::vector<uint16_t> available_buffer_indices_;
+ std::vector<ipc_lib::QueueIndex::PackedIndexType> available_buffer_indices_;
const EventScheduler *scheduler_;
@@ -808,8 +814,8 @@
->emplace(SimpleChannel(channel),
std::unique_ptr<SimulatedChannel>(new SimulatedChannel(
channel,
- std::chrono::nanoseconds(
- configuration()->channel_storage_duration()),
+ configuration::ChannelStorageDuration(
+ configuration(), channel),
scheduler_)))
.first;
}
diff --git a/aos/events/simulated_network_bridge.cc b/aos/events/simulated_network_bridge.cc
index 35cc3a6..c8a8150 100644
--- a/aos/events/simulated_network_bridge.cc
+++ b/aos/events/simulated_network_bridge.cc
@@ -515,7 +515,7 @@
if (channel == timestamp_channel) {
source_event_loop->second.SetSendData(
- [captured_delayers = delayers.get()](const Context &) {
+ [captured_delayers = delayers.get()]() {
for (std::unique_ptr<RawMessageDelayer> &delayer :
captured_delayers->v) {
delayer->Schedule();
diff --git a/aos/events/simulated_network_bridge.h b/aos/events/simulated_network_bridge.h
index 9565029..faa6398 100644
--- a/aos/events/simulated_network_bridge.h
+++ b/aos/events/simulated_network_bridge.h
@@ -99,7 +99,7 @@
void SetEventLoop(std::unique_ptr<aos::EventLoop> loop);
- void SetSendData(std::function<void(const Context &)> fn) {
+ void SetSendData(std::function<void()> fn) {
CHECK(!fn_);
fn_ = std::move(fn);
if (server_status) {
@@ -209,7 +209,7 @@
std::vector<std::pair<const Channel *, DelayersVector *>> delayer_watchers_;
- std::function<void(const Context &)> fn_;
+ std::function<void()> fn_;
NodeEventLoopFactory *node_factory_;
std::unique_ptr<aos::EventLoop> event_loop;
diff --git a/aos/init.cc b/aos/init.cc
index 7d4840c..3ca7314 100644
--- a/aos/init.cc
+++ b/aos/init.cc
@@ -15,6 +15,7 @@
#include "glog/logging.h"
#include "aos/realtime.h"
+#include "aos/uuid.h"
DEFINE_bool(coredump, false, "If true, write core dumps on failure.");
@@ -37,6 +38,10 @@
}
RegisterMallocHook();
+ // Ensure that the random number generator for the UUID code is initialized
+ // (it does some potentially expensive random number generation).
+ UUID::Random();
+
initialized = true;
}
diff --git a/aos/ipc_lib/BUILD b/aos/ipc_lib/BUILD
index cc46dac..147e0ba 100644
--- a/aos/ipc_lib/BUILD
+++ b/aos/ipc_lib/BUILD
@@ -141,6 +141,32 @@
)
cc_library(
+ name = "index32",
+ srcs = ["index.cc"],
+ hdrs = ["index.h"],
+ defines = [
+ "AOS_QUEUE_ATOMIC_SIZE=32",
+ ],
+ target_compatible_with = ["@platforms//os:linux"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":shm_observers",
+ "@com_github_google_glog//:glog",
+ ],
+)
+
+cc_test(
+ name = "index32_test",
+ srcs = ["index_test.cc"],
+ target_compatible_with = ["@platforms//os:linux"],
+ deps = [
+ ":index32",
+ "//aos/testing:googletest",
+ "@com_github_google_glog//:glog",
+ ],
+)
+
+cc_library(
name = "index",
srcs = ["index.cc"],
hdrs = ["index.h"],
@@ -183,6 +209,7 @@
"//aos:configuration",
"//aos:realtime",
"//aos:uuid",
+ "//aos/events:context",
"//aos/time",
"//aos/util:compiler_memory_barrier",
"@com_github_google_glog//:glog",
diff --git a/aos/ipc_lib/index.h b/aos/ipc_lib/index.h
index 4814075..434da7c 100644
--- a/aos/ipc_lib/index.h
+++ b/aos/ipc_lib/index.h
@@ -1,6 +1,7 @@
#ifndef AOS_IPC_LIB_INDEX_H_
#define AOS_IPC_LIB_INDEX_H_
+#include <stdint.h>
#include <sys/types.h>
#include <atomic>
@@ -10,6 +11,18 @@
#include "aos/ipc_lib/shm_observers.h"
+#ifndef AOS_QUEUE_ATOMIC_SIZE
+#if UINTPTR_MAX == 0xffffffff
+#define AOS_QUEUE_ATOMIC_SIZE 32
+/* 32-bit */
+#elif UINTPTR_MAX == 0xffffffffffffffff
+#define AOS_QUEUE_ATOMIC_SIZE 64
+/* 64-bit */
+#else
+#error "Unknown pointer size"
+#endif
+#endif
+
namespace aos {
namespace ipc_lib {
@@ -40,13 +53,19 @@
// Structure for holding the index into the queue.
class QueueIndex {
public:
+#if AOS_QUEUE_ATOMIC_SIZE == 64
+ typedef uint32_t PackedIndexType;
+#else
+ typedef uint16_t PackedIndexType;
+#endif
+
// Returns an invalid queue element which uses a reserved value.
- static QueueIndex Invalid() { return QueueIndex(0xffffffff, 0); }
+ static QueueIndex Invalid() { return QueueIndex(sentinal_value(), 0); }
// Returns a queue element pointing to 0.
static QueueIndex Zero(uint32_t count) { return QueueIndex(0, count); }
// Returns true if the index is valid.
- bool valid() const { return index_ != 0xffffffff; }
+ bool valid() const { return index_ != sentinal_value(); }
// Returns the modulo base used to wrap to avoid overlapping with the reserved
// number.
@@ -88,10 +107,14 @@
return QueueIndex(index, count_);
}
- // Returns true if the lowest 16 bits of the queue index from the Index could
- // plausibly match this queue index.
- bool IsPlausible(uint16_t queue_index) const {
- return valid() && (queue_index == static_cast<uint16_t>(index_ & 0xffff));
+ // Returns true if the lowest bits of the queue index from the Index could
+ // plausibly match this queue index. The number of bits matched depends on
+ // the the size of atomics in use.
+ bool IsPlausible(PackedIndexType queue_index) const {
+ return valid() &&
+ (queue_index ==
+ static_cast<PackedIndexType>(
+ index_ & std::numeric_limits<PackedIndexType>::max()));
}
bool operator==(const QueueIndex other) const {
@@ -170,21 +193,39 @@
// Structure holding the queue index and the index into the message list.
class Index {
public:
+#if AOS_QUEUE_ATOMIC_SIZE == 64
+ typedef uint64_t IndexType;
+ typedef uint32_t MessageIndexType;
+#else
+ typedef uint32_t IndexType;
+ typedef uint16_t MessageIndexType;
+#endif
+ typedef QueueIndex::PackedIndexType PackedIndexType;
+
// Constructs an Index. queue_index is the QueueIndex of this message, and
// message_index is the index into the messages structure.
- Index(QueueIndex queue_index, uint16_t message_index)
+ Index(QueueIndex queue_index, MessageIndexType message_index)
: Index(queue_index.index_, message_index) {}
- Index(uint32_t queue_index, uint16_t message_index)
- : index_((queue_index & 0xffff) |
- (static_cast<uint32_t>(message_index) << 16)) {
+ Index(uint32_t queue_index, MessageIndexType message_index)
+ : index_(static_cast<IndexType>(
+ queue_index & std::numeric_limits<PackedIndexType>::max()) |
+ (static_cast<IndexType>(message_index)
+ << std::numeric_limits<PackedIndexType>::digits)) {
CHECK_LE(message_index, MaxMessages());
}
// Index of this message in the message array.
- uint16_t message_index() const { return (index_ >> 16) & 0xffff; }
+ MessageIndexType message_index() const {
+ return (index_ >> std::numeric_limits<PackedIndexType>::digits) &
+ std::numeric_limits<MessageIndexType>::max();
+ }
- // Lowest 16 bits of the queue index of this message in the queue.
- uint16_t queue_index() const { return index_ & 0xffff; }
+ // Lowest bits of the queue index of this message in the queue. This will
+ // either be 16 or 32 bits, depending on if we have 32 or 64 bit atomics under
+ // the cover.
+ PackedIndexType queue_index() const {
+ return index_ & std::numeric_limits<PackedIndexType>::max();
+ }
// Returns true if the provided queue index plausibly represents this Index.
bool IsPlausible(QueueIndex queue_index) const {
@@ -197,29 +238,33 @@
bool valid() const { return index_ != sentinal_value(); }
// Returns the raw Index. This should only be used for debug.
- uint32_t get() const { return index_; }
+ IndexType get() const { return index_; }
// Returns the maximum number of messages we can store before overflowing.
- static constexpr uint16_t MaxMessages() { return 0xfffe; }
+ static constexpr MessageIndexType MaxMessages() {
+ return std::numeric_limits<MessageIndexType>::max() - 1;
+ }
bool operator==(const Index other) const { return other.index_ == index_; }
bool operator!=(const Index other) const { return other.index_ != index_; }
// Returns a string representing the index.
- ::std::string DebugString() const;
+ std::string DebugString() const;
private:
- Index(uint32_t index) : index_(index) {}
+ Index(IndexType index) : index_(index) {}
friend class AtomicIndex;
- static constexpr uint32_t sentinal_value() { return 0xffffffffu; }
+ static constexpr IndexType sentinal_value() {
+ return std::numeric_limits<IndexType>::max();
+ }
- // Note: a value of 0xffffffff is a sentinal to represent an invalid entry.
+ // Note: a value of all 1 bits is a sentinal to represent an invalid entry.
// This works because we would need to have a queue index of 0x*ffff, *and*
// have 0xffff messages in the message list. That constraint is easy to
// enforce by limiting the max messages.
- uint32_t index_;
+ IndexType index_;
};
// Atomic storage for setting and getting Index objects.
@@ -257,7 +302,7 @@
}
private:
- ::std::atomic<uint32_t> index_;
+ ::std::atomic<Index::IndexType> index_;
};
} // namespace ipc_lib
diff --git a/aos/ipc_lib/index_test.cc b/aos/ipc_lib/index_test.cc
index a288ec9..dca75a0 100644
--- a/aos/ipc_lib/index_test.cc
+++ b/aos/ipc_lib/index_test.cc
@@ -133,6 +133,37 @@
EXPECT_EQ(index.message_index(), 11);
}
+#if AOS_QUEUE_ATOMIC_SIZE == 64
+// Tests that the 64 bit plausible has sane behavior.
+TEST(IndexTest, TestPlausible) {
+ QueueIndex five = QueueIndex::Zero(100).IncrementBy(5);
+ QueueIndex ffff = QueueIndex::Zero(100).IncrementBy(0xffff);
+
+ // Tests some various combinations of indices.
+ for (int i = 0; i < 100; ++i) {
+ Index index(five, i);
+ EXPECT_EQ(index.queue_index(), 5 + i * 0x10000);
+
+ EXPECT_TRUE(index.IsPlausible(five));
+
+ EXPECT_EQ(index.message_index(), i);
+
+ five = five.IncrementBy(0x10000);
+ }
+
+ // Tests that a queue index with a value of 0xffff doesn't match an invalid
+ // index.
+ for (int i = 0; i < 100; ++i) {
+ Index index(ffff, i);
+ EXPECT_EQ(index.queue_index(), 0xffff);
+
+ EXPECT_TRUE(index.IsPlausible(ffff));
+ EXPECT_FALSE(index.IsPlausible(QueueIndex::Invalid()));
+
+ EXPECT_EQ(index.message_index(), i);
+ }
+}
+#else
// Tests that Plausible behaves.
TEST(IndexTest, TestPlausible) {
QueueIndex five = QueueIndex::Zero(100).IncrementBy(5);
@@ -162,6 +193,16 @@
EXPECT_EQ(index.message_index(), i);
}
}
+#endif
+
+// Tests that the max message size makes sense.
+TEST(IndexTest, TestMaxMessages) {
+#if AOS_QUEUE_ATOMIC_SIZE == 64
+ EXPECT_EQ(Index::MaxMessages(), 0xfffffffe);
+#else
+ EXPECT_EQ(Index::MaxMessages(), 0xfffe);
+#endif
+}
} // namespace testing
} // namespace ipc_lib
diff --git a/aos/ipc_lib/lockless_queue.cc b/aos/ipc_lib/lockless_queue.cc
index 3d2c0d1..01706ba 100644
--- a/aos/ipc_lib/lockless_queue.cc
+++ b/aos/ipc_lib/lockless_queue.cc
@@ -627,7 +627,8 @@
::aos::ipc_lib::Sender *s = memory->GetSender(i);
// Nobody else can possibly be touching these because we haven't set
// initialized to true yet.
- s->scratch_index.RelaxedStore(Index(0xffff, i + memory->queue_size()));
+ s->scratch_index.RelaxedStore(
+ Index(QueueIndex::Invalid(), i + memory->queue_size()));
s->to_replace.RelaxedInvalidate();
}
@@ -636,7 +637,8 @@
// Nobody else can possibly be touching these because we haven't set
// initialized to true yet.
pinner->scratch_index.RelaxedStore(
- Index(0xffff, i + memory->num_senders() + memory->queue_size()));
+ Index(QueueIndex::Invalid(),
+ i + memory->num_senders() + memory->queue_size()));
pinner->pinned.Invalidate();
}
@@ -1267,7 +1269,7 @@
monotonic_clock::time_point *monotonic_remote_time,
realtime_clock::time_point *realtime_remote_time,
uint32_t *remote_queue_index, UUID *source_boot_uuid, size_t *length,
- char *data) const {
+ char *data, std::function<bool(const Context &)> should_read) const {
const size_t queue_size = memory_->queue_size();
// Build up the QueueIndex.
@@ -1341,34 +1343,80 @@
// Then read the data out. Copy it all out to be deterministic and so we can
// make length be from either end.
- *monotonic_sent_time = m->header.monotonic_sent_time;
- *realtime_sent_time = m->header.realtime_sent_time;
- if (m->header.remote_queue_index == 0xffffffffu) {
- *remote_queue_index = queue_index.index();
+ if (!should_read) {
+ *monotonic_sent_time = m->header.monotonic_sent_time;
+ *realtime_sent_time = m->header.realtime_sent_time;
+ if (m->header.remote_queue_index == 0xffffffffu) {
+ *remote_queue_index = queue_index.index();
+ } else {
+ *remote_queue_index = m->header.remote_queue_index;
+ }
+ *monotonic_remote_time = m->header.monotonic_remote_time;
+ *realtime_remote_time = m->header.realtime_remote_time;
+ *source_boot_uuid = m->header.source_boot_uuid;
+ *length = m->header.length;
} else {
- *remote_queue_index = m->header.remote_queue_index;
+ // Cache the header results so we don't modify the outputs unless the filter
+ // function says "go".
+ Context context;
+ context.monotonic_event_time = m->header.monotonic_sent_time;
+ context.realtime_event_time = m->header.realtime_sent_time;
+ context.monotonic_remote_time = m->header.monotonic_remote_time;
+ context.realtime_remote_time = m->header.realtime_remote_time;
+ context.queue_index = queue_index.index();
+ if (m->header.remote_queue_index == 0xffffffffu) {
+ context.remote_queue_index = context.queue_index;
+ } else {
+ context.remote_queue_index = m->header.remote_queue_index;
+ }
+ context.source_boot_uuid = m->header.source_boot_uuid;
+ context.size = m->header.length;
+ context.data = nullptr;
+ context.buffer_index = -1;
+
+ // And finally, confirm that the message *still* points to the queue index
+ // we want. This means it didn't change out from under us. If something
+ // changed out from under us, we were reading it much too late in its
+ // lifetime.
+ aos_compiler_memory_barrier();
+ const QueueIndex final_queue_index = m->header.queue_index.Load(queue_size);
+ if (final_queue_index != queue_index) {
+ VLOG(3) << "Changed out from under us. Reading " << std::hex
+ << queue_index.index() << ", finished with "
+ << final_queue_index.index() << ", delta: " << std::dec
+ << (final_queue_index.index() - queue_index.index());
+ return Result::OVERWROTE;
+ }
+
+ // We now know that the context is safe to use. See if we are supposed to
+ // take the message or not.
+ if (!should_read(context)) {
+ return Result::FILTERED;
+ }
+
+ // And now take it.
+ *monotonic_sent_time = context.monotonic_event_time;
+ *realtime_sent_time = context.realtime_event_time;
+ *remote_queue_index = context.remote_queue_index;
+ *monotonic_remote_time = context.monotonic_remote_time;
+ *realtime_remote_time = context.realtime_remote_time;
+ *source_boot_uuid = context.source_boot_uuid;
+ *length = context.size;
}
- *monotonic_remote_time = m->header.monotonic_remote_time;
- *realtime_remote_time = m->header.realtime_remote_time;
- *source_boot_uuid = m->header.source_boot_uuid;
if (data) {
memcpy(data, m->data(memory_->message_data_size()),
memory_->message_data_size());
- }
- *length = m->header.length;
- // And finally, confirm that the message *still* points to the queue index we
- // want. This means it didn't change out from under us.
- // If something changed out from under us, we were reading it much too late in
- // it's lifetime.
- aos_compiler_memory_barrier();
- const QueueIndex final_queue_index = m->header.queue_index.Load(queue_size);
- if (final_queue_index != queue_index) {
- VLOG(3) << "Changed out from under us. Reading " << std::hex
- << queue_index.index() << ", finished with "
- << final_queue_index.index() << ", delta: " << std::dec
- << (final_queue_index.index() - queue_index.index());
- return Result::OVERWROTE;
+ // Check again since we touched the message again.
+ aos_compiler_memory_barrier();
+ const QueueIndex final_queue_index = m->header.queue_index.Load(queue_size);
+ if (final_queue_index != queue_index) {
+ VLOG(3) << "Changed out from under us. Reading " << std::hex
+ << queue_index.index() << ", finished with "
+ << final_queue_index.index() << ", delta: " << std::dec
+ << (final_queue_index.index() - queue_index.index());
+ return Result::OVERWROTE;
+ }
}
return Result::GOOD;
@@ -1400,7 +1448,7 @@
// Prints out the mutex state. Not safe to use while the mutex is being
// changed.
-::std::string PrintMutex(aos_mutex *mutex) {
+::std::string PrintMutex(const aos_mutex *mutex) {
::std::stringstream s;
s << "aos_mutex(" << ::std::hex << mutex->futex;
@@ -1418,7 +1466,7 @@
} // namespace
-void PrintLocklessQueueMemory(LocklessQueueMemory *memory) {
+void PrintLocklessQueueMemory(const LocklessQueueMemory *memory) {
const size_t queue_size = memory->queue_size();
::std::cout << "LocklessQueueMemory (" << memory << ") {" << ::std::endl;
::std::cout << " aos_mutex queue_setup_lock = "
@@ -1452,7 +1500,7 @@
::std::cout << " Message messages[" << memory->num_messages() << "] {"
<< ::std::endl;
for (size_t i = 0; i < memory->num_messages(); ++i) {
- Message *m = memory->GetMessage(Index(i, i));
+ const Message *m = memory->GetMessage(Index(i, i));
::std::cout << " [" << i << "] -> Message 0x" << std::hex
<< (reinterpret_cast<uintptr_t>(
memory->GetMessage(Index(i, i))) -
@@ -1484,8 +1532,9 @@
::std::cout << " }" << ::std::endl;
const bool corrupt = CheckBothRedzones(memory, m);
if (corrupt) {
- absl::Span<char> pre_redzone = m->PreRedzone(memory->message_data_size());
- absl::Span<char> post_redzone =
+ absl::Span<const char> pre_redzone =
+ m->PreRedzone(memory->message_data_size());
+ absl::Span<const char> post_redzone =
m->PostRedzone(memory->message_data_size(), memory->message_size());
::std::cout << " pre-redzone: \""
@@ -1514,7 +1563,7 @@
::std::cout << " Sender senders[" << memory->num_senders() << "] {"
<< ::std::endl;
for (size_t i = 0; i < memory->num_senders(); ++i) {
- Sender *s = memory->GetSender(i);
+ const Sender *s = memory->GetSender(i);
::std::cout << " [" << i << "] -> Sender {" << ::std::endl;
::std::cout << " aos_mutex tid = " << PrintMutex(&s->tid)
<< ::std::endl;
@@ -1529,7 +1578,7 @@
::std::cout << " Pinner pinners[" << memory->num_pinners() << "] {"
<< ::std::endl;
for (size_t i = 0; i < memory->num_pinners(); ++i) {
- Pinner *p = memory->GetPinner(i);
+ const Pinner *p = memory->GetPinner(i);
::std::cout << " [" << i << "] -> Pinner {" << ::std::endl;
::std::cout << " aos_mutex tid = " << PrintMutex(&p->tid)
<< ::std::endl;
@@ -1545,7 +1594,7 @@
::std::cout << " Watcher watchers[" << memory->num_watchers() << "] {"
<< ::std::endl;
for (size_t i = 0; i < memory->num_watchers(); ++i) {
- Watcher *w = memory->GetWatcher(i);
+ const Watcher *w = memory->GetWatcher(i);
::std::cout << " [" << i << "] -> Watcher {" << ::std::endl;
::std::cout << " aos_mutex tid = " << PrintMutex(&w->tid)
<< ::std::endl;
diff --git a/aos/ipc_lib/lockless_queue.h b/aos/ipc_lib/lockless_queue.h
index 9cc97c0..2cafb48 100644
--- a/aos/ipc_lib/lockless_queue.h
+++ b/aos/ipc_lib/lockless_queue.h
@@ -10,6 +10,7 @@
#include "absl/types/span.h"
+#include "aos/events/context.h"
#include "aos/ipc_lib/aos_sync.h"
#include "aos/ipc_lib/data_alignment.h"
#include "aos/ipc_lib/index.h"
@@ -406,7 +407,19 @@
class LocklessQueueReader {
public:
- enum class Result { TOO_OLD, GOOD, NOTHING_NEW, OVERWROTE };
+ enum class Result {
+ // Message we read was too old and no longer is in the queue.
+ TOO_OLD,
+ // Success!
+ GOOD,
+ // The message is in the future and we haven't written it yet.
+ NOTHING_NEW,
+ // There is a message, but should_read() returned false so we didn't fetch
+ // it.
+ FILTERED,
+ // The message got overwritten while we were reading it.
+ OVERWROTE,
+ };
LocklessQueueReader(LocklessQueue queue) : memory_(queue.const_memory()) {
queue.Initialize();
@@ -416,7 +429,8 @@
// NOTHING_NEW until that gets overwritten with new data. If you ask for an
// element newer than QueueSize() from the current message, we consider it
// behind by a large amount and return TOO_OLD. If the message is modified
- // out from underneath us as we read it, return OVERWROTE.
+ // out from underneath us as we read it, return OVERWROTE. If we found a new
+ // message, but the filter function returned false, return FILTERED.
//
// data may be nullptr to indicate the data should not be copied.
Result Read(uint32_t queue_index,
@@ -425,7 +439,8 @@
monotonic_clock::time_point *monotonic_remote_time,
realtime_clock::time_point *realtime_remote_time,
uint32_t *remote_queue_index, UUID *source_boot_uuid,
- size_t *length, char *data) const;
+ size_t *length, char *data,
+ std::function<bool(const Context &context)> should_read) const;
// Returns the index to the latest queue message. Returns empty_queue_index()
// if there are no messages in the queue. Do note that this index wraps if
@@ -454,7 +469,7 @@
// before and after a time with a binary search.
// Prints to stdout the data inside the queue for debugging.
-void PrintLocklessQueueMemory(LocklessQueueMemory *memory);
+void PrintLocklessQueueMemory(const LocklessQueueMemory *memory);
} // namespace ipc_lib
} // namespace aos
diff --git a/aos/ipc_lib/lockless_queue_death_test.cc b/aos/ipc_lib/lockless_queue_death_test.cc
index 3ed3099..c37dc63 100644
--- a/aos/ipc_lib/lockless_queue_death_test.cc
+++ b/aos/ipc_lib/lockless_queue_death_test.cc
@@ -692,6 +692,11 @@
// increments.
char last_data = '0';
int i = 0;
+
+ std::function<bool(const Context &)> should_read = [](const Context &) {
+ return true;
+ };
+
while (true) {
monotonic_clock::time_point monotonic_sent_time;
realtime_clock::time_point realtime_sent_time;
@@ -702,10 +707,11 @@
char read_data[1024];
size_t length;
- LocklessQueueReader::Result read_result = reader.Read(
- i, &monotonic_sent_time, &realtime_sent_time,
- &monotonic_remote_time, &realtime_remote_time,
- &remote_queue_index, &source_boot_uuid, &length, &(read_data[0]));
+ LocklessQueueReader::Result read_result =
+ reader.Read(i, &monotonic_sent_time, &realtime_sent_time,
+ &monotonic_remote_time, &realtime_remote_time,
+ &remote_queue_index, &source_boot_uuid, &length,
+ &(read_data[0]), std::ref(should_read));
if (read_result != LocklessQueueReader::Result::GOOD) {
if (read_result == LocklessQueueReader::Result::TOO_OLD) {
diff --git a/aos/ipc_lib/lockless_queue_memory.h b/aos/ipc_lib/lockless_queue_memory.h
index b3f9468..713d9cd 100644
--- a/aos/ipc_lib/lockless_queue_memory.h
+++ b/aos/ipc_lib/lockless_queue_memory.h
@@ -119,6 +119,14 @@
SizeOfSenders() + pinner_index * sizeof(Pinner));
}
+ const Sender *GetSender(size_t sender_index) const {
+ static_assert(alignof(Sender) <= kDataAlignment,
+ "kDataAlignment is too small");
+ return reinterpret_cast<const Sender *>(
+ &data[0] + SizeOfQueue() + SizeOfMessages() + SizeOfWatchers() +
+ sender_index * sizeof(Sender));
+ }
+
Sender *GetSender(size_t sender_index) {
static_assert(alignof(Sender) <= kDataAlignment,
"kDataAlignment is too small");
diff --git a/aos/ipc_lib/lockless_queue_test.cc b/aos/ipc_lib/lockless_queue_test.cc
index 34e2762..93ae2a3 100644
--- a/aos/ipc_lib/lockless_queue_test.cc
+++ b/aos/ipc_lib/lockless_queue_test.cc
@@ -234,9 +234,14 @@
LocklessQueueSender::Make(queue(), kChannelStorageDuration).value();
LocklessQueueReader reader(queue());
- time::PhasedLoop loop(std::chrono::microseconds(1), monotonic_clock::now());
+ time::PhasedLoop loop(kChannelStorageDuration / (config_.queue_size - 1),
+ monotonic_clock::now());
+ std::function<bool(const Context &)> should_read = [](const Context &) {
+ return true;
+ };
+
// Send enough messages to wrap.
- for (int i = 0; i < 20000; ++i) {
+ for (int i = 0; i < 2 * static_cast<int>(config_.queue_size); ++i) {
// Confirm that the queue index makes sense given the number of sends.
EXPECT_EQ(reader.LatestIndex().index(),
i == 0 ? QueueIndex::Invalid().index() : i - 1);
@@ -244,7 +249,7 @@
// Send a trivial piece of data.
char data[100];
size_t s = snprintf(data, sizeof(data), "foobar%d", i);
- EXPECT_EQ(sender.Send(data, s, monotonic_clock::min_time,
+ ASSERT_EQ(sender.Send(data, s, monotonic_clock::min_time,
realtime_clock::min_time, 0xffffffffu, UUID::Zero(),
nullptr, nullptr, nullptr),
LocklessQueueSender::Result::GOOD);
@@ -272,12 +277,12 @@
LocklessQueueReader::Result read_result = reader.Read(
index.index(), &monotonic_sent_time, &realtime_sent_time,
&monotonic_remote_time, &realtime_remote_time, &remote_queue_index,
- &source_boot_uuid, &length, &(read_data[0]));
+ &source_boot_uuid, &length, &(read_data[0]), std::ref(should_read));
// This should either return GOOD, or TOO_OLD if it is before the start of
// the queue.
if (read_result != LocklessQueueReader::Result::GOOD) {
- EXPECT_EQ(read_result, LocklessQueueReader::Result::TOO_OLD);
+ ASSERT_EQ(read_result, LocklessQueueReader::Result::TOO_OLD);
}
loop.SleepUntilNext();
@@ -291,6 +296,8 @@
::std::mt19937 generator(0);
::std::uniform_int_distribution<> write_wrap_count_distribution(0, 10);
::std::bernoulli_distribution race_reads_distribution;
+ ::std::bernoulli_distribution set_should_read_distribution;
+ ::std::bernoulli_distribution should_read_result_distribution;
::std::bernoulli_distribution wrap_writes_distribution;
const chrono::seconds print_frequency(FLAGS_print_rate);
@@ -304,12 +311,15 @@
monotonic_clock::time_point next_print_time = start_time + print_frequency;
uint64_t messages = 0;
for (int i = 0; i < FLAGS_min_iterations || monotonic_now < end_time; ++i) {
- bool race_reads = race_reads_distribution(generator);
+ const bool race_reads = race_reads_distribution(generator);
+ const bool set_should_read = set_should_read_distribution(generator);
+ const bool should_read_result = should_read_result_distribution(generator);
int write_wrap_count = write_wrap_count_distribution(generator);
if (!wrap_writes_distribution(generator)) {
write_wrap_count = 0;
}
- EXPECT_NO_FATAL_FAILURE(racer.RunIteration(race_reads, write_wrap_count))
+ EXPECT_NO_FATAL_FAILURE(racer.RunIteration(
+ race_reads, write_wrap_count, set_should_read, should_read_result))
<< ": Running with race_reads: " << race_reads
<< ", and write_wrap_count " << write_wrap_count << " and on iteration "
<< i;
@@ -391,7 +401,7 @@
std::chrono::milliseconds(500),
false});
- EXPECT_NO_FATAL_FAILURE(racer.RunIteration(false, 0));
+ EXPECT_NO_FATAL_FAILURE(racer.RunIteration(false, 0, true, true));
}
// // Send enough messages to wrap the 32 bit send counter.
@@ -401,7 +411,7 @@
QueueRacer racer(queue(), 1, kNumMessages);
const monotonic_clock::time_point start_time = monotonic_clock::now();
- EXPECT_NO_FATAL_FAILURE(racer.RunIteration(false, 0));
+ EXPECT_NO_FATAL_FAILURE(racer.RunIteration(false, 0, false, true));
const monotonic_clock::time_point monotonic_now = monotonic_clock::now();
double elapsed_seconds = chrono::duration_cast<chrono::duration<double>>(
monotonic_now - start_time)
diff --git a/aos/ipc_lib/memory_mapped_queue.cc b/aos/ipc_lib/memory_mapped_queue.cc
index c2d8d44..79f27d9 100644
--- a/aos/ipc_lib/memory_mapped_queue.cc
+++ b/aos/ipc_lib/memory_mapped_queue.cc
@@ -17,7 +17,7 @@
std::string ShmPath(std::string_view shm_base, const Channel *channel) {
CHECK(channel->has_type());
- return ShmFolder(shm_base, channel) + channel->type()->str() + ".v4";
+ return ShmFolder(shm_base, channel) + channel->type()->str() + ".v5";
}
void PageFaultDataWrite(char *data, size_t size) {
@@ -67,6 +67,15 @@
// copy.
config.num_pinners = channel->num_readers();
config.queue_size = configuration::QueueSize(configuration, channel);
+ CHECK_LT(config.queue_size,
+ std::numeric_limits<QueueIndex::PackedIndexType>::max())
+ << ": More messages/second configured than the queue can hold on "
+ << configuration::CleanedChannelToString(channel) << ", "
+ << channel->frequency() << "hz for "
+ << std::chrono::duration<double>(
+ configuration::ChannelStorageDuration(configuration, channel))
+ .count()
+ << "sec";
config.message_data_size = channel->max_size();
return config;
diff --git a/aos/ipc_lib/queue_racer.cc b/aos/ipc_lib/queue_racer.cc
index 7c0408d..0b8f1a6 100644
--- a/aos/ipc_lib/queue_racer.cc
+++ b/aos/ipc_lib/queue_racer.cc
@@ -13,7 +13,7 @@
namespace {
struct ThreadPlusCount {
- int thread;
+ uint64_t thread;
uint64_t count;
};
@@ -47,7 +47,8 @@
Reset();
}
-void QueueRacer::RunIteration(bool race_reads, int write_wrap_count) {
+void QueueRacer::RunIteration(bool race_reads, int write_wrap_count,
+ bool set_should_read, bool should_read_result) {
const bool will_wrap = num_messages_ * num_threads_ *
static_cast<uint64_t>(1 + write_wrap_count) >
queue_.config().queue_size;
@@ -197,7 +198,10 @@
++started_writes_;
auto result =
sender.Send(sizeof(ThreadPlusCount), aos::monotonic_clock::min_time,
- aos::realtime_clock::min_time, 0xffffffff, UUID::Zero(),
+ aos::realtime_clock::min_time, 0xffffffff,
+ UUID::FromSpan(absl::Span<const uint8_t>(
+ reinterpret_cast<const uint8_t *>(&tpc),
+ sizeof(ThreadPlusCount))),
nullptr, nullptr, nullptr);
CHECK(std::find(expected_send_results_.begin(),
@@ -237,7 +241,8 @@
}
if (check_writes_and_reads_) {
- CheckReads(race_reads, write_wrap_count, &threads);
+ CheckReads(race_reads, write_wrap_count, &threads, set_should_read,
+ should_read_result);
}
// Reap all the threads.
@@ -276,7 +281,8 @@
}
void QueueRacer::CheckReads(bool race_reads, int write_wrap_count,
- ::std::vector<ThreadState> *threads) {
+ ::std::vector<ThreadState> *threads,
+ bool set_should_read, bool should_read_result) {
// Now read back the results to double check.
LocklessQueueReader reader(queue_);
const bool will_wrap = num_messages_ * num_threads_ * (1 + write_wrap_count) >
@@ -290,6 +296,15 @@
LocklessQueueSize(queue_.memory());
}
+ std::function<bool(const Context &)> nop;
+
+ Context fetched_context;
+ std::function<bool(const Context &)> should_read =
+ [&should_read_result, &fetched_context](const Context &context) {
+ fetched_context = context;
+ return should_read_result;
+ };
+
for (uint64_t i = initial_i;
i < (1 + write_wrap_count) * num_messages_ * num_threads_; ++i) {
monotonic_clock::time_point monotonic_sent_time;
@@ -308,8 +323,17 @@
LocklessQueueReader::Result read_result = reader.Read(
wrapped_i, &monotonic_sent_time, &realtime_sent_time,
&monotonic_remote_time, &realtime_remote_time, &remote_queue_index,
- &source_boot_uuid, &length, &(read_data[0]));
+ &source_boot_uuid, &length, &(read_data[0]),
+ set_should_read ? std::ref(should_read) : std::ref(nop));
+ // The code in lockless_queue.cc reads everything but data, checks that the
+ // header hasn't changed, then reads the data. So, if we succeed and both
+ // end up not being corrupted, then we've confirmed everything works.
+ //
+ // Feed in both combos of should_read and whether or not to return true or
+ // false from should_read. By capturing the header values inside the
+ // callback, we can also verify the state in the middle of the process to
+ // make sure we have the right boundaries.
if (race_reads) {
if (read_result == LocklessQueueReader::Result::NOTHING_NEW) {
--i;
@@ -322,22 +346,54 @@
continue;
}
}
- // Every message should be good.
- ASSERT_EQ(read_result, LocklessQueueReader::Result::GOOD) << ": i is " << i;
+
+ if (!set_should_read) {
+ // Every message should be good.
+ ASSERT_EQ(read_result, LocklessQueueReader::Result::GOOD)
+ << ": i is " << i;
+ } else {
+ if (should_read_result) {
+ ASSERT_EQ(read_result, LocklessQueueReader::Result::GOOD)
+ << ": i is " << i;
+
+ ASSERT_EQ(monotonic_sent_time, fetched_context.monotonic_event_time);
+ ASSERT_EQ(realtime_sent_time, fetched_context.realtime_event_time);
+ ASSERT_EQ(monotonic_remote_time, fetched_context.monotonic_remote_time);
+ ASSERT_EQ(realtime_remote_time, fetched_context.realtime_remote_time);
+ ASSERT_EQ(source_boot_uuid, fetched_context.source_boot_uuid);
+ ASSERT_EQ(remote_queue_index, fetched_context.remote_queue_index);
+ ASSERT_EQ(length, fetched_context.size);
+
+ ASSERT_EQ(
+ absl::Span<const uint8_t>(
+ reinterpret_cast<const uint8_t *>(
+ read_data + LocklessQueueMessageDataSize(queue_.memory()) -
+ length),
+ length),
+ source_boot_uuid.span());
+ } else {
+ ASSERT_EQ(read_result, LocklessQueueReader::Result::FILTERED);
+ monotonic_sent_time = fetched_context.monotonic_event_time;
+ realtime_sent_time = fetched_context.realtime_event_time;
+ monotonic_remote_time = fetched_context.monotonic_remote_time;
+ realtime_remote_time = fetched_context.realtime_remote_time;
+ source_boot_uuid = fetched_context.source_boot_uuid;
+ remote_queue_index = fetched_context.remote_queue_index;
+ length = fetched_context.size;
+ }
+ }
// And, confirm that time never went backwards.
ASSERT_GT(monotonic_sent_time, last_monotonic_sent_time);
last_monotonic_sent_time = monotonic_sent_time;
- EXPECT_EQ(monotonic_remote_time, aos::monotonic_clock::min_time);
- EXPECT_EQ(realtime_remote_time, aos::realtime_clock::min_time);
- EXPECT_EQ(source_boot_uuid, UUID::Zero());
+ ASSERT_EQ(monotonic_remote_time, aos::monotonic_clock::min_time);
+ ASSERT_EQ(realtime_remote_time, aos::realtime_clock::min_time);
ThreadPlusCount tpc;
- ASSERT_EQ(length, sizeof(ThreadPlusCount));
- memcpy(&tpc,
- read_data + LocklessQueueMessageDataSize(queue_.memory()) - length,
- sizeof(ThreadPlusCount));
+ ASSERT_EQ(source_boot_uuid.span().size(), sizeof(ThreadPlusCount));
+ memcpy(&tpc, source_boot_uuid.span().data(),
+ source_boot_uuid.span().size());
if (will_wrap) {
// The queue won't chang out from under us, so we should get some amount
diff --git a/aos/ipc_lib/queue_racer.h b/aos/ipc_lib/queue_racer.h
index 3e5ca94..87f2cce 100644
--- a/aos/ipc_lib/queue_racer.h
+++ b/aos/ipc_lib/queue_racer.h
@@ -48,7 +48,12 @@
// necesitates a loser check at the end.
//
// If both are set, run an even looser test.
- void RunIteration(bool race_reads, int write_wrap_count);
+ //
+ // set_should_read is used to determine if we should pass in a valid
+ // should_read function, and should_read_result is the return result of that
+ // function.
+ void RunIteration(bool race_reads, int write_wrap_count, bool set_should_read,
+ bool should_read_result);
size_t CurrentIndex() {
return LocklessQueueReader(queue_).LatestIndex().index();
@@ -64,7 +69,8 @@
// clean up all the threads. Otherwise we get an assert on the way out of
// RunIteration instead of getting all the way back to gtest.
void CheckReads(bool race_reads, int write_wrap_count,
- ::std::vector<ThreadState> *threads);
+ ::std::vector<ThreadState> *threads, bool set_should_read,
+ bool should_read_result);
LocklessQueue queue_;
const uint64_t num_threads_;
@@ -80,6 +86,14 @@
::std::atomic<uint64_t> started_writes_;
// Number of writes completed.
::std::atomic<uint64_t> finished_writes_;
+
+ std::function<bool(uint32_t, monotonic_clock::time_point,
+ realtime_clock::time_point, monotonic_clock::time_point,
+ realtime_clock::time_point, uint32_t, UUID, size_t)>
+ should_read_ = [](uint32_t, monotonic_clock::time_point,
+ realtime_clock::time_point, monotonic_clock::time_point,
+ realtime_clock::time_point, uint32_t, UUID,
+ size_t) { return true; };
};
} // namespace ipc_lib
diff --git a/aos/network/BUILD b/aos/network/BUILD
index d8d5756..75173e7 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -88,6 +88,31 @@
)
flatbuffer_cc_library(
+ name = "sctp_config_fbs",
+ srcs = ["sctp_config.fbs"],
+ gen_reflections = 1,
+)
+
+cc_static_flatbuffer(
+ name = "sctp_config_schema",
+ function = "aos::message_bridge::SctpConfig",
+ target = ":sctp_config_fbs_reflection_out",
+)
+
+flatbuffer_cc_library(
+ name = "sctp_config_request_fbs",
+ srcs = ["sctp_config_request.fbs"],
+ gen_reflections = 1,
+)
+
+cc_static_flatbuffer(
+ name = "sctp_config_request_schema",
+ function = "aos::message_bridge::SctpConfigRequest",
+ target = ":sctp_config_request_fbs_reflection_out",
+ visibility = ["//visibility:public"],
+)
+
+flatbuffer_cc_library(
name = "message_bridge_server_fbs",
srcs = ["message_bridge_server.fbs"],
gen_reflections = 1,
@@ -156,8 +181,27 @@
deps = [
"//aos:unique_malloc_ptr",
"//aos/util:file",
- "//third_party/lksctp-tools:sctp",
"@com_github_google_glog//:glog",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_test(
+ name = "sctp_test",
+ srcs = [
+ "sctp_test.cc",
+ ],
+ tags = [
+ # Fakeroot is required to enable "net.sctp.auth_enable".
+ "requires-fakeroot",
+ ],
+ target_compatible_with = ["@platforms//cpu:x86_64"],
+ deps = [
+ ":sctp_client",
+ ":sctp_lib",
+ ":sctp_server",
+ "//aos/events:epoll",
+ "//aos/testing:googletest",
],
)
@@ -175,7 +219,6 @@
target_compatible_with = ["@platforms//os:linux"],
deps = [
":sctp_lib",
- "//third_party/lksctp-tools:sctp",
],
)
@@ -253,6 +296,8 @@
":message_bridge_server_status",
":remote_data_fbs",
":remote_message_fbs",
+ ":sctp_config_fbs",
+ ":sctp_config_request_fbs",
":sctp_lib",
":sctp_server",
":timestamp_channel",
@@ -260,7 +305,6 @@
"//aos:unique_malloc_ptr",
"//aos/events:shm_event_loop",
"//aos/events/logging:log_reader",
- "//third_party/lksctp-tools:sctp",
],
)
@@ -272,6 +316,7 @@
target_compatible_with = ["@platforms//os:linux"],
deps = [
":message_bridge_server_lib",
+ ":sctp_lib",
"//aos:init",
"//aos:json_to_flatbuffer",
"//aos:sha256",
@@ -294,7 +339,6 @@
target_compatible_with = ["@platforms//os:linux"],
deps = [
":sctp_lib",
- "//third_party/lksctp-tools:sctp",
],
)
@@ -338,6 +382,8 @@
":remote_data_fbs",
":remote_message_fbs",
":sctp_client",
+ ":sctp_config_fbs",
+ ":sctp_config_request_fbs",
":timestamp_fbs",
"//aos/events:shm_event_loop",
"//aos/events/logging:log_reader",
@@ -355,11 +401,13 @@
target_compatible_with = ["@platforms//os:linux"],
deps = [
":message_bridge_client_lib",
+ ":sctp_lib",
"//aos:init",
"//aos:json_to_flatbuffer",
"//aos:sha256",
"//aos/events:shm_event_loop",
"//aos/logging:dynamic_logging",
+ "//aos/util:file",
],
)
@@ -368,6 +416,8 @@
src = "message_bridge_test_combined_timestamps_common.json",
flatbuffers = [
":remote_message_fbs",
+ ":sctp_config_fbs",
+ ":sctp_config_request_fbs",
"//aos/events:ping_fbs",
"//aos/events:pong_fbs",
"//aos/network:message_bridge_client_fbs",
@@ -383,6 +433,8 @@
src = "message_bridge_test_common.json",
flatbuffers = [
":remote_message_fbs",
+ ":sctp_config_fbs",
+ ":sctp_config_request_fbs",
"//aos/events:ping_fbs",
"//aos/events:pong_fbs",
"//aos/network:message_bridge_client_fbs",
@@ -408,6 +460,50 @@
deps = ["//aos/events:aos_config"],
)
+cc_library(
+ name = "message_bridge_test_lib",
+ testonly = True,
+ srcs = ["message_bridge_test_lib.cc"],
+ hdrs = ["message_bridge_test_lib.h"],
+ deps = [
+ ":message_bridge_client_lib",
+ ":message_bridge_server_lib",
+ "//aos:json_to_flatbuffer",
+ "//aos:sha256",
+ "//aos/events:ping_fbs",
+ "//aos/events:pong_fbs",
+ "//aos/events:shm_event_loop",
+ "//aos/ipc_lib:event",
+ "//aos/testing:googletest",
+ "//aos/testing:path",
+ ],
+)
+
+cc_test(
+ name = "message_bridge_retry_test",
+ srcs = [
+ "message_bridge_retry_test.cc",
+ ],
+ data = [
+ ":message_bridge_test_common_config",
+ ],
+ # Somewhat flaky due to relying on the behavior & timing of the host system's network stack.
+ flaky = True,
+ target_compatible_with = ["@platforms//os:linux"],
+ deps = [
+ ":message_bridge_server_lib",
+ ":message_bridge_test_lib",
+ "//aos:json_to_flatbuffer",
+ "//aos:sha256",
+ "//aos/events:ping_fbs",
+ "//aos/events:pong_fbs",
+ "//aos/events:shm_event_loop",
+ "//aos/ipc_lib:event",
+ "//aos/testing:googletest",
+ "//aos/testing:path",
+ ],
+)
+
cc_test(
name = "message_bridge_test",
srcs = [
@@ -418,11 +514,12 @@
":message_bridge_test_common_config",
],
flaky = True,
- shard_count = 10,
+ shard_count = 16,
target_compatible_with = ["@platforms//os:linux"],
deps = [
":message_bridge_client_lib",
":message_bridge_server_lib",
+ ":message_bridge_test_lib",
"//aos:json_to_flatbuffer",
"//aos:sha256",
"//aos/events:ping_fbs",
diff --git a/aos/network/log_web_proxy_main.cc b/aos/network/log_web_proxy_main.cc
index f93c5a2..12276c3 100644
--- a/aos/network/log_web_proxy_main.cc
+++ b/aos/network/log_web_proxy_main.cc
@@ -25,11 +25,8 @@
int main(int argc, char **argv) {
aos::InitGoogle(&argc, &argv);
- const std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
const std::vector<aos::logger::LogFile> logfiles =
- aos::logger::SortParts(unsorted_logfiles);
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv));
aos::logger::LogReader reader(logfiles);
diff --git a/aos/network/message_bridge_client.cc b/aos/network/message_bridge_client.cc
index c3f55ba..ef727eb 100644
--- a/aos/network/message_bridge_client.cc
+++ b/aos/network/message_bridge_client.cc
@@ -2,14 +2,21 @@
#include "aos/init.h"
#include "aos/logging/dynamic_logging.h"
#include "aos/network/message_bridge_client_lib.h"
+#include "aos/network/sctp_lib.h"
#include "aos/sha256.h"
+#include "aos/util/file.h"
DEFINE_string(config, "aos_config.json", "Path to the config.");
DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
+DEFINE_bool(
+ wants_sctp_authentication, false,
+ "When set, try to use SCTP authentication if provided by the kernel");
namespace aos {
namespace message_bridge {
+using ::aos::util::ReadFileToVecOrDie;
+
int Main() {
aos::FlatbufferDetachedBuffer<aos::Configuration> config =
aos::configuration::ReadConfig(FLAGS_config);
@@ -19,7 +26,10 @@
event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
}
- MessageBridgeClient app(&event_loop, Sha256(config.span()));
+ MessageBridgeClient app(&event_loop, Sha256(config.span()),
+ FLAGS_wants_sctp_authentication
+ ? SctpAuthMethod::kAuth
+ : SctpAuthMethod::kNoAuth);
logging::DynamicLogging dynamic_logging(&event_loop);
// TODO(austin): Save messages into a vector to be logged. One file per
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index c23f748..8df81c6 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -13,10 +13,14 @@
#include "aos/network/message_bridge_protocol.h"
#include "aos/network/remote_data_generated.h"
#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_config_generated.h"
+#include "aos/network/sctp_config_request_generated.h"
#include "aos/network/timestamp_generated.h"
#include "aos/unique_malloc_ptr.h"
#include "aos/util/file.h"
+DECLARE_bool(use_sctp_authentication);
+
// This application receives messages from another node and re-publishes them on
// this node.
//
@@ -30,6 +34,9 @@
namespace {
namespace chrono = std::chrono;
+// How often we should poll for the active SCTP authentication key.
+constexpr chrono::seconds kRefreshAuthKeyPeriod{3};
+
std::vector<int> StreamToChannel(const Configuration *config,
const Node *my_node, const Node *other_node) {
std::vector<int> stream_to_channel;
@@ -100,7 +107,8 @@
aos::ShmEventLoop *const event_loop, std::string_view remote_name,
const Node *my_node, std::string_view local_host,
std::vector<SctpClientChannelState> *channels, int client_index,
- MessageBridgeClientStatus *client_status, std::string_view config_sha256)
+ MessageBridgeClientStatus *client_status, std::string_view config_sha256,
+ SctpAuthMethod requested_authentication)
: event_loop_(event_loop),
connect_message_(MakeConnectMessage(event_loop->configuration(), my_node,
remote_name, event_loop->boot_uuid(),
@@ -111,7 +119,7 @@
client_(remote_node_->hostname()->string_view(), remote_node_->port(),
connect_message_.message().channels_to_transfer()->size() +
kControlStreams(),
- local_host, 0),
+ local_host, 0, requested_authentication),
channels_(channels),
stream_to_channel_(
StreamToChannel(event_loop->configuration(), my_node, remote_node_)),
@@ -363,13 +371,34 @@
<< " cumtsn=" << message->header.rcvinfo.rcv_cumtsn << ")";
}
-MessageBridgeClient::MessageBridgeClient(aos::ShmEventLoop *event_loop,
- std::string config_sha256)
+MessageBridgeClient::MessageBridgeClient(
+ aos::ShmEventLoop *event_loop, std::string config_sha256,
+ SctpAuthMethod requested_authentication)
: event_loop_(event_loop),
client_status_(event_loop_),
- config_sha256_(std::move(config_sha256)) {
+ config_sha256_(std::move(config_sha256)),
+ refresh_key_timer_(event_loop->AddTimer([this]() { RequestAuthKey(); })),
+ sctp_config_request_(event_loop_->MakeSender<SctpConfigRequest>("/aos")) {
std::string_view node_name = event_loop->node()->name()->string_view();
+ // Set up the SCTP configuration watcher and timer.
+ if (requested_authentication == SctpAuthMethod::kAuth && HasSctpAuth()) {
+ event_loop->MakeWatcher("/aos", [this](const SctpConfig &config) {
+ if (config.has_key()) {
+ for (auto &conn : connections_) {
+ conn->SetAuthKey(*config.key());
+ }
+ }
+ });
+
+ // We poll in case the SCTP authentication key has changed.
+ refresh_key_timer_->set_name("refresh_key");
+ event_loop_->OnRun([this]() {
+ refresh_key_timer_->Schedule(event_loop_->monotonic_now(),
+ kRefreshAuthKeyPeriod);
+ });
+ }
+
// Find all the channels which are supposed to be delivered to us.
channels_.resize(event_loop_->configuration()->channels()->size());
int channel_index = 0;
@@ -415,9 +444,16 @@
connections_.emplace_back(new SctpClientConnection(
event_loop, source_node, event_loop->node(), "", &channels_,
client_status_.FindClientIndex(source_node), &client_status_,
- config_sha256_));
+ config_sha256_, requested_authentication));
}
}
+void MessageBridgeClient::RequestAuthKey() {
+ auto sender = sctp_config_request_.MakeBuilder();
+ auto builder = sender.MakeBuilder<SctpConfigRequest>();
+ builder.add_request_key(true);
+ sender.CheckOk(sender.Send(builder.Finish()));
+}
+
} // namespace message_bridge
} // namespace aos
diff --git a/aos/network/message_bridge_client_lib.h b/aos/network/message_bridge_client_lib.h
index a5a69e3..e4b5e84 100644
--- a/aos/network/message_bridge_client_lib.h
+++ b/aos/network/message_bridge_client_lib.h
@@ -10,6 +10,7 @@
#include "aos/network/message_bridge_client_generated.h"
#include "aos/network/message_bridge_client_status.h"
#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_config_request_generated.h"
#include "aos/network/sctp_lib.h"
namespace aos {
@@ -38,10 +39,15 @@
std::vector<SctpClientChannelState> *channels,
int client_index,
MessageBridgeClientStatus *client_status,
- std::string_view config_sha256);
+ std::string_view config_sha256,
+ SctpAuthMethod requested_authentication);
~SctpClientConnection() { event_loop_->epoll()->DeleteFd(client_.fd()); }
+ void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+ client_.SetAuthKey(auth_key);
+ }
+
private:
// Reads a message from the socket. Could be a notification.
void MessageReceived();
@@ -102,11 +108,15 @@
// node.
class MessageBridgeClient {
public:
- MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256);
+ MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256,
+ SctpAuthMethod requested_authentication);
~MessageBridgeClient() {}
private:
+ // Sends a request for the currently active authentication key.
+ void RequestAuthKey();
+
// Event loop to schedule everything on.
aos::ShmEventLoop *event_loop_;
@@ -119,6 +129,12 @@
std::vector<std::unique_ptr<SctpClientConnection>> connections_;
std::string config_sha256_;
+
+ // We use this timer to poll the active authentication key.
+ aos::TimerHandler *refresh_key_timer_;
+
+ // Used to request the current sctp settings to be used.
+ aos::Sender<SctpConfigRequest> sctp_config_request_;
};
} // namespace message_bridge
diff --git a/aos/network/message_bridge_retry_test.cc b/aos/network/message_bridge_retry_test.cc
new file mode 100644
index 0000000..4110a76
--- /dev/null
+++ b/aos/network/message_bridge_retry_test.cc
@@ -0,0 +1,141 @@
+#include <chrono>
+#include <thread>
+
+#include "absl/strings/str_cat.h"
+#include "gtest/gtest.h"
+
+#include "aos/events/ping_generated.h"
+#include "aos/events/pong_generated.h"
+#include "aos/ipc_lib/event.h"
+#include "aos/network/message_bridge_client_lib.h"
+#include "aos/network/message_bridge_protocol.h"
+#include "aos/network/message_bridge_server_lib.h"
+#include "aos/network/message_bridge_test_lib.h"
+#include "aos/network/team_number.h"
+#include "aos/sha256.h"
+#include "aos/testing/path.h"
+#include "aos/util/file.h"
+
+DECLARE_int32(force_wmem_max);
+
+namespace aos {
+
+namespace message_bridge {
+namespace testing {
+
+void SendPing(aos::Sender<examples::Ping> *sender, int value) {
+ aos::Sender<examples::Ping>::Builder builder = sender->MakeBuilder();
+ // Artificially inflate message size by adding a bunch of padding.
+ builder.fbb()->CreateVector(std::vector<int>(1000, 0));
+ examples::Ping::Builder ping_builder = builder.MakeBuilder<examples::Ping>();
+ ping_builder.add_value(value);
+ builder.CheckOk(builder.Send(ping_builder.Finish()));
+}
+
+// Test that if we fill up the kernel buffers then the message bridge code does
+// indeed trigger (and succeed at triggering) its retry logic. Separated from
+// the normal message bridge tests because triggering this originally seemed
+// likely to be prone to extreme flakiness depending on the platform it is run
+// on. In practice, it actually seems to be *more* reliable than the normal
+// message_bridge_test, so we kept it separate.
+TEST_P(MessageBridgeParameterizedTest, ReliableRetries) {
+ // Set an absurdly small wmem max. This will help to trigger retries.
+ FLAGS_force_wmem_max = 1024;
+ OnPi1();
+
+ FLAGS_application_name = "sender";
+ aos::ShmEventLoop send_event_loop(&config.message());
+ aos::Sender<examples::Ping> ping_sender =
+ send_event_loop.MakeSender<examples::Ping>("/test");
+ SendPing(&ping_sender, 1);
+ aos::Fetcher<ServerStatistics> pi1_server_statistics_fetcher =
+ send_event_loop.MakeFetcher<ServerStatistics>("/aos");
+
+ MakePi1Server();
+ MakePi1Client();
+
+ // Now do it for "raspberrypi2", the client.
+ OnPi2();
+
+ MakePi2Server();
+
+ aos::ShmEventLoop receive_event_loop(&config.message());
+ aos::Fetcher<examples::Ping> ping_fetcher =
+ receive_event_loop.MakeFetcher<examples::Ping>("/test");
+ aos::Fetcher<ClientStatistics> pi2_client_statistics_fetcher =
+ receive_event_loop.MakeFetcher<ClientStatistics>("/pi2/aos");
+
+ // Before everything starts up, confirm there is no message.
+ EXPECT_FALSE(ping_fetcher.Fetch());
+
+ // Spin up the persistent pieces.
+ StartPi1Server();
+ StartPi1Client();
+ StartPi2Server();
+
+ {
+ constexpr size_t kNumPingMessages = 25;
+ // Now, spin up a client for 2 seconds.
+ MakePi2Client();
+ StartPi2Client();
+
+ std::this_thread::sleep_for(std::chrono::seconds(2));
+
+ for (size_t i = 0; i < kNumPingMessages; ++i) {
+ SendPing(&ping_sender, i);
+ }
+
+ // Give plenty of time for retries to succeed.
+ std::this_thread::sleep_for(std::chrono::seconds(5));
+
+ StopPi2Client();
+
+ // Confirm there is no detected duplicate packet.
+ EXPECT_TRUE(pi2_client_statistics_fetcher.Fetch());
+ EXPECT_GT(pi2_client_statistics_fetcher->connections()
+ ->Get(0)
+ ->received_packets(),
+ kNumPingMessages);
+ EXPECT_EQ(pi2_client_statistics_fetcher->connections()
+ ->Get(0)
+ ->duplicate_packets(),
+ 0u);
+
+ EXPECT_EQ(pi2_client_statistics_fetcher->connections()
+ ->Get(0)
+ ->partial_deliveries(),
+ 0u);
+
+ // Check that we received the reliable message that was sent before
+ // starting.
+ EXPECT_TRUE(ping_fetcher.FetchNext());
+ EXPECT_EQ(ping_fetcher->value(), 1);
+
+ // Check that we got all the messages sent while running.
+ for (size_t i = 0; i < kNumPingMessages; ++i) {
+ EXPECT_TRUE(ping_fetcher.FetchNext());
+ EXPECT_EQ(ping_fetcher->value(), i);
+ }
+
+ EXPECT_TRUE(pi1_server_statistics_fetcher.Fetch());
+ EXPECT_GT(
+ pi1_server_statistics_fetcher->connections()->Get(0)->sent_packets(),
+ kNumPingMessages);
+ EXPECT_GT(
+ pi1_server_statistics_fetcher->connections()->Get(0)->retry_count(), 0u)
+ << FlatbufferToJson(pi1_server_statistics_fetcher.get());
+ }
+
+ // Shut everyone else down.
+ StopPi1Client();
+ StopPi2Server();
+ StopPi1Server();
+}
+
+INSTANTIATE_TEST_SUITE_P(MessageBridgeTests, MessageBridgeParameterizedTest,
+ ::testing::Values(Param{
+ "message_bridge_test_common_config.json", false}));
+
+} // namespace testing
+} // namespace message_bridge
+} // namespace aos
diff --git a/aos/network/message_bridge_server.cc b/aos/network/message_bridge_server.cc
index 04b07c3..be6cc8e 100644
--- a/aos/network/message_bridge_server.cc
+++ b/aos/network/message_bridge_server.cc
@@ -5,14 +5,20 @@
#include "aos/init.h"
#include "aos/logging/dynamic_logging.h"
#include "aos/network/message_bridge_server_lib.h"
+#include "aos/network/sctp_lib.h"
#include "aos/sha256.h"
DEFINE_string(config, "aos_config.json", "Path to the config.");
DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
+DEFINE_bool(
+ wants_sctp_authentication, false,
+ "When set, try to use SCTP authentication if provided by the kernel");
namespace aos {
namespace message_bridge {
+using ::aos::util::ReadFileToVecOrDie;
+
int Main() {
aos::FlatbufferDetachedBuffer<aos::Configuration> config =
aos::configuration::ReadConfig(FLAGS_config);
@@ -22,7 +28,10 @@
event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
}
- MessageBridgeServer app(&event_loop, Sha256(config.span()));
+ MessageBridgeServer app(&event_loop, Sha256(config.span()),
+ FLAGS_wants_sctp_authentication
+ ? SctpAuthMethod::kAuth
+ : SctpAuthMethod::kNoAuth);
logging::DynamicLogging dynamic_logging(&event_loop);
diff --git a/aos/network/message_bridge_server.fbs b/aos/network/message_bridge_server.fbs
index e936828..b05d933 100644
--- a/aos/network/message_bridge_server.fbs
+++ b/aos/network/message_bridge_server.fbs
@@ -13,6 +13,8 @@
// Total number of messages that were dropped while sending (e.g.,
// those dropped by the kernel).
dropped_packets:uint (id: 2);
+ // Count of the total number of retries attempted on this channel.
+ retry_count:uint (id: 3);
}
// State of the connection.
@@ -62,6 +64,10 @@
// Statistics for every channel being forwarded to this node. Ordering is arbitrary;
// the channels are identified by an index in the ServerChannelStatistics.
channels:[ServerChannelStatistics] (id: 10);
+
+ // Total number of retries attempted on all channels. Typically due to kernel
+ // send buffers filling up.
+ retry_count:uint (id: 11);
}
// Statistics for all connections to all the clients.
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index 22b2f3f..9ac062e 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -13,13 +13,52 @@
#include "aos/network/message_bridge_server_generated.h"
#include "aos/network/remote_data_generated.h"
#include "aos/network/remote_message_generated.h"
+#include "aos/network/sctp_config_generated.h"
#include "aos/network/sctp_server.h"
#include "aos/network/timestamp_channel.h"
+// For retrying sends on reliable channels, we will do an additive backoff
+// strategy where we start at FLAGS_min_retry_period_ms and then add
+// kRetryAdditivePeriod every time the retry fails, up until
+// FLAGS_max_retry_period_ms. These numbers are somewhat arbitrarily chosen.
+// For the minimum retry period, choose 10ms as that is slow enough that the
+// kernel should have had time to clear its buffers, while being fast enough
+// that hopefully it is a relatively minor blip for anything that isn't
+// timing-critical (and timing-critical things that hit the retry logic are
+// probably in trouble).
+DEFINE_uint32(min_retry_period_ms, 10,
+ "Maximum retry timer period--the exponential backoff will not "
+ "exceed this period, in milliseconds.");
+// Amount of backoff to add every time a retry fails. Chosen semi-arbitrarily;
+// 100ms is large enough that the backoff actually does increase at a reasonable
+// rate, while preventing the period from growing so fast that it can readily
+// take multiple seconds for a retry to occur.
+DEFINE_uint32(retry_period_additive_backoff_ms, 100,
+ "Amount of time to add to the retry period every time a retry "
+ "fails, in milliseconds.");
+// Max out retry period at 10 seconds---this is generally a much longer
+// timescale than anything normally happening on our systems, while still being
+// short enough that the retries will regularly happen (basically, the maximum
+// should be short enough that a human trying to debug issues with the system
+// will still see the retries regularly happening as they debug, rather than
+// having to wait minutes or hours for a retry to occur).
+DEFINE_uint32(max_retry_period_ms, 10000,
+ "Maximum retry timer period--the additive backoff will not "
+ "exceed this period, in milliseconds.");
+
+DEFINE_int32(force_wmem_max, -1,
+ "If set to a nonnegative numbers, the wmem buffer size to use, in "
+ "bytes. Intended solely for testing purposes.");
+
+DECLARE_bool(use_sctp_authentication);
+
namespace aos {
namespace message_bridge {
namespace chrono = std::chrono;
+// How often we should poll for the active SCTP authentication key.
+constexpr chrono::seconds kRefreshAuthKeyPeriod{3};
+
bool ChannelState::Matches(const Channel *other_channel) {
return channel_->name()->string_view() ==
other_channel->name()->string_view() &&
@@ -28,9 +67,9 @@
}
flatbuffers::FlatBufferBuilder ChannelState::PackContext(
- FixedAllocator *allocator, const Context &context) {
+ const Context &context) {
flatbuffers::FlatBufferBuilder fbb(
- channel_->max_size() + kHeaderSizeOverhead(), allocator);
+ channel_->max_size() + kHeaderSizeOverhead(), allocator_);
fbb.ForceDefaults(true);
VLOG(2) << "Found " << peers_.size() << " peers on channel "
<< channel_->name()->string_view() << " "
@@ -60,34 +99,137 @@
return fbb;
}
-void ChannelState::SendData(SctpServer *server, FixedAllocator *allocator,
- const Context &context) {
- // TODO(austin): I don't like allocating this buffer when we are just freeing
- // it at the end of the function.
- flatbuffers::FlatBufferBuilder fbb = PackContext(allocator, context);
+ChannelState::ChannelState(aos::EventLoop *event_loop, const Channel *channel,
+ int channel_index, SctpServer *server,
+ FixedAllocator *allocator)
+ : event_loop_(event_loop),
+ channel_index_(channel_index),
+ channel_(channel),
+ server_(server),
+ allocator_(allocator),
+ last_message_fetcher_(event_loop->MakeRawFetcher(channel)),
+ retry_timer_(event_loop->AddTimer([this]() { SendData(); })),
+ retry_period_(std::chrono::milliseconds(FLAGS_min_retry_period_ms)) {
+ retry_timer_->set_name(absl::StrFormat("retry%d", channel_index));
+}
- // TODO(austin): Track which connections need to be reliable and handle
- // resending properly.
+bool ChannelState::PeerReadyToFetchNext(const Peer &peer,
+ const Context &context) const {
+ if (peer.sac_assoc_id == 0) {
+ // The peer is not connected; don't wait for it.
+ return true;
+ }
+ if (context.data == nullptr) {
+ // There is nothing on the Fetcher that we could still be trying to send.
+ return true;
+ }
+ if (!peer.last_sent_index.has_value()) {
+ // Nothing has been sent yet, so we can't possibly be ready to fetch the
+ // next message.
+ return false;
+ }
+ if (peer.last_sent_index.value() != context.queue_index) {
+ return false;
+ }
+ return true;
+}
+
+bool ChannelState::ReadyToFetchNext() const {
+ for (const Peer &peer : peers_) {
+ if (!PeerReadyToFetchNext(peer, last_message_fetcher_->context())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ChannelState::AnyNodeConnected() const {
+ for (const Peer &peer : peers_) {
+ if (peer.sac_assoc_id != 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void ChannelState::SendData() {
+ // The goal of this logic is to make it so that we continually send out any
+ // message data available on the current channel, until we reach a point where
+ // either (a) we run out of messages to send or (b) sends start to fail.
+ do {
+ if (ReadyToFetchNext()) {
+ retry_period_ = std::chrono::milliseconds(FLAGS_min_retry_period_ms);
+ if (!last_message_fetcher_->FetchNext()) {
+ return;
+ }
+ }
+ } while (TrySendData(last_message_fetcher_->context()));
+}
+
+bool ChannelState::TrySendData(const Context &context) {
+ CHECK(context.data != nullptr)
+ << configuration::StrippedChannelToString(channel_);
+ // TODO(austin): I don't like allocating this buffer when we are just
+ // freeing it at the end of the function.
+ flatbuffers::FlatBufferBuilder fbb = PackContext(context);
+
size_t sent_count = 0;
bool logged_remotely = false;
+ bool retry_required = false;
+ VLOG(1) << "Send for " << configuration::StrippedChannelToString(channel_)
+ << " with " << context.queue_index << " and data " << context.data;
for (Peer &peer : peers_) {
+ if (PeerReadyToFetchNext(peer, context)) {
+ VLOG(1) << "Skipping send for "
+ << configuration::StrippedChannelToString(channel_) << " to "
+ << FlatbufferToJson(peer.connection) << " with queue index of "
+ << context.queue_index;
+ // Either:
+ // * We already sent on this connection; we do not need to do anything
+ // further.
+ // * The peer in question is not connected.
+ continue;
+ }
logged_remotely = logged_remotely || peer.logged_remotely;
+ const int time_to_live_ms = peer.connection->time_to_live() / 1000000;
+ CHECK((time_to_live_ms == 0) == (peer.connection->time_to_live() == 0))
+ << ": TTLs below 1ms are not supported, as they would get rounded "
+ "down "
+ "to zero, which is used to indicate a reliable connection.";
+
if (peer.sac_assoc_id != 0 &&
- server->Send(std::string_view(
- reinterpret_cast<const char *>(fbb.GetBufferPointer()),
- fbb.GetSize()),
- peer.sac_assoc_id, peer.stream,
- peer.connection->time_to_live() / 1000000)) {
+ server_->Send(std::string_view(reinterpret_cast<const char *>(
+ fbb.GetBufferPointer()),
+ fbb.GetSize()),
+ peer.sac_assoc_id, peer.stream, time_to_live_ms)) {
peer.server_status->AddSentPacket(peer.node_index, channel_);
if (peer.logged_remotely) {
++sent_count;
}
+ peer.last_sent_index = context.queue_index;
} else {
- peer.server_status->AddDroppedPacket(peer.node_index, channel_);
+ if (time_to_live_ms == 0) {
+ // Reliable connection that failed to send; set a retry timer.
+ retry_required = true;
+ peer.server_status->AddPacketRetry(peer.node_index, channel_);
+ } else {
+ // Unreliable connection that failed to send; losses
+ // are permitted, so just mark it as sent.
+ peer.server_status->AddDroppedPacket(peer.node_index, channel_);
+ peer.last_sent_index = context.queue_index;
+ }
}
}
+ if (retry_required) {
+ retry_timer_->Schedule(event_loop_->monotonic_now() + retry_period_);
+ retry_period_ = std::min(
+ retry_period_ +
+ std::chrono::milliseconds(FLAGS_retry_period_additive_backoff_ms),
+ std::chrono::milliseconds(FLAGS_max_retry_period_ms));
+ }
+
if (logged_remotely) {
if (sent_count == 0) {
VLOG(1)
@@ -110,6 +252,7 @@
// TODO(austin): ~10 MB chunks on disk and push them over the logging
// channel? Threadsafe disk backed queue object which can handle restarts
// and flushes. Whee.
+ return !retry_required;
}
void ChannelState::HandleDelivery(sctp_assoc_t rcv_assoc_id, uint16_t /*ssn*/,
@@ -181,13 +324,17 @@
}
int ChannelState::NodeDisconnected(sctp_assoc_t assoc_id) {
- VLOG(1) << "Disconnected " << assoc_id;
+ VLOG(1) << "Disconnected " << assoc_id << " for "
+ << configuration::StrippedChannelToString(channel_);
for (ChannelState::Peer &peer : peers_) {
if (peer.sac_assoc_id == assoc_id) {
// TODO(austin): This will not handle multiple clients from
// a single node. But that should be rare.
peer.sac_assoc_id = 0;
peer.stream = 0;
+ // We do not guarantee the consistent delivery of reliable channels
+ // through node disconnects.
+ peer.last_sent_index.reset();
return peer.node_index;
}
}
@@ -195,8 +342,7 @@
}
int ChannelState::NodeConnected(const Node *node, sctp_assoc_t assoc_id,
- int stream, SctpServer *server,
- FixedAllocator *allocator,
+ int stream,
aos::monotonic_clock::time_point monotonic_now,
std::vector<sctp_assoc_t> *reconnected) {
VLOG(1) << "Channel " << channel_->name()->string_view() << " "
@@ -226,60 +372,71 @@
<< " already connected on " << peer.sac_assoc_id
<< " aborting old connection and switching to " << assoc_id;
}
- server->Abort(peer.sac_assoc_id);
+ server_->Abort(peer.sac_assoc_id);
+ // sac_assoc_id will be set again before NodeConnected() exits, but we
+ // clear it here so that AnyNodeConnected() (or any other similar
+ // function calls) observe the node as having had an aborted
+ // connection.
+ peer.sac_assoc_id = 0;
}
}
+ // Clear the last sent index to force a retry of any reliable channels.
+ peer.last_sent_index.reset();
+
+ if (!AnyNodeConnected()) {
+ // If no one else is connected yet, reset the Fetcher.
+ last_message_fetcher_->Fetch();
+ retry_period_ = std::chrono::milliseconds(FLAGS_min_retry_period_ms);
+ }
+ // Unreliable channels aren't supposed to send out the latest fetched
+ // message.
+ if (peer.connection->time_to_live() != 0 &&
+ last_message_fetcher_->context().data != nullptr) {
+ peer.last_sent_index = last_message_fetcher_->context().queue_index;
+ }
peer.sac_assoc_id = assoc_id;
peer.stream = stream;
peer.server_status->Connect(peer.node_index, monotonic_now);
- server->SetStreamPriority(assoc_id, stream, peer.connection->priority());
- if (last_message_fetcher_ && peer.connection->time_to_live() == 0) {
- last_message_fetcher_->Fetch();
- VLOG(1) << "Got a connection on a reliable channel "
- << configuration::StrippedChannelToString(
- last_message_fetcher_->channel())
- << ", sending? "
- << (last_message_fetcher_->context().data != nullptr);
- if (last_message_fetcher_->context().data != nullptr) {
- // SendData sends to all... Only send to the new one.
- flatbuffers::FlatBufferBuilder fbb =
- PackContext(allocator, last_message_fetcher_->context());
-
- if (server->Send(std::string_view(reinterpret_cast<const char *>(
- fbb.GetBufferPointer()),
- fbb.GetSize()),
- peer.sac_assoc_id, peer.stream,
- peer.connection->time_to_live() / 1000000)) {
- peer.server_status->AddSentPacket(peer.node_index, channel_);
- } else {
- peer.server_status->AddDroppedPacket(peer.node_index, channel_);
- }
- }
- }
+ server_->SetStreamPriority(assoc_id, stream, peer.connection->priority());
+ SendData();
return peer.node_index;
}
}
return -1;
}
-MessageBridgeServer::MessageBridgeServer(aos::ShmEventLoop *event_loop,
- std::string config_sha256)
+MessageBridgeServer::MessageBridgeServer(
+ aos::ShmEventLoop *event_loop, std::string config_sha256,
+ SctpAuthMethod requested_authentication)
: event_loop_(event_loop),
timestamp_loggers_(event_loop_),
server_(max_channels() + kControlStreams(), "",
- event_loop->node()->port()),
- server_status_(event_loop,
- [this](const Context &context) {
- timestamp_state_->SendData(&server_, &allocator_,
- context);
- }),
+ event_loop->node()->port(), requested_authentication),
+ server_status_(event_loop, [this]() { timestamp_state_->SendData(); }),
config_sha256_(std::move(config_sha256)),
- allocator_(0) {
+ allocator_(0),
+ refresh_key_timer_(event_loop->AddTimer([this]() { RequestAuthKey(); })),
+ sctp_config_request_(event_loop_->MakeSender<SctpConfigRequest>("/aos")) {
CHECK_EQ(config_sha256_.size(), 64u) << ": Wrong length sha256sum";
CHECK(event_loop_->node() != nullptr) << ": No nodes configured.";
+ // Set up the SCTP configuration watcher and timer.
+ if (requested_authentication == SctpAuthMethod::kAuth && HasSctpAuth()) {
+ event_loop_->MakeWatcher("/aos", [this](const SctpConfig &config) {
+ if (config.has_key()) {
+ server_.SetAuthKey(*config.key());
+ }
+ });
+
+ // We poll in case the SCTP authentication key has changed.
+ refresh_key_timer_->set_name("refresh_key");
+ event_loop_->OnRun([this]() {
+ refresh_key_timer_->Schedule(event_loop_->monotonic_now(),
+ kRefreshAuthKeyPeriod);
+ });
+ }
// Start out with a decent size big enough to hold timestamps.
size_t max_size = 204;
@@ -333,12 +490,10 @@
if (configuration::ChannelIsForwardedFromNode(channel,
event_loop_->node())) {
- bool any_reliable = false;
for (const Connection *connection : *channel->destination_nodes()) {
if (connection->time_to_live() == 0) {
reliable_buffer_size +=
static_cast<size_t>(channel->max_size() + kHeaderSizeOverhead());
- any_reliable = true;
}
}
@@ -351,8 +506,7 @@
max_channel_buffer_size);
std::unique_ptr<ChannelState> state(new ChannelState{
- channel, channel_index,
- any_reliable ? event_loop_->MakeRawFetcher(channel) : nullptr});
+ event_loop, channel, channel_index, &server_, &allocator_});
for (const Connection *connection : *channel->destination_nodes()) {
const Node *other_node = configuration::GetNode(
@@ -380,11 +534,8 @@
if (channel != timestamp_channel) {
// Call SendData for every message.
ChannelState *state_ptr = state.get();
- event_loop_->MakeRawWatcher(
- channel, [this, state_ptr](const Context &context,
- const void * /*message*/) {
- state_ptr->SendData(&server_, &allocator_, context);
- });
+ event_loop_->MakeRawNoArgWatcher(
+ channel, [state_ptr](const Context &) { state_ptr->SendData(); });
} else {
for (const Connection *connection : *channel->destination_nodes()) {
CHECK_GE(connection->time_to_live(), 1000u);
@@ -394,8 +545,8 @@
}
channels_.emplace_back(std::move(state));
} else if (channel == timestamp_channel) {
- std::unique_ptr<ChannelState> state(
- new ChannelState{channel, channel_index, nullptr});
+ std::unique_ptr<ChannelState> state(new ChannelState{
+ event_loop_, channel, channel_index, &server_, &allocator_});
for (const Connection *connection : *channel->destination_nodes()) {
CHECK_GE(connection->time_to_live(), 1000u);
}
@@ -415,8 +566,12 @@
LOG(INFO) << "Reliable buffer size for all clients is "
<< reliable_buffer_size;
server_.SetMaxReadSize(max_size);
- server_.SetMaxWriteSize(
- std::max(max_channel_buffer_size, reliable_buffer_size));
+ if (FLAGS_force_wmem_max >= 0) {
+ server_.SetMaxWriteSize(FLAGS_force_wmem_max);
+ } else {
+ server_.SetMaxWriteSize(
+ std::max(max_channel_buffer_size, reliable_buffer_size));
+ }
// Since we are doing interleaving mode 1, we will see at most 1 message being
// delivered at a time for an association. That means, if a message is
@@ -624,8 +779,7 @@
if (channel_state->Matches(channel)) {
node_index = channel_state->NodeConnected(
connect->node(), message->header.rcvinfo.rcv_assoc_id,
- channel_index, &server_, &allocator_, monotonic_now,
- &reconnected_);
+ channel_index, monotonic_now, &reconnected_);
CHECK_NE(node_index, -1)
<< ": Failed to find node "
<< aos::FlatbufferToJson(connect->node()) << " for connection "
@@ -691,5 +845,12 @@
}
}
+void MessageBridgeServer::RequestAuthKey() {
+ auto sender = sctp_config_request_.MakeBuilder();
+ auto builder = sender.MakeBuilder<SctpConfigRequest>();
+ builder.add_request_key(true);
+ sender.CheckOk(sender.Send(builder.Finish()));
+}
+
} // namespace message_bridge
} // namespace aos
diff --git a/aos/network/message_bridge_server_lib.h b/aos/network/message_bridge_server_lib.h
index d74a847..b8377c6 100644
--- a/aos/network/message_bridge_server_lib.h
+++ b/aos/network/message_bridge_server_lib.h
@@ -15,6 +15,7 @@
#include "aos/network/message_bridge_server_status.h"
#include "aos/network/remote_data_generated.h"
#include "aos/network/remote_message_generated.h"
+#include "aos/network/sctp_config_request_generated.h"
#include "aos/network/sctp_server.h"
#include "aos/network/timestamp_channel.h"
#include "aos/network/timestamp_generated.h"
@@ -28,11 +29,9 @@
// new message from the event loop.
class ChannelState {
public:
- ChannelState(const Channel *channel, int channel_index,
- std::unique_ptr<aos::RawFetcher> last_message_fetcher)
- : channel_index_(channel_index),
- channel_(channel),
- last_message_fetcher_(std::move(last_message_fetcher)) {}
+ ChannelState(aos::EventLoop *event_loop, const Channel *channel,
+ int channel_index, SctpServer *server,
+ FixedAllocator *allocator);
// Class to encapsulate all the state per client on a channel. A client may
// be subscribed to multiple channels.
@@ -61,6 +60,12 @@
// If true, this message will be logged on a receiving node. We need to
// keep it around to log it locally if that fails.
bool logged_remotely = false;
+
+ // Last "successfully" sent message for this connection. For reliable
+ // connections, this being set to a value will indicate that the message was
+ // truly successfully sent. For unreliable connections, this will get set as
+ // soon as we've attempted to send it.
+ std::optional<size_t> last_sent_index = std::nullopt;
};
// Needs to be called when a node (might have) disconnected.
@@ -70,7 +75,6 @@
// reconnects.
int NodeDisconnected(sctp_assoc_t assoc_id);
int NodeConnected(const Node *node, sctp_assoc_t assoc_id, int stream,
- SctpServer *server, FixedAllocator *allocator,
aos::monotonic_clock::time_point monotonic_now,
std::vector<sctp_assoc_t> *reconnected);
@@ -84,13 +88,12 @@
// channel.
bool Matches(const Channel *other_channel);
- // Sends the data in context using the provided server.
- void SendData(SctpServer *server, FixedAllocator *allocator,
- const Context &context);
+ // Sends as much data on this channel as is possible using the internal
+ // fetcher.
+ void SendData();
// Packs a context into a size prefixed message header for transmission.
- flatbuffers::FlatBufferBuilder PackContext(FixedAllocator *allocator,
- const Context &context);
+ flatbuffers::FlatBufferBuilder PackContext(const Context &context);
// Handles reception of delivery times.
void HandleDelivery(sctp_assoc_t rcv_assoc_id, uint16_t ssn,
@@ -99,21 +102,81 @@
MessageBridgeServerStatus *server_status);
private:
+ // When sending a message, we must guarantee that reliable messages make it to
+ // their destinations. Unfortunately, we cannot purely rely on the kernel to
+ // provide this guarantee, as the internal send buffer can fill up, resulting
+ // in Send() calls failing. To guarantee that we end up sending reliable
+ // messages, we do the following:
+ // * For channels with no reliable connections, we send the message and do not
+ // retry if the kernel rejects it.
+ // * For channels with at least one reliable connection:
+ // * We will always attempt to retry failed sends on reliable connections
+ // (if a channel has mixed reliable/unreliable connections, the unreliable
+ // connections are not retried).
+ // * Until we have successfully sent message X on every single reliable
+ // connection, we will not progress to sending X+1 on *any* connection.
+ // This reduces the number of Fetchers that we must maintain for each
+ // channel.
+ // * If a given client node is not connected (or becomes disconnected), then
+ // it will be ignored and will not block the progression of sending of
+ // reliable messages to other nodes (connection state is tracked through
+ // Peer::sac_assoc_id).
+ // * Retries will be performed with an additive backoff up to a set
+ // maximum. The backoff duration resets once the retry succeeds.
+ // * If we fall so far behind that the Fetcher drops off the end of the
+ // queue, then we kill the message bridge.
+
+ // Returns false if a retry will be required for the message in question, and
+ // true if it was sent "successfully" (note that for unreliable messages, we
+ // may drop the message but still return true here).
+ bool TrySendData(const Context &context);
+
+ // Returns true if all of the peer connections are in a state where we are
+ // permitted to progress to sending the next message. As described above, this
+ // will never block on any unreliable connections, but will not return true
+ // until every reliable connection has successfully sent the currently fetched
+ // message.
+ bool ReadyToFetchNext() const;
+ // Returns true if the given peer can move to the next message (used by
+ // ReadyToFetchNext()).
+ bool PeerReadyToFetchNext(const Peer &peer, const Context &context) const;
+
+ bool AnyNodeConnected() const;
+
+ aos::EventLoop *const event_loop_;
const int channel_index_;
const Channel *const channel_;
+ SctpServer *server_;
+ FixedAllocator *allocator_;
+
std::vector<Peer> peers_;
- // A fetcher to use to send the last message when a node connects and is
- // reliable.
+ // A fetcher to use to send the message. For reliable channels this is
+ // used both on startup to fetch the latest message as well as to
+ // support retries of messages. For unreliable channels, we use the
+ // Fetcher to minimize the diff with the reliable codepath, but it
+ // provides no utility over just using a Watcher directly.
std::unique_ptr<aos::RawFetcher> last_message_fetcher_;
+ // For reliable channels, the timer to use to retry sends on said channel.
+ aos::TimerHandler *retry_timer_;
+ // Current retry period.
+ std::chrono::milliseconds retry_period_;
};
// This encapsulates the state required to talk to *all* the clients from this
// node. It handles the session and dispatches data to the ChannelState.
class MessageBridgeServer {
public:
- MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256);
+ MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256,
+ SctpAuthMethod requested_authentication);
+
+ // Delete copy/move constructors explicitly--we internally pass around
+ // pointers to internal state.
+ MessageBridgeServer(MessageBridgeServer &&) = delete;
+ MessageBridgeServer(const MessageBridgeServer &) = delete;
+ MessageBridgeServer &operator=(MessageBridgeServer &&) = delete;
+ MessageBridgeServer &operator=(const MessageBridgeServer &) = delete;
~MessageBridgeServer() { event_loop_->epoll()->DeleteFd(server_.fd()); }
@@ -141,6 +204,9 @@
return event_loop_->configuration()->channels()->size();
}
+ // Sends a request for the currently active authentication key.
+ void RequestAuthKey();
+
// Event loop to schedule everything on.
aos::ShmEventLoop *event_loop_;
@@ -163,6 +229,12 @@
std::vector<sctp_assoc_t> reconnected_;
FixedAllocator allocator_;
+
+ // We use this timer to poll the active authentication key.
+ aos::TimerHandler *refresh_key_timer_;
+
+ // Used to request the current sctp settings to be used.
+ aos::Sender<SctpConfigRequest> sctp_config_request_;
};
} // namespace message_bridge
diff --git a/aos/network/message_bridge_server_status.cc b/aos/network/message_bridge_server_status.cc
index 5607d4a..22270d9 100644
--- a/aos/network/message_bridge_server_status.cc
+++ b/aos/network/message_bridge_server_status.cc
@@ -49,6 +49,7 @@
connection_builder.add_node(node_offset);
connection_builder.add_state(State::DISCONNECTED);
connection_builder.add_dropped_packets(0);
+ connection_builder.add_retry_count(0);
connection_builder.add_sent_packets(0);
connection_builder.add_monotonic_offset(0);
connection_builder.add_partial_deliveries(0);
@@ -90,7 +91,7 @@
} // namespace
MessageBridgeServerStatus::MessageBridgeServerStatus(
- aos::EventLoop *event_loop, std::function<void(const Context &)> send_data)
+ aos::EventLoop *event_loop, std::function<void()> send_data)
: event_loop_(event_loop),
sender_(event_loop_->MakeSender<ServerStatistics>("/aos")),
statistics_(MakeServerStatistics(
@@ -137,6 +138,7 @@
configuration::ChannelIndex(event_loop_->configuration(), channel);
initial_statistics.sent_packets = 0;
initial_statistics.dropped_packets = 0;
+ initial_statistics.retry_count = 0;
channel_statistics[channel] = initial_statistics;
}
@@ -182,6 +184,7 @@
connection->mutate_sent_packets(connection->sent_packets() + 1);
node.channel_statistics[channel].sent_packets++;
}
+
void MessageBridgeServerStatus::AddDroppedPacket(int node_index,
const aos::Channel *channel) {
CHECK(nodes_[node_index].has_value());
@@ -191,6 +194,14 @@
node.channel_statistics[channel].dropped_packets++;
}
+void MessageBridgeServerStatus::AddPacketRetry(int node_index,
+ const aos::Channel *channel) {
+ NodeState &node = nodes_[node_index].value();
+ ServerConnection *connection = node.server_connection;
+ connection->mutate_retry_count(connection->retry_count() + 1);
+ node.channel_statistics[channel].retry_count++;
+}
+
void MessageBridgeServerStatus::SetBootUUID(int node_index,
const UUID &boot_uuid) {
nodes_[node_index].value().boot_uuid = boot_uuid;
@@ -281,6 +292,7 @@
server_connection_builder.add_dropped_packets(
connection->dropped_packets());
server_connection_builder.add_sent_packets(connection->sent_packets());
+ server_connection_builder.add_retry_count(connection->retry_count());
server_connection_builder.add_partial_deliveries(
PartialDeliveries(node_index));
server_connection_builder.add_channels(channels_offset);
@@ -465,18 +477,10 @@
timestamp_failure_counter_.Count(err);
// Reply only if we successfully sent the timestamp
if (err == RawSender::Error::kOk) {
- Context context;
- context.monotonic_event_time = timestamp_sender_.monotonic_sent_time();
- context.realtime_event_time = timestamp_sender_.realtime_sent_time();
- context.queue_index = timestamp_sender_.sent_queue_index();
- context.size = timestamp_copy.span().size();
- context.source_boot_uuid = event_loop_->boot_uuid();
- context.data = timestamp_copy.span().data();
-
// Since we are building up the timestamp to send here, we need to trigger
// the SendData call ourselves.
if (send_data_) {
- send_data_(context);
+ send_data_();
}
}
}
diff --git a/aos/network/message_bridge_server_status.h b/aos/network/message_bridge_server_status.h
index d6edb89..3945d57 100644
--- a/aos/network/message_bridge_server_status.h
+++ b/aos/network/message_bridge_server_status.h
@@ -51,9 +51,9 @@
uint32_t partial_deliveries = 0;
};
- MessageBridgeServerStatus(aos::EventLoop *event_loop,
- std::function<void(const Context &)> send_data =
- std::function<void(const Context &)>());
+ MessageBridgeServerStatus(
+ aos::EventLoop *event_loop,
+ std::function<void()> send_data = std::function<void()>());
MessageBridgeServerStatus(const MessageBridgeServerStatus &) = delete;
MessageBridgeServerStatus(MessageBridgeServerStatus &&) = delete;
@@ -61,7 +61,7 @@
delete;
MessageBridgeServerStatus &operator=(MessageBridgeServerStatus &&) = delete;
- void set_send_data(std::function<void(const Context &)> send_data) {
+ void set_send_data(std::function<void()> send_data) {
send_data_ = send_data;
}
@@ -97,6 +97,7 @@
// node_index must be a valid client node.
void AddSentPacket(int node_index, const aos::Channel *channel);
void AddDroppedPacket(int node_index, const aos::Channel *channel);
+ void AddPacketRetry(int node_index, const aos::Channel *channel);
// Returns the ServerConnection message which is updated by the server.
ServerConnection *FindServerConnection(std::string_view node_name);
@@ -148,7 +149,7 @@
aos::monotonic_clock::time_point last_statistics_send_time_ =
aos::monotonic_clock::min_time;
- std::function<void(const Context &)> send_data_;
+ std::function<void()> send_data_;
bool send_ = true;
diff --git a/aos/network/message_bridge_test.cc b/aos/network/message_bridge_test.cc
index 991d618..4eb5e8b 100644
--- a/aos/network/message_bridge_test.cc
+++ b/aos/network/message_bridge_test.cc
@@ -10,323 +10,21 @@
#include "aos/network/message_bridge_client_lib.h"
#include "aos/network/message_bridge_protocol.h"
#include "aos/network/message_bridge_server_lib.h"
+#include "aos/network/message_bridge_test_lib.h"
#include "aos/network/team_number.h"
#include "aos/sha256.h"
#include "aos/testing/path.h"
#include "aos/util/file.h"
-DECLARE_string(boot_uuid);
-
namespace aos {
-void SetShmBase(const std::string_view base);
namespace message_bridge {
namespace testing {
-using aos::testing::ArtifactPath;
-
-namespace chrono = std::chrono;
-
-std::string ShmBase(const std::string_view node) {
- const char *tmpdir_c_str = getenv("TEST_TMPDIR");
- if (tmpdir_c_str != nullptr) {
- return absl::StrCat(tmpdir_c_str, "/", node);
- } else {
- return absl::StrCat("/dev/shm/", node);
- }
-}
-
-void DoSetShmBase(const std::string_view node) {
- aos::SetShmBase(ShmBase(node));
-}
-
-// Class to manage starting and stopping a thread with an event loop in it. The
-// thread is guarenteed to be running before the constructor exits.
-class ThreadedEventLoopRunner {
- public:
- ThreadedEventLoopRunner(aos::ShmEventLoop *event_loop)
- : event_loop_(event_loop), my_thread_([this]() {
- LOG(INFO) << "Started " << event_loop_->name();
- event_loop_->OnRun([this]() { event_.Set(); });
- event_loop_->Run();
- }) {
- event_.Wait();
- }
-
- ~ThreadedEventLoopRunner() { Exit(); }
-
- void Exit() {
- if (my_thread_.joinable()) {
- event_loop_->Exit();
- my_thread_.join();
- my_thread_ = std::thread();
- }
- }
-
- private:
- aos::Event event_;
- aos::ShmEventLoop *event_loop_;
- std::thread my_thread_;
-};
-
-// Parameters to run all the tests with.
-struct Param {
- // The config file to use.
- std::string config;
- // If true, the RemoteMessage channel should be shared between all the remote
- // channels. If false, there will be 1 RemoteMessage channel per remote
- // channel.
- bool shared;
-};
-
-class MessageBridgeParameterizedTest
- : public ::testing::TestWithParam<struct Param> {
- public:
- MessageBridgeParameterizedTest()
- : config(aos::configuration::ReadConfig(
- ArtifactPath(absl::StrCat("aos/network/", GetParam().config)))),
- config_sha256(Sha256(config.span())),
- pi1_boot_uuid_(UUID::Random()),
- pi2_boot_uuid_(UUID::Random()) {
- util::UnlinkRecursive(ShmBase("pi1"));
- util::UnlinkRecursive(ShmBase("pi2"));
- }
-
- bool shared() const { return GetParam().shared; }
-
- void OnPi1() {
- DoSetShmBase("pi1");
- FLAGS_override_hostname = "raspberrypi";
- FLAGS_boot_uuid = pi1_boot_uuid_.ToString();
- }
-
- void OnPi2() {
- DoSetShmBase("pi2");
- FLAGS_override_hostname = "raspberrypi2";
- FLAGS_boot_uuid = pi2_boot_uuid_.ToString();
- }
-
- void MakePi1Server(std::string server_config_sha256 = "") {
- OnPi1();
- FLAGS_application_name = "pi1_message_bridge_server";
- pi1_server_event_loop =
- std::make_unique<aos::ShmEventLoop>(&config.message());
- pi1_server_event_loop->SetRuntimeRealtimePriority(1);
- pi1_message_bridge_server = std::make_unique<MessageBridgeServer>(
- pi1_server_event_loop.get(), server_config_sha256.size() == 0
- ? config_sha256
- : server_config_sha256);
- }
-
- void RunPi1Server(chrono::nanoseconds duration) {
- // Set up a shutdown callback.
- aos::TimerHandler *const quit = pi1_server_event_loop->AddTimer(
- [this]() { pi1_server_event_loop->Exit(); });
- pi1_server_event_loop->OnRun([this, quit, duration]() {
- // Stop between timestamps, not exactly on them.
- quit->Schedule(pi1_server_event_loop->monotonic_now() + duration);
- });
-
- pi1_server_event_loop->Run();
- }
-
- void StartPi1Server() {
- pi1_server_thread =
- std::make_unique<ThreadedEventLoopRunner>(pi1_server_event_loop.get());
- }
-
- void StopPi1Server() {
- pi1_server_thread.reset();
- pi1_message_bridge_server.reset();
- pi1_server_event_loop.reset();
- }
-
- void MakePi1Client() {
- OnPi1();
- FLAGS_application_name = "pi1_message_bridge_client";
- pi1_client_event_loop =
- std::make_unique<aos::ShmEventLoop>(&config.message());
- pi1_client_event_loop->SetRuntimeRealtimePriority(1);
- pi1_message_bridge_client = std::make_unique<MessageBridgeClient>(
- pi1_client_event_loop.get(), config_sha256);
- }
-
- void StartPi1Client() {
- pi1_client_thread =
- std::make_unique<ThreadedEventLoopRunner>(pi1_client_event_loop.get());
- }
-
- void StopPi1Client() {
- pi1_client_thread.reset();
- pi1_message_bridge_client.reset();
- pi1_client_event_loop.reset();
- }
-
- void MakePi1Test() {
- OnPi1();
- FLAGS_application_name = "test1";
- pi1_test_event_loop =
- std::make_unique<aos::ShmEventLoop>(&config.message());
-
- pi1_test_event_loop->MakeWatcher(
- "/pi1/aos", [](const ServerStatistics &stats) {
- VLOG(1) << "/pi1/aos ServerStatistics " << FlatbufferToJson(&stats);
- });
-
- pi1_test_event_loop->MakeWatcher(
- "/pi1/aos", [](const ClientStatistics &stats) {
- VLOG(1) << "/pi1/aos ClientStatistics " << FlatbufferToJson(&stats);
- });
-
- pi1_test_event_loop->MakeWatcher(
- "/pi1/aos", [](const Timestamp ×tamp) {
- VLOG(1) << "/pi1/aos Timestamp " << FlatbufferToJson(×tamp);
- });
- pi1_test_event_loop->MakeWatcher(
- "/pi2/aos", [this](const Timestamp ×tamp) {
- VLOG(1) << "/pi2/aos Timestamp " << FlatbufferToJson(×tamp);
- EXPECT_EQ(pi1_test_event_loop->context().source_boot_uuid,
- pi2_boot_uuid_);
- });
- }
-
- void StartPi1Test() {
- pi1_test_thread =
- std::make_unique<ThreadedEventLoopRunner>(pi1_test_event_loop.get());
- }
-
- void StopPi1Test() { pi1_test_thread.reset(); }
-
- void MakePi2Server() {
- OnPi2();
- FLAGS_application_name = "pi2_message_bridge_server";
- pi2_server_event_loop =
- std::make_unique<aos::ShmEventLoop>(&config.message());
- pi2_server_event_loop->SetRuntimeRealtimePriority(1);
- pi2_message_bridge_server = std::make_unique<MessageBridgeServer>(
- pi2_server_event_loop.get(), config_sha256);
- }
-
- void RunPi2Server(chrono::nanoseconds duration) {
- // Set up a shutdown callback.
- aos::TimerHandler *const quit = pi2_server_event_loop->AddTimer(
- [this]() { pi2_server_event_loop->Exit(); });
- pi2_server_event_loop->OnRun([this, quit, duration]() {
- // Stop between timestamps, not exactly on them.
- quit->Schedule(pi2_server_event_loop->monotonic_now() + duration);
- });
-
- pi2_server_event_loop->Run();
- }
-
- void StartPi2Server() {
- pi2_server_thread =
- std::make_unique<ThreadedEventLoopRunner>(pi2_server_event_loop.get());
- }
-
- void StopPi2Server() {
- pi2_server_thread.reset();
- pi2_message_bridge_server.reset();
- pi2_server_event_loop.reset();
- }
-
- void MakePi2Client() {
- OnPi2();
- FLAGS_application_name = "pi2_message_bridge_client";
- pi2_client_event_loop =
- std::make_unique<aos::ShmEventLoop>(&config.message());
- pi2_client_event_loop->SetRuntimeRealtimePriority(1);
- pi2_message_bridge_client = std::make_unique<MessageBridgeClient>(
- pi2_client_event_loop.get(), config_sha256);
- }
-
- void RunPi2Client(chrono::nanoseconds duration) {
- // Run for 5 seconds to make sure we have time to estimate the offset.
- aos::TimerHandler *const quit = pi2_client_event_loop->AddTimer(
- [this]() { pi2_client_event_loop->Exit(); });
- pi2_client_event_loop->OnRun([this, quit, duration]() {
- // Stop between timestamps, not exactly on them.
- quit->Schedule(pi2_client_event_loop->monotonic_now() + duration);
- });
-
- // And go!
- pi2_client_event_loop->Run();
- }
-
- void StartPi2Client() {
- pi2_client_thread =
- std::make_unique<ThreadedEventLoopRunner>(pi2_client_event_loop.get());
- }
-
- void StopPi2Client() {
- pi2_client_thread.reset();
- pi2_message_bridge_client.reset();
- pi2_client_event_loop.reset();
- }
-
- void MakePi2Test() {
- OnPi2();
- FLAGS_application_name = "test2";
- pi2_test_event_loop =
- std::make_unique<aos::ShmEventLoop>(&config.message());
-
- pi2_test_event_loop->MakeWatcher(
- "/pi2/aos", [](const ServerStatistics &stats) {
- VLOG(1) << "/pi2/aos ServerStatistics " << FlatbufferToJson(&stats);
- });
-
- pi2_test_event_loop->MakeWatcher(
- "/pi2/aos", [](const ClientStatistics &stats) {
- VLOG(1) << "/pi2/aos ClientStatistics " << FlatbufferToJson(&stats);
- });
-
- pi2_test_event_loop->MakeWatcher(
- "/pi1/aos", [this](const Timestamp ×tamp) {
- VLOG(1) << "/pi1/aos Timestamp " << FlatbufferToJson(×tamp);
- EXPECT_EQ(pi2_test_event_loop->context().source_boot_uuid,
- pi1_boot_uuid_);
- });
- pi2_test_event_loop->MakeWatcher(
- "/pi2/aos", [](const Timestamp ×tamp) {
- VLOG(1) << "/pi2/aos Timestamp " << FlatbufferToJson(×tamp);
- });
- }
-
- void StartPi2Test() {
- pi2_test_thread =
- std::make_unique<ThreadedEventLoopRunner>(pi2_test_event_loop.get());
- }
-
- void StopPi2Test() { pi2_test_thread.reset(); }
-
- aos::FlatbufferDetachedBuffer<aos::Configuration> config;
- std::string config_sha256;
-
- const UUID pi1_boot_uuid_;
- const UUID pi2_boot_uuid_;
-
- std::unique_ptr<aos::ShmEventLoop> pi1_server_event_loop;
- std::unique_ptr<MessageBridgeServer> pi1_message_bridge_server;
- std::unique_ptr<ThreadedEventLoopRunner> pi1_server_thread;
-
- std::unique_ptr<aos::ShmEventLoop> pi1_client_event_loop;
- std::unique_ptr<MessageBridgeClient> pi1_message_bridge_client;
- std::unique_ptr<ThreadedEventLoopRunner> pi1_client_thread;
-
- std::unique_ptr<aos::ShmEventLoop> pi1_test_event_loop;
- std::unique_ptr<ThreadedEventLoopRunner> pi1_test_thread;
-
- std::unique_ptr<aos::ShmEventLoop> pi2_server_event_loop;
- std::unique_ptr<MessageBridgeServer> pi2_message_bridge_server;
- std::unique_ptr<ThreadedEventLoopRunner> pi2_server_thread;
-
- std::unique_ptr<aos::ShmEventLoop> pi2_client_event_loop;
- std::unique_ptr<MessageBridgeClient> pi2_message_bridge_client;
- std::unique_ptr<ThreadedEventLoopRunner> pi2_client_thread;
-
- std::unique_ptr<aos::ShmEventLoop> pi2_test_event_loop;
- std::unique_ptr<ThreadedEventLoopRunner> pi2_test_thread;
-};
+// Note: All of these tests spin up ShmEventLoop's in separate threads to allow
+// us to run the "real" message bridge. This requires extra threading and timing
+// coordination to make happen, which is the reason for some of the extra
+// complexity in these tests.
// Test that we can send a ping message over sctp and receive it.
TEST_P(MessageBridgeParameterizedTest, PingPong) {
@@ -913,7 +611,7 @@
StopPi2Client();
}
- // Shut everyone else down
+ // Shut everyone else down.
StopPi1Server();
StopPi1Client();
StopPi2Server();
@@ -1073,7 +771,7 @@
StopPi2Server();
}
- // Shut everyone else down
+ // Shut everyone else down.
StopPi1Server();
StopPi1Client();
StopPi2Client();
@@ -1129,6 +827,8 @@
const size_t ping_channel_index = configuration::ChannelIndex(
receive_event_loop.configuration(), ping_fetcher.channel());
+ // ping_timestamp_count is accessed from multiple threads (the Watcher that
+ // triggers it is in a separate thread), so make it atomic.
std::atomic<int> ping_timestamp_count{0};
const std::string channel_name =
shared() ? "/pi1/aos/remote_timestamps/pi2"
@@ -1150,7 +850,7 @@
EXPECT_FALSE(ping_fetcher.Fetch());
EXPECT_FALSE(unreliable_ping_fetcher.Fetch());
- // Spin up the persistant pieces.
+ // Spin up the persistent pieces.
StartPi1Server();
StartPi1Client();
StartPi2Server();
@@ -1161,7 +861,7 @@
&pi1_remote_timestamp_event_loop);
{
- // Now, spin up a client for 2 seconds.
+ // Now spin up a client for 2 seconds.
MakePi2Client();
RunPi2Client(chrono::milliseconds(2050));
@@ -1210,7 +910,7 @@
StopPi2Client();
}
- // Shut everyone else down
+ // Shut everyone else down.
StopPi1Client();
StopPi2Server();
pi1_remote_timestamp_thread.reset();
@@ -1258,6 +958,8 @@
const size_t ping_channel_index = configuration::ChannelIndex(
receive_event_loop.configuration(), ping_fetcher.channel());
+ // ping_timestamp_count is accessed from multiple threads (the Watcher that
+ // triggers it is in a separate thread), so make it atomic.
std::atomic<int> ping_timestamp_count{0};
const std::string channel_name =
shared() ? "/pi1/aos/remote_timestamps/pi2"
@@ -1279,7 +981,7 @@
EXPECT_FALSE(ping_fetcher.Fetch());
EXPECT_FALSE(unreliable_ping_fetcher.Fetch());
- // Spin up the persistant pieces.
+ // Spin up the persistent pieces.
StartPi1Client();
StartPi2Server();
StartPi2Client();
@@ -1339,13 +1041,144 @@
StopPi1Server();
}
- // Shut everyone else down
+ // Shut everyone else down.
StopPi1Client();
StopPi2Server();
StopPi2Client();
pi1_remote_timestamp_thread.reset();
}
+// Tests that when multiple reliable messages are sent during a time when the
+// client is restarting that only the final of those messages makes it to the
+// client. This ensures that we handle a disconnecting & reconnecting client
+// correctly in the server reliable connection retry logic.
+TEST_P(MessageBridgeParameterizedTest, ReliableSentDuringClientReboot) {
+ OnPi1();
+
+ FLAGS_application_name = "sender";
+ aos::ShmEventLoop send_event_loop(&config.message());
+ aos::Sender<examples::Ping> ping_sender =
+ send_event_loop.MakeSender<examples::Ping>("/test");
+ size_t ping_index = 0;
+ SendPing(&ping_sender, ++ping_index);
+
+ MakePi1Server();
+ MakePi1Client();
+
+ FLAGS_application_name = "pi1_timestamp";
+ aos::ShmEventLoop pi1_remote_timestamp_event_loop(&config.message());
+
+ // Now do it for "raspberrypi2", the client.
+ OnPi2();
+
+ MakePi2Server();
+
+ aos::ShmEventLoop receive_event_loop(&config.message());
+ aos::Fetcher<examples::Ping> ping_fetcher =
+ receive_event_loop.MakeFetcher<examples::Ping>("/test");
+ aos::Fetcher<ClientStatistics> pi2_client_statistics_fetcher =
+ receive_event_loop.MakeFetcher<ClientStatistics>("/pi2/aos");
+
+ const size_t ping_channel_index = configuration::ChannelIndex(
+ receive_event_loop.configuration(), ping_fetcher.channel());
+
+ // ping_timestamp_count is accessed from multiple threads (the Watcher that
+ // triggers it is in a separate thread), so make it atomic.
+ std::atomic<int> ping_timestamp_count{0};
+ const std::string channel_name =
+ shared() ? "/pi1/aos/remote_timestamps/pi2"
+ : "/pi1/aos/remote_timestamps/pi2/test/aos-examples-Ping";
+ pi1_remote_timestamp_event_loop.MakeWatcher(
+ channel_name, [this, channel_name, ping_channel_index,
+ &ping_timestamp_count](const RemoteMessage &header) {
+ VLOG(1) << channel_name << " RemoteMessage "
+ << aos::FlatbufferToJson(&header);
+ EXPECT_TRUE(header.has_boot_uuid());
+ if (shared() && header.channel_index() != ping_channel_index) {
+ return;
+ }
+ CHECK_EQ(header.channel_index(), ping_channel_index);
+ ++ping_timestamp_count;
+ });
+
+ // Before everything starts up, confirm there is no message.
+ EXPECT_FALSE(ping_fetcher.Fetch());
+
+ // Spin up the persistent pieces.
+ StartPi1Server();
+ StartPi1Client();
+ StartPi2Server();
+
+ // Event used to wait for the timestamp counting thread to start.
+ std::unique_ptr<ThreadedEventLoopRunner> pi1_remote_timestamp_thread =
+ std::make_unique<ThreadedEventLoopRunner>(
+ &pi1_remote_timestamp_event_loop);
+
+ {
+ // Now, spin up a client for 2 seconds.
+ MakePi2Client();
+
+ RunPi2Client(chrono::milliseconds(2050));
+
+ // Confirm there is no detected duplicate packet.
+ EXPECT_TRUE(pi2_client_statistics_fetcher.Fetch());
+ EXPECT_EQ(pi2_client_statistics_fetcher->connections()
+ ->Get(0)
+ ->duplicate_packets(),
+ 0u);
+
+ EXPECT_EQ(pi2_client_statistics_fetcher->connections()
+ ->Get(0)
+ ->partial_deliveries(),
+ 0u);
+
+ EXPECT_TRUE(ping_fetcher.Fetch());
+ EXPECT_EQ(ping_timestamp_count, 1);
+
+ StopPi2Client();
+ }
+
+ // Send some reliable messages while the client is dead. Only the final one
+ // should make it through.
+ while (ping_index < 10) {
+ SendPing(&ping_sender, ++ping_index);
+ }
+
+ {
+ // Now, spin up a client for 2 seconds.
+ MakePi2Client();
+
+ RunPi2Client(chrono::milliseconds(5050));
+
+ // No duplicate packets should have appeared.
+ EXPECT_TRUE(pi2_client_statistics_fetcher.Fetch());
+ EXPECT_EQ(pi2_client_statistics_fetcher->connections()
+ ->Get(0)
+ ->duplicate_packets(),
+ 0u);
+
+ EXPECT_EQ(pi2_client_statistics_fetcher->connections()
+ ->Get(0)
+ ->partial_deliveries(),
+ 0u);
+
+ EXPECT_EQ(ping_timestamp_count, 2);
+ // We should have gotten precisely one more ping message--the latest one
+ // sent should've made it, but no previous ones.
+ EXPECT_TRUE(ping_fetcher.FetchNext());
+ EXPECT_EQ(ping_index, ping_fetcher->value());
+ EXPECT_FALSE(ping_fetcher.FetchNext());
+
+ StopPi2Client();
+ }
+
+ // Shut everyone else down.
+ StopPi1Client();
+ StopPi2Server();
+ pi1_remote_timestamp_thread.reset();
+ StopPi1Server();
+}
+
// Test that differing config sha256's result in no connection.
TEST_P(MessageBridgeParameterizedTest, MismatchedSha256) {
// This is rather annoying to set up. We need to start up a client and
@@ -1441,7 +1274,7 @@
StopPi2Client();
}
- // Shut everyone else down
+ // Shut everyone else down.
StopPi1Server();
StopPi1Client();
StopPi2Server();
@@ -1590,7 +1423,7 @@
StopPi2Client();
}
- // Shut everyone else down
+ // Shut everyone else down.
StopPi1Server();
StopPi1Client();
StopPi2Server();
diff --git a/aos/network/message_bridge_test_combined_timestamps_common.json b/aos/network/message_bridge_test_combined_timestamps_common.json
index 13a0514..a99c37c 100644
--- a/aos/network/message_bridge_test_combined_timestamps_common.json
+++ b/aos/network/message_bridge_test_combined_timestamps_common.json
@@ -100,6 +100,38 @@
"max_size": 2048
},
{
+ "name": "/pi1/aos",
+ "type": "aos.message_bridge.SctpConfig",
+ "source_node": "pi1",
+ "frequency": 10,
+ "num_senders": 1,
+ "max_size": 256
+ },
+ {
+ "name": "/pi2/aos",
+ "type": "aos.message_bridge.SctpConfig",
+ "source_node": "pi2",
+ "frequency": 10,
+ "num_senders": 1,
+ "max_size": 256
+ },
+ {
+ "name": "/pi1/aos",
+ "type": "aos.message_bridge.SctpConfigRequest",
+ "source_node": "pi1",
+ "frequency": 1,
+ "num_senders": 2,
+ "max_size": 32
+ },
+ {
+ "name": "/pi2/aos",
+ "type": "aos.message_bridge.SctpConfigRequest",
+ "source_node": "pi2",
+ "frequency": 1,
+ "num_senders": 2,
+ "max_size": 32
+ },
+ {
"name": "/test",
"type": "aos.examples.Ping",
"source_node": "pi1",
diff --git a/aos/network/message_bridge_test_common.json b/aos/network/message_bridge_test_common.json
index 9bb0863..09d19c8 100644
--- a/aos/network/message_bridge_test_common.json
+++ b/aos/network/message_bridge_test_common.json
@@ -121,6 +121,38 @@
"max_size": 2048
},
{
+ "name": "/pi1/aos",
+ "type": "aos.message_bridge.SctpConfig",
+ "source_node": "pi1",
+ "frequency": 10,
+ "num_senders": 1,
+ "max_size": 256
+ },
+ {
+ "name": "/pi2/aos",
+ "type": "aos.message_bridge.SctpConfig",
+ "source_node": "pi2",
+ "frequency": 10,
+ "num_senders": 1,
+ "max_size": 256
+ },
+ {
+ "name": "/pi1/aos",
+ "type": "aos.message_bridge.SctpConfigRequest",
+ "source_node": "pi1",
+ "frequency": 1,
+ "num_senders": 2,
+ "max_size": 32
+ },
+ {
+ "name": "/pi2/aos",
+ "type": "aos.message_bridge.SctpConfigRequest",
+ "source_node": "pi2",
+ "frequency": 1,
+ "num_senders": 2,
+ "max_size": 32
+ },
+ {
"name": "/test",
"type": "aos.examples.Ping",
"source_node": "pi1",
diff --git a/aos/network/message_bridge_test_lib.cc b/aos/network/message_bridge_test_lib.cc
new file mode 100644
index 0000000..d5a5f8f
--- /dev/null
+++ b/aos/network/message_bridge_test_lib.cc
@@ -0,0 +1,266 @@
+#include "aos/network/message_bridge_test_lib.h"
+
+DECLARE_string(boot_uuid);
+
+namespace aos {
+void SetShmBase(const std::string_view base);
+
+namespace message_bridge::testing {
+
+namespace chrono = std::chrono;
+using aos::testing::ArtifactPath;
+
+std::string ShmBase(const std::string_view node) {
+ const char *const tmpdir_c_str = getenv("TEST_TMPDIR");
+ if (tmpdir_c_str != nullptr) {
+ return absl::StrCat(tmpdir_c_str, "/", node);
+ } else {
+ return absl::StrCat("/dev/shm/", node);
+ }
+}
+
+void DoSetShmBase(const std::string_view node) {
+ aos::SetShmBase(ShmBase(node));
+}
+
+ThreadedEventLoopRunner::ThreadedEventLoopRunner(aos::ShmEventLoop *event_loop)
+ : event_loop_(event_loop), my_thread_([this]() {
+ LOG(INFO) << "Started " << event_loop_->name();
+ event_loop_->OnRun([this]() { event_.Set(); });
+ event_loop_->Run();
+ }) {
+ event_.Wait();
+}
+
+ThreadedEventLoopRunner::~ThreadedEventLoopRunner() { Exit(); }
+
+void ThreadedEventLoopRunner::Exit() {
+ if (my_thread_.joinable()) {
+ event_loop_->Exit();
+ my_thread_.join();
+ my_thread_ = std::thread();
+ }
+}
+
+MessageBridgeParameterizedTest::MessageBridgeParameterizedTest()
+ : config(aos::configuration::ReadConfig(
+ ArtifactPath(absl::StrCat("aos/network/", GetParam().config)))),
+ config_sha256(Sha256(config.span())),
+ pi1_boot_uuid_(UUID::Random()),
+ pi2_boot_uuid_(UUID::Random()) {
+ // Make sure that we clean up all the shared memory queues so that we cannot
+ // inadvertently be influenced other tests or by previously run AOS
+ // applications (in a fully sharded test running inside the bazel sandbox,
+ // this should not matter).
+ util::UnlinkRecursive(ShmBase("pi1"));
+ util::UnlinkRecursive(ShmBase("pi2"));
+}
+
+bool MessageBridgeParameterizedTest::shared() const {
+ return GetParam().shared;
+}
+
+void MessageBridgeParameterizedTest::OnPi1() {
+ DoSetShmBase("pi1");
+ FLAGS_override_hostname = "raspberrypi";
+ FLAGS_boot_uuid = pi1_boot_uuid_.ToString();
+}
+
+void MessageBridgeParameterizedTest::OnPi2() {
+ DoSetShmBase("pi2");
+ FLAGS_override_hostname = "raspberrypi2";
+ FLAGS_boot_uuid = pi2_boot_uuid_.ToString();
+}
+
+void MessageBridgeParameterizedTest::MakePi1Server(
+ std::string server_config_sha256) {
+ OnPi1();
+ FLAGS_application_name = "pi1_message_bridge_server";
+ pi1_server_event_loop =
+ std::make_unique<aos::ShmEventLoop>(&config.message());
+ pi1_server_event_loop->SetRuntimeRealtimePriority(1);
+ pi1_message_bridge_server = std::make_unique<MessageBridgeServer>(
+ pi1_server_event_loop.get(),
+ server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256,
+ SctpAuthMethod::kNoAuth);
+}
+
+void MessageBridgeParameterizedTest::RunPi1Server(
+ chrono::nanoseconds duration) {
+ // Set up a shutdown callback.
+ aos::TimerHandler *const quit = pi1_server_event_loop->AddTimer(
+ [this]() { pi1_server_event_loop->Exit(); });
+ pi1_server_event_loop->OnRun([this, quit, duration]() {
+ // Stop between timestamps, not exactly on them.
+ quit->Schedule(pi1_server_event_loop->monotonic_now() + duration);
+ });
+
+ pi1_server_event_loop->Run();
+}
+
+void MessageBridgeParameterizedTest::StartPi1Server() {
+ pi1_server_thread =
+ std::make_unique<ThreadedEventLoopRunner>(pi1_server_event_loop.get());
+}
+
+void MessageBridgeParameterizedTest::StopPi1Server() {
+ pi1_server_thread.reset();
+ pi1_message_bridge_server.reset();
+ pi1_server_event_loop.reset();
+}
+
+void MessageBridgeParameterizedTest::MakePi1Client() {
+ OnPi1();
+ FLAGS_application_name = "pi1_message_bridge_client";
+ pi1_client_event_loop =
+ std::make_unique<aos::ShmEventLoop>(&config.message());
+ pi1_client_event_loop->SetRuntimeRealtimePriority(1);
+ pi1_message_bridge_client = std::make_unique<MessageBridgeClient>(
+ pi1_client_event_loop.get(), config_sha256, SctpAuthMethod::kNoAuth);
+}
+
+void MessageBridgeParameterizedTest::StartPi1Client() {
+ pi1_client_thread =
+ std::make_unique<ThreadedEventLoopRunner>(pi1_client_event_loop.get());
+}
+
+void MessageBridgeParameterizedTest::StopPi1Client() {
+ pi1_client_thread.reset();
+ pi1_message_bridge_client.reset();
+ pi1_client_event_loop.reset();
+}
+
+void MessageBridgeParameterizedTest::MakePi1Test() {
+ OnPi1();
+ FLAGS_application_name = "test1";
+ pi1_test_event_loop = std::make_unique<aos::ShmEventLoop>(&config.message());
+
+ pi1_test_event_loop->MakeWatcher(
+ "/pi1/aos", [](const ServerStatistics &stats) {
+ VLOG(1) << "/pi1/aos ServerStatistics " << FlatbufferToJson(&stats);
+ });
+
+ pi1_test_event_loop->MakeWatcher(
+ "/pi1/aos", [](const ClientStatistics &stats) {
+ VLOG(1) << "/pi1/aos ClientStatistics " << FlatbufferToJson(&stats);
+ });
+
+ pi1_test_event_loop->MakeWatcher("/pi1/aos", [](const Timestamp ×tamp) {
+ VLOG(1) << "/pi1/aos Timestamp " << FlatbufferToJson(×tamp);
+ });
+ pi1_test_event_loop->MakeWatcher("/pi2/aos", [this](
+ const Timestamp ×tamp) {
+ VLOG(1) << "/pi2/aos Timestamp " << FlatbufferToJson(×tamp);
+ EXPECT_EQ(pi1_test_event_loop->context().source_boot_uuid, pi2_boot_uuid_);
+ });
+}
+
+void MessageBridgeParameterizedTest::StartPi1Test() {
+ pi1_test_thread =
+ std::make_unique<ThreadedEventLoopRunner>(pi1_test_event_loop.get());
+}
+
+void MessageBridgeParameterizedTest::StopPi1Test() { pi1_test_thread.reset(); }
+
+void MessageBridgeParameterizedTest::MakePi2Server() {
+ OnPi2();
+ FLAGS_application_name = "pi2_message_bridge_server";
+ pi2_server_event_loop =
+ std::make_unique<aos::ShmEventLoop>(&config.message());
+ pi2_server_event_loop->SetRuntimeRealtimePriority(1);
+ pi2_message_bridge_server = std::make_unique<MessageBridgeServer>(
+ pi2_server_event_loop.get(), config_sha256, SctpAuthMethod::kNoAuth);
+}
+
+void MessageBridgeParameterizedTest::RunPi2Server(
+ chrono::nanoseconds duration) {
+ // Schedule a shutdown callback.
+ aos::TimerHandler *const quit = pi2_server_event_loop->AddTimer(
+ [this]() { pi2_server_event_loop->Exit(); });
+ pi2_server_event_loop->OnRun([this, quit, duration]() {
+ // Stop between timestamps, not exactly on them.
+ quit->Schedule(pi2_server_event_loop->monotonic_now() + duration);
+ });
+
+ pi2_server_event_loop->Run();
+}
+
+void MessageBridgeParameterizedTest::StartPi2Server() {
+ pi2_server_thread =
+ std::make_unique<ThreadedEventLoopRunner>(pi2_server_event_loop.get());
+}
+
+void MessageBridgeParameterizedTest::StopPi2Server() {
+ pi2_server_thread.reset();
+ pi2_message_bridge_server.reset();
+ pi2_server_event_loop.reset();
+}
+
+void MessageBridgeParameterizedTest::MakePi2Client() {
+ OnPi2();
+ FLAGS_application_name = "pi2_message_bridge_client";
+ pi2_client_event_loop =
+ std::make_unique<aos::ShmEventLoop>(&config.message());
+ pi2_client_event_loop->SetRuntimeRealtimePriority(1);
+ pi2_message_bridge_client = std::make_unique<MessageBridgeClient>(
+ pi2_client_event_loop.get(), config_sha256, SctpAuthMethod::kNoAuth);
+}
+
+void MessageBridgeParameterizedTest::RunPi2Client(
+ chrono::nanoseconds duration) {
+ // Run for 5 seconds to make sure we have time to estimate the offset.
+ aos::TimerHandler *const quit = pi2_client_event_loop->AddTimer(
+ [this]() { pi2_client_event_loop->Exit(); });
+ pi2_client_event_loop->OnRun([this, quit, duration]() {
+ // Stop between timestamps, not exactly on them.
+ quit->Schedule(pi2_client_event_loop->monotonic_now() + duration);
+ });
+
+ // And go!
+ pi2_client_event_loop->Run();
+}
+
+void MessageBridgeParameterizedTest::StartPi2Client() {
+ pi2_client_thread =
+ std::make_unique<ThreadedEventLoopRunner>(pi2_client_event_loop.get());
+}
+
+void MessageBridgeParameterizedTest::StopPi2Client() {
+ pi2_client_thread.reset();
+ pi2_message_bridge_client.reset();
+ pi2_client_event_loop.reset();
+}
+
+void MessageBridgeParameterizedTest::MakePi2Test() {
+ OnPi2();
+ FLAGS_application_name = "test2";
+ pi2_test_event_loop = std::make_unique<aos::ShmEventLoop>(&config.message());
+
+ pi2_test_event_loop->MakeWatcher(
+ "/pi2/aos", [](const ServerStatistics &stats) {
+ VLOG(1) << "/pi2/aos ServerStatistics " << FlatbufferToJson(&stats);
+ });
+
+ pi2_test_event_loop->MakeWatcher(
+ "/pi2/aos", [](const ClientStatistics &stats) {
+ VLOG(1) << "/pi2/aos ClientStatistics " << FlatbufferToJson(&stats);
+ });
+
+ pi2_test_event_loop->MakeWatcher("/pi1/aos", [this](
+ const Timestamp ×tamp) {
+ VLOG(1) << "/pi1/aos Timestamp " << FlatbufferToJson(×tamp);
+ EXPECT_EQ(pi2_test_event_loop->context().source_boot_uuid, pi1_boot_uuid_);
+ });
+ pi2_test_event_loop->MakeWatcher("/pi2/aos", [](const Timestamp ×tamp) {
+ VLOG(1) << "/pi2/aos Timestamp " << FlatbufferToJson(×tamp);
+ });
+}
+
+void MessageBridgeParameterizedTest::StartPi2Test() {
+ pi2_test_thread =
+ std::make_unique<ThreadedEventLoopRunner>(pi2_test_event_loop.get());
+}
+
+void MessageBridgeParameterizedTest::StopPi2Test() { pi2_test_thread.reset(); }
+} // namespace message_bridge::testing
+} // namespace aos
diff --git a/aos/network/message_bridge_test_lib.h b/aos/network/message_bridge_test_lib.h
new file mode 100644
index 0000000..047cff5
--- /dev/null
+++ b/aos/network/message_bridge_test_lib.h
@@ -0,0 +1,137 @@
+#ifndef AOS_NETWORK_MESSAGE_BRIDGE_TEST_LIB_H_
+#define AOS_NETWORK_MESSAGE_BRIDGE_TEST_LIB_H_
+#include <chrono>
+#include <thread>
+
+#include "absl/strings/str_cat.h"
+#include "gtest/gtest.h"
+
+#include "aos/events/ping_generated.h"
+#include "aos/events/pong_generated.h"
+#include "aos/ipc_lib/event.h"
+#include "aos/network/message_bridge_client_lib.h"
+#include "aos/network/message_bridge_protocol.h"
+#include "aos/network/message_bridge_server_lib.h"
+#include "aos/network/team_number.h"
+#include "aos/sha256.h"
+#include "aos/testing/path.h"
+#include "aos/util/file.h"
+namespace aos::message_bridge::testing {
+
+namespace chrono = std::chrono;
+
+// Class to manage starting and stopping a thread with an event loop in it. The
+// thread is guarenteed to be running before the constructor exits.
+class ThreadedEventLoopRunner {
+ public:
+ ThreadedEventLoopRunner(aos::ShmEventLoop *event_loop);
+
+ ~ThreadedEventLoopRunner();
+
+ void Exit();
+
+ private:
+ aos::Event event_;
+ aos::ShmEventLoop *event_loop_;
+ std::thread my_thread_;
+};
+
+// Parameters to run all the tests with.
+struct Param {
+ // The config file to use.
+ std::string config;
+ // If true, the RemoteMessage channel should be shared between all the remote
+ // channels. If false, there will be 1 RemoteMessage channel per remote
+ // channel.
+ bool shared;
+};
+
+class MessageBridgeParameterizedTest
+ : public ::testing::TestWithParam<struct Param> {
+ protected:
+ MessageBridgeParameterizedTest();
+
+ bool shared() const;
+
+ // OnPi* sets the global state necessary to pretend that a ShmEventLoop is on
+ // the requisite system.
+ void OnPi1();
+
+ void OnPi2();
+
+ void MakePi1Server(std::string server_config_sha256 = "");
+
+ void RunPi1Server(chrono::nanoseconds duration);
+
+ void StartPi1Server();
+
+ void StopPi1Server();
+
+ void MakePi1Client();
+
+ void StartPi1Client();
+
+ void StopPi1Client();
+
+ void MakePi1Test();
+
+ void StartPi1Test();
+
+ void StopPi1Test();
+
+ void MakePi2Server();
+
+ void RunPi2Server(chrono::nanoseconds duration);
+
+ void StartPi2Server();
+
+ void StopPi2Server();
+
+ void MakePi2Client();
+
+ void RunPi2Client(chrono::nanoseconds duration);
+
+ void StartPi2Client();
+
+ void StopPi2Client();
+
+ void MakePi2Test();
+
+ void StartPi2Test();
+
+ void StopPi2Test();
+
+ gflags::FlagSaver flag_saver_;
+
+ aos::FlatbufferDetachedBuffer<aos::Configuration> config;
+ std::string config_sha256;
+
+ const UUID pi1_boot_uuid_;
+ const UUID pi2_boot_uuid_;
+
+ std::unique_ptr<aos::ShmEventLoop> pi1_server_event_loop;
+ std::unique_ptr<MessageBridgeServer> pi1_message_bridge_server;
+ std::unique_ptr<ThreadedEventLoopRunner> pi1_server_thread;
+
+ std::unique_ptr<aos::ShmEventLoop> pi1_client_event_loop;
+ std::unique_ptr<MessageBridgeClient> pi1_message_bridge_client;
+ std::unique_ptr<ThreadedEventLoopRunner> pi1_client_thread;
+
+ std::unique_ptr<aos::ShmEventLoop> pi1_test_event_loop;
+ std::unique_ptr<ThreadedEventLoopRunner> pi1_test_thread;
+
+ std::unique_ptr<aos::ShmEventLoop> pi2_server_event_loop;
+ std::unique_ptr<MessageBridgeServer> pi2_message_bridge_server;
+ std::unique_ptr<ThreadedEventLoopRunner> pi2_server_thread;
+
+ std::unique_ptr<aos::ShmEventLoop> pi2_client_event_loop;
+ std::unique_ptr<MessageBridgeClient> pi2_message_bridge_client;
+ std::unique_ptr<ThreadedEventLoopRunner> pi2_client_thread;
+
+ std::unique_ptr<aos::ShmEventLoop> pi2_test_event_loop;
+ std::unique_ptr<ThreadedEventLoopRunner> pi2_test_thread;
+};
+
+} // namespace aos::message_bridge::testing
+
+#endif // AOS_NETWORK_MESSAGE_BRIDGE_TEST_LIB_H_
diff --git a/aos/network/sctp_client.cc b/aos/network/sctp_client.cc
index b7b7a93..fa1828d 100644
--- a/aos/network/sctp_client.cc
+++ b/aos/network/sctp_client.cc
@@ -1,8 +1,8 @@
#include "aos/network/sctp_client.h"
#include <arpa/inet.h>
+#include <linux/sctp.h>
#include <net/if.h>
-#include <netinet/sctp.h>
#include <sys/socket.h>
#include <cstdlib>
@@ -22,8 +22,9 @@
namespace message_bridge {
SctpClient::SctpClient(std::string_view remote_host, int remote_port,
- int streams, std::string_view local_host,
- int local_port) {
+ int streams, std::string_view local_host, int local_port,
+ SctpAuthMethod requested_authentication)
+ : sctp_(requested_authentication) {
bool use_ipv6 = Ipv6Enabled();
sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
sockaddr_remote_ = ResolveSocket(remote_host, remote_port, use_ipv6);
diff --git a/aos/network/sctp_client.h b/aos/network/sctp_client.h
index 5affecc..06f6b15 100644
--- a/aos/network/sctp_client.h
+++ b/aos/network/sctp_client.h
@@ -5,6 +5,7 @@
#include <cstdlib>
#include <string_view>
+#include "absl/types/span.h"
#include "glog/logging.h"
#include "aos/network/sctp_lib.h"
@@ -17,7 +18,8 @@
class SctpClient {
public:
SctpClient(std::string_view remote_host, int remote_port, int streams,
- std::string_view local_host = "0.0.0.0", int local_port = 9971);
+ std::string_view local_host = "0.0.0.0", int local_port = 9971,
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth);
~SctpClient() {}
@@ -60,6 +62,10 @@
sctp_.FreeMessage(std::move(message));
}
+ void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+ sctp_.SetAuthKey(auth_key);
+ }
+
private:
struct sockaddr_storage sockaddr_remote_;
struct sockaddr_storage sockaddr_local_;
diff --git a/aos/network/sctp_config.fbs b/aos/network/sctp_config.fbs
new file mode 100644
index 0000000..4c1819b
--- /dev/null
+++ b/aos/network/sctp_config.fbs
@@ -0,0 +1,10 @@
+namespace aos.message_bridge;
+
+// SCTP Configuration options for message bridge.
+table SctpConfig {
+ // The authentication key to use.
+ key:[ubyte] (id: 0);
+}
+
+root_type SctpConfig;
+
diff --git a/aos/network/sctp_config_request.fbs b/aos/network/sctp_config_request.fbs
new file mode 100644
index 0000000..196589b
--- /dev/null
+++ b/aos/network/sctp_config_request.fbs
@@ -0,0 +1,10 @@
+namespace aos.message_bridge;
+
+// SCTP configuration requests for message bridge.
+table SctpConfigRequest {
+ // When set, the authentication key is being requested.
+ request_key:bool (id: 0);
+}
+
+root_type SctpConfigRequest;
+
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 2658271..c3c6e23 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -1,15 +1,19 @@
#include "aos/network/sctp_lib.h"
#include <arpa/inet.h>
+#include <linux/sctp.h>
#include <net/if.h>
#include <netdb.h>
-#include <netinet/sctp.h>
+#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
+#include <cerrno>
+#include <fstream>
#include <string_view>
+#include <vector>
#include "aos/util/file.h"
@@ -30,6 +34,33 @@
struct sctp_sndrcvinfo sndrcvinfo;
} _sctp_cmsg_data_t;
+#if HAS_SCTP_AUTH
+// Returns true if SCTP authentication is available and enabled.
+bool SctpAuthIsEnabled() {
+ struct stat current_stat;
+ if (stat("/proc/sys/net/sctp/auth_enable", ¤t_stat) != -1) {
+ int value = std::stoi(
+ util::ReadFileToStringOrDie("/proc/sys/net/sctp/auth_enable"));
+ CHECK(value == 0 || value == 1)
+ << "Unknown auth enable sysctl value: " << value;
+ return value == 1;
+ } else {
+ LOG(WARNING) << "/proc/sys/net/sctp/auth_enable doesn't exist.";
+ return false;
+ }
+}
+
+std::vector<uint8_t> GenerateSecureRandomSequence(size_t count) {
+ std::ifstream rng("/dev/random", std::ios::in | std::ios::binary);
+ CHECK(rng) << "Unable to open /dev/random";
+ std::vector<uint8_t> out(count, 0);
+ rng.read(reinterpret_cast<char *>(out.data()), count);
+ CHECK(rng) << "Couldn't read from random device";
+ rng.close();
+ return out;
+}
+#endif
+
} // namespace
bool Ipv6Enabled() {
@@ -211,7 +242,7 @@
status.sstat_assoc_id = assoc_id;
socklen_t size = sizeof(status);
- const int result = getsockopt(fd, SOL_SCTP, SCTP_STATUS,
+ const int result = getsockopt(fd, IPPROTO_SCTP, SCTP_STATUS,
reinterpret_cast<void *>(&status), &size);
if (result == -1 && errno == EINVAL) {
LOG(INFO) << "sctp_status) not associated";
@@ -265,10 +296,37 @@
subscribe.sctp_association_event = 1;
subscribe.sctp_stream_change_event = 1;
subscribe.sctp_partial_delivery_event = 1;
- PCHECK(setsockopt(fd(), SOL_SCTP, SCTP_EVENTS, (char *)&subscribe,
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_EVENTS, (char *)&subscribe,
sizeof(subscribe)) == 0);
}
+#if HAS_SCTP_AUTH
+ if (sctp_authentication_) {
+ CHECK(SctpAuthIsEnabled())
+ << "SCTP Authentication key requested, but authentication isn't "
+ "enabled... Use `sysctl -w net.sctp.auth_enable=1` to enable";
+
+ // Unfortunately there's no way to delete the null key if we don't have
+ // another key active so this is the only way to prevent unauthenticated
+ // traffic until the real shared key is established.
+ SetAuthKey(GenerateSecureRandomSequence(16));
+
+ // Disallow the null key.
+ struct sctp_authkeyid authkeyid;
+ authkeyid.scact_keynumber = 0;
+ authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_DELETE_KEY, &authkeyid,
+ sizeof(authkeyid)) == 0);
+
+ // Set up authentication for data chunks.
+ struct sctp_authchunk authchunk;
+ authchunk.sauth_chunk = 0;
+
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_CHUNK, &authchunk,
+ sizeof(authchunk)) == 0);
+ }
+#endif
+
DoSetMaxSize();
}
@@ -277,6 +335,8 @@
std::optional<struct sockaddr_storage> sockaddr_remote,
sctp_assoc_t snd_assoc_id) {
CHECK(fd_ != -1);
+ LOG_IF(FATAL, sctp_authentication_ && current_key_.empty())
+ << "Expected SCTP authentication but no key active";
struct iovec iov;
iov.iov_base = const_cast<char *>(data.data());
iov.iov_len = data.size();
@@ -371,6 +431,8 @@
// fragmented. If we do end up with a fragment, then we copy the data out of it.
aos::unique_c_ptr<Message> SctpReadWrite::ReadMessage() {
CHECK(fd_ != -1);
+ LOG_IF(FATAL, sctp_authentication_ && current_key_.empty())
+ << "Expected SCTP authentication but no key active";
while (true) {
aos::unique_c_ptr<Message> result = AcquireMessage();
@@ -636,6 +698,55 @@
return false;
}
+void SctpReadWrite::SetAuthKey(absl::Span<const uint8_t> auth_key) {
+ PCHECK(fd_ != -1);
+ if (auth_key.empty()) {
+ return;
+ }
+ // We are already using the key, nothing to do.
+ if (auth_key == current_key_) {
+ return;
+ }
+#if !(HAS_SCTP_AUTH)
+ LOG(FATAL) << "SCTP Authentication key requested, but authentication isn't "
+ "available... You may need a newer kernel";
+#else
+ LOG_IF(FATAL, !SctpAuthIsEnabled())
+ << "SCTP Authentication key requested, but authentication isn't "
+ "enabled... Use `sysctl -w net.sctp.auth_enable=1` to enable";
+ // Set up the key with id `1`.
+ std::unique_ptr<sctp_authkey> authkey(
+ (sctp_authkey *)malloc(sizeof(sctp_authkey) + auth_key.size()));
+
+ authkey->sca_keynumber = 1;
+ authkey->sca_keylength = auth_key.size();
+ authkey->sca_assoc_id = SCTP_ALL_ASSOC;
+ memcpy(&authkey->sca_key, auth_key.data(), auth_key.size());
+
+ if (setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_KEY, authkey.get(),
+ sizeof(sctp_authkey) + auth_key.size()) != 0) {
+ if (errno == EACCES) {
+ // TODO(adam.snaider): Figure out why this fails when expected nodes are
+ // not connected.
+ PLOG_EVERY_N(ERROR, 100) << "Setting authentication key failed";
+ return;
+ } else {
+ PLOG(FATAL) << "Setting authentication key failed";
+ }
+ }
+
+ // Set key `1` as active.
+ struct sctp_authkeyid authkeyid;
+ authkeyid.scact_keynumber = 1;
+ authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
+ if (setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_ACTIVE_KEY, &authkeyid,
+ sizeof(authkeyid)) != 0) {
+ PLOG(FATAL) << "Setting key id `1` as active failed";
+ }
+ current_key_.assign(auth_key.begin(), auth_key.end());
+#endif
+} // namespace message_bridge
+
void Message::LogRcvInfo() const {
LOG(INFO) << "\tSNDRCV (stream=" << header.rcvinfo.rcv_sid
<< " ssn=" << header.rcvinfo.rcv_ssn
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index 14509e4..0d021a9 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -2,7 +2,8 @@
#define AOS_NETWORK_SCTP_LIB_H_
#include <arpa/inet.h>
-#include <netinet/sctp.h>
+#include <linux/sctp.h>
+#include <linux/version.h>
#include <memory>
#include <optional>
@@ -10,14 +11,19 @@
#include <string_view>
#include <vector>
+#include "absl/types/span.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "aos/unique_malloc_ptr.h"
+#define HAS_SCTP_AUTH LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 0)
+
namespace aos {
namespace message_bridge {
+constexpr bool HasSctpAuth() { return HAS_SCTP_AUTH; }
+
// Check if ipv6 is enabled.
// If we don't try IPv6, and omit AI_ADDRCONFIG when resolving addresses, the
// library will happily resolve nodes to IPv6 IPs that can't be used. If we add
@@ -89,10 +95,31 @@
// Gets and logs the contents of the sctp_status message.
void LogSctpStatus(int fd, sctp_assoc_t assoc_id);
+// Authentication method used for the SCTP socket.
+enum class SctpAuthMethod {
+ // Use unauthenticated sockets.
+ kNoAuth,
+ // Use RFC4895 authentication for SCTP.
+ kAuth,
+};
+
// Manages reading and writing SCTP messages.
class SctpReadWrite {
public:
- SctpReadWrite() = default;
+ // When `requested_authentication` is kAuth, it will use SCTP authentication
+ // if it's provided by the kernel. Note that this will ignore the value of
+ // `requested_authentication` if the kernel is too old and will fall back to
+ // an unauthenticated channel.
+ SctpReadWrite(
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth)
+ : sctp_authentication_(HasSctpAuth() ? requested_authentication ==
+ SctpAuthMethod::kAuth
+ : false) {
+ LOG_IF(WARNING,
+ requested_authentication == SctpAuthMethod::kAuth && !HasSctpAuth())
+ << "SCTP authentication requested but not provided by the kernel... "
+ "You may need a newer kernel";
+ }
~SctpReadWrite() { CloseSocket(); }
// Opens a new socket.
@@ -139,6 +166,9 @@
// Allocates messages for the pool. SetMaxSize must be set first.
void SetPoolSize(size_t pool_size);
+ // Set the active authentication key to `auth_key`.
+ void SetAuthKey(absl::Span<const uint8_t> auth_key);
+
private:
aos::unique_c_ptr<Message> AcquireMessage();
@@ -161,6 +191,10 @@
bool use_pool_ = false;
std::vector<aos::unique_c_ptr<Message>> free_messages_;
+
+ // Use SCTP authentication (RFC4895).
+ bool sctp_authentication_;
+ std::vector<uint8_t> current_key_;
};
// Returns the max network buffer available for reading for a socket.
diff --git a/aos/network/sctp_perf.cc b/aos/network/sctp_perf.cc
index cce4bed..5201f47 100644
--- a/aos/network/sctp_perf.cc
+++ b/aos/network/sctp_perf.cc
@@ -6,6 +6,7 @@
#include "aos/events/shm_event_loop.h"
#include "aos/init.h"
#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_lib.h"
#include "aos/network/sctp_server.h"
DEFINE_string(config, "aos_config.json", "Path to the config.");
@@ -21,16 +22,40 @@
DEFINE_uint32(skip_first_n, 10,
"Skip the first 'n' messages when computing statistics.");
+DEFINE_string(sctp_auth_key_file, "",
+ "When set, use the provided key for SCTP authentication as "
+ "defined in RFC 4895");
+
DECLARE_bool(die_on_malloc);
namespace aos::message_bridge::perf {
+namespace {
+
+using util::ReadFileToVecOrDie;
+
+SctpAuthMethod SctpAuthMethod() {
+ return FLAGS_sctp_auth_key_file.empty() ? SctpAuthMethod::kNoAuth
+ : SctpAuthMethod::kAuth;
+}
+
+std::vector<uint8_t> GetSctpAuthKey() {
+ if (SctpAuthMethod() == SctpAuthMethod::kNoAuth) {
+ return {};
+ }
+ return ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+}
+
+} // namespace
+
namespace chrono = std::chrono;
class Server {
public:
Server(aos::ShmEventLoop *event_loop)
- : event_loop_(event_loop), server_(2, "0.0.0.0", FLAGS_port) {
+ : event_loop_(event_loop),
+ server_(2, "0.0.0.0", FLAGS_port, SctpAuthMethod()) {
+ server_.SetAuthKey(GetSctpAuthKey());
event_loop_->epoll()->OnReadable(server_.fd(),
[this]() { MessageReceived(); });
server_.SetMaxReadSize(FLAGS_rx_size + 100);
@@ -109,7 +134,10 @@
class Client {
public:
Client(aos::ShmEventLoop *event_loop)
- : event_loop_(event_loop), client_(FLAGS_host, FLAGS_port, 2) {
+ : event_loop_(event_loop),
+ client_(FLAGS_host, FLAGS_port, 2, "0.0.0.0", FLAGS_port,
+ SctpAuthMethod()) {
+ client_.SetAuthKey(GetSctpAuthKey());
client_.SetMaxReadSize(FLAGS_rx_size + 100);
client_.SetMaxWriteSize(FLAGS_rx_size + 100);
@@ -196,8 +224,8 @@
double throughput = FLAGS_payload_size * 2.0 / elapsed_secs;
double avg_throughput = FLAGS_payload_size * 2.0 / avg_latency_;
printf(
- "Round trip: %.2fms | %.2f KB/s | Avg RTL: %.2fms | %.2f KB/s | Count: "
- "%d\n",
+ "Round trip: %.2fms | %.2f KB/s | Avg RTL: %.2fms | %.2f KB/s | "
+ "Count: %d\n",
elapsed_secs * 1000, throughput / 1024, avg_latency_ * 1000,
avg_throughput / 1024, count_);
}
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index cfc8512..f90b21b 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -1,10 +1,10 @@
#include "aos/network/sctp_server.h"
#include <arpa/inet.h>
+#include <linux/sctp.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
-#include <netinet/sctp.h>
#include <sys/socket.h>
#include <cstdio>
@@ -21,8 +21,9 @@
namespace aos {
namespace message_bridge {
-SctpServer::SctpServer(int streams, std::string_view local_host,
- int local_port) {
+SctpServer::SctpServer(int streams, std::string_view local_host, int local_port,
+ SctpAuthMethod requested_authentication)
+ : sctp_(requested_authentication) {
bool use_ipv6 = Ipv6Enabled();
sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
while (true) {
diff --git a/aos/network/sctp_server.h b/aos/network/sctp_server.h
index 6d0b44e..e641fd8 100644
--- a/aos/network/sctp_server.h
+++ b/aos/network/sctp_server.h
@@ -2,10 +2,10 @@
#define AOS_NETWORK_SCTP_SERVER_H_
#include <arpa/inet.h>
+#include <linux/sctp.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
-#include <netinet/sctp.h>
#include <sys/socket.h>
#include <cstdio>
@@ -13,6 +13,7 @@
#include <cstring>
#include <memory>
+#include "absl/types/span.h"
#include "glog/logging.h"
#include "aos/network/sctp_lib.h"
@@ -24,7 +25,8 @@
class SctpServer {
public:
SctpServer(int streams, std::string_view local_host = "0.0.0.0",
- int local_port = 9971);
+ int local_port = 9971,
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth);
~SctpServer() {}
@@ -63,6 +65,10 @@
void SetPoolSize(size_t pool_size) { sctp_.SetPoolSize(pool_size); }
+ void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+ sctp_.SetAuthKey(auth_key);
+ }
+
private:
struct sockaddr_storage sockaddr_local_;
SctpReadWrite sctp_;
diff --git a/aos/network/sctp_test.cc b/aos/network/sctp_test.cc
new file mode 100644
index 0000000..8e332e4
--- /dev/null
+++ b/aos/network/sctp_test.cc
@@ -0,0 +1,367 @@
+#include <unistd.h>
+
+#include <chrono>
+#include <functional>
+
+#include "gflags/gflags.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+#include "aos/events/epoll.h"
+#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_lib.h"
+#include "aos/network/sctp_server.h"
+
+DECLARE_bool(disable_ipv6);
+
+namespace aos::message_bridge::testing {
+
+using ::aos::internal::EPoll;
+using ::aos::internal::TimerFd;
+using ::testing::ElementsAre;
+
+using namespace ::std::chrono_literals;
+
+constexpr int kPort = 19423;
+constexpr int kStreams = 1;
+
+namespace {
+void EnableSctpAuthIfAvailable() {
+#if HAS_SCTP_AUTH
+ CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || /sbin/sysctl "
+ "net.sctp.auth_enable=1") == 0)
+ << "Couldn't enable sctp authentication.";
+#endif
+}
+} // namespace
+
+// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
+// SctpClient), and an `sctp_notification` handler and a `message` handler. It
+// asynchronously routes incoming messages to the appropriate handler.
+template <typename T>
+class SctpReceiver {
+ public:
+ SctpReceiver(
+ EPoll &epoll, T &receiver,
+ std::function<void(T &, const union sctp_notification *)> on_notify,
+ std::function<void(T &, std::vector<uint8_t>)> on_message)
+ : epoll_(epoll),
+ receiver_(receiver),
+ on_notify_(std::move(on_notify)),
+ on_message_(std::move(on_message)) {
+ epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
+ }
+
+ ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
+
+ private:
+ // Handles an incoming message by routing it to the apropriate handler.
+ void Read() {
+ aos::unique_c_ptr<Message> message = receiver_.Read();
+ if (!message) {
+ return;
+ }
+
+ switch (message->message_type) {
+ case Message::kNotification: {
+ const union sctp_notification *notification =
+ reinterpret_cast<const union sctp_notification *>(message->data());
+ on_notify_(receiver_, notification);
+ break;
+ }
+ case Message::kMessage:
+ on_message_(receiver_, std::vector(message->data(),
+ message->data() + message->size));
+ break;
+ case Message::kOverflow:
+ LOG(FATAL) << "Overflow";
+ }
+ receiver_.FreeMessage(std::move(message));
+ }
+
+ EPoll &epoll_;
+ T &receiver_;
+ std::function<void(T &, const union sctp_notification *)> on_notify_;
+ std::function<void(T &, std::vector<uint8_t>)> on_message_;
+};
+
+// Base SctpTest class.
+//
+// The class provides a few virtual methods that should be overriden to define
+// the behavior of the test.
+class SctpTest : public ::testing::Test {
+ public:
+ SctpTest(std::vector<uint8_t> server_key = {},
+ std::vector<uint8_t> client_key = {},
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
+ std::chrono::milliseconds timeout = 1000ms)
+ : server_(kStreams, "", kPort, requested_authentication),
+ client_("localhost", kPort, kStreams, "", 0, requested_authentication),
+ client_receiver_(
+ epoll_, client_,
+ [this](SctpClient &client,
+ const union sctp_notification *notification) {
+ HandleNotification(client, notification);
+ },
+ [this](SctpClient &client, std::vector<uint8_t> message) {
+ HandleMessage(client, std::move(message));
+ }),
+ server_receiver_(
+ epoll_, server_,
+ [this](SctpServer &server,
+ const union sctp_notification *notification) {
+ HandleNotification(server, notification);
+ },
+ [this](SctpServer &server, std::vector<uint8_t> message) {
+ HandleMessage(server, std::move(message));
+ }) {
+ server_.SetAuthKey(server_key);
+ client_.SetAuthKey(client_key);
+ timeout_.SetTime(aos::monotonic_clock::now() + timeout,
+ std::chrono::milliseconds::zero());
+ epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
+ }
+
+ static void SetUpTestSuite() {
+ EnableSctpAuthIfAvailable();
+ // Buildkite seems to have issues with ipv6 sctp sockets...
+ FLAGS_disable_ipv6 = true;
+ }
+
+ void SetUp() override { Run(); }
+
+ protected:
+ // Handles a server notification message.
+ //
+ // The default behaviour is to track the sctp association ID.
+ virtual void HandleNotification(SctpServer &,
+ const union sctp_notification *notification) {
+ if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
+ assoc_ = notification->sn_assoc_change.sac_assoc_id;
+ }
+ }
+
+ // Handles the client notification message.
+ virtual void HandleNotification(SctpClient &,
+ const union sctp_notification *) {}
+
+ // Handles a server "data" message.
+ virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
+ // Handles a client "data" message.
+ virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
+
+ // Defines the "timeout" behaviour (fail by default).
+ virtual void TimeOut() {
+ Quit();
+ FAIL() << "Timer expired";
+ }
+
+ virtual ~SctpTest() {}
+
+ // Quit the test.
+ void Quit() {
+ epoll_.DeleteFd(timeout_.fd());
+ epoll_.Quit();
+ }
+ void Run() { epoll_.Run(); }
+
+ SctpServer server_;
+ SctpClient client_;
+ sctp_assoc_t assoc_ = 0;
+
+ private:
+ TimerFd timeout_;
+ EPoll epoll_;
+ SctpReceiver<SctpClient> client_receiver_;
+ SctpReceiver<SctpServer> server_receiver_;
+};
+
+// Verifies we can ping the server, and the server replies.
+class SctpPingPongTest : public SctpTest {
+ public:
+ SctpPingPongTest()
+ : SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &server,
+ std::vector<uint8_t> message) override {
+ // Server should receive a ping message.
+ EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+ got_ping_ = true;
+ ASSERT_NE(assoc_, 0);
+ // Reply with "pong".
+ server.Send("pong", assoc_, 0, 0);
+ }
+
+ void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+ // Client should receive a "pong" message.
+ EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+ got_pong_ = true;
+ // We are done.
+ Quit();
+ }
+ ~SctpPingPongTest() {
+ // Check that we got the ping/pong messages.
+ // This isn't strictly necessary as otherwise we would time out and fail
+ // anyway.
+ EXPECT_TRUE(got_ping_);
+ EXPECT_TRUE(got_pong_);
+ }
+
+ protected:
+ bool got_ping_ = false;
+ bool got_pong_ = false;
+};
+
+TEST_F(SctpPingPongTest, Test) {}
+
+#if HAS_SCTP_AUTH
+
+// Same as SctpPingPongTest but with authentication keys. Both keys are the
+// same so it should work the same way.
+class SctpAuthTest : public SctpTest {
+ public:
+ SctpAuthTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
+ /*timeout*/ 20s) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &server,
+ std::vector<uint8_t> message) override {
+ // Server should receive a ping message.
+ EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+ got_ping_ = true;
+ ASSERT_NE(assoc_, 0);
+ // Reply with "pong".
+ server.Send("pong", assoc_, 0, 0);
+ }
+ void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+ // Client should receive a "pong" message.
+ EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+ got_pong_ = true;
+ // We are done.
+ Quit();
+ }
+ ~SctpAuthTest() {
+ EXPECT_TRUE(got_ping_);
+ EXPECT_TRUE(got_pong_);
+ }
+
+ protected:
+ bool got_ping_ = false;
+ bool got_pong_ = false;
+};
+
+TEST_F(SctpAuthTest, Test) {}
+
+// Tests that we can dynamically change the SCTP authentication key used.
+class SctpChangingAuthKeysTest : public SctpTest {
+ public:
+ SctpChangingAuthKeysTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+ SctpAuthMethod::kAuth) {
+ // Start by having the client send "ping".
+ client_.SetAuthKey({5, 4, 3, 2, 1});
+ server_.SetAuthKey({5, 4, 3, 2, 1});
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &server,
+ std::vector<uint8_t> message) override {
+ // Server should receive a ping message.
+ EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+ got_ping_ = true;
+ ASSERT_NE(assoc_, 0);
+ // Reply with "pong".
+ server.Send("pong", assoc_, 0, 0);
+ }
+ void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+ // Client should receive a "pong" message.
+ EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+ got_pong_ = true;
+ // We are done.
+ Quit();
+ }
+
+ ~SctpChangingAuthKeysTest() {
+ EXPECT_TRUE(got_ping_);
+ EXPECT_TRUE(got_pong_);
+ }
+
+ protected:
+ bool got_ping_ = false;
+ bool got_pong_ = false;
+};
+
+TEST_F(SctpChangingAuthKeysTest, Test) {}
+
+// Keys don't match, we should send the `ping` message but the server should
+// never receive it. We then time out as nothing calls Quit.
+class SctpMismatchedAuthTest : public SctpTest {
+ public:
+ SctpMismatchedAuthTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
+ SctpAuthMethod::kAuth) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+ FAIL() << "Authentication keys don't match. Message should be discarded";
+ Quit();
+ }
+
+ // We expect to time out since we never get the message.
+ void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpMismatchedAuthTest, Test) {}
+
+// Same as SctpMismatchedAuthTest but the client uses the null key. We should
+// see the same behaviour.
+class SctpOneNullKeyTest : public SctpTest {
+ public:
+ SctpOneNullKeyTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+ FAIL() << "Authentication keys don't match. Message should be discarded";
+ Quit();
+ }
+
+ // We expect to time out since we never get the message.
+ void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpOneNullKeyTest, Test) {}
+
+// If we want SCTP authentication but we don't set the auth keys, we shouldn't
+// be able to send data.
+class SctpAuthKeysNotSet : public SctpTest {
+ public:
+ SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+ FAIL() << "Haven't setup authentication keys. Should not get message.";
+ Quit();
+ }
+
+ // We expect to time out since we never get the message.
+ void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpAuthKeysNotSet, Test) {}
+
+#endif // HAS_SCTP_AUTH
+
+} // namespace aos::message_bridge::testing
diff --git a/aos/realtime.cc b/aos/realtime.cc
index 14849a4..8b34ada 100644
--- a/aos/realtime.cc
+++ b/aos/realtime.cc
@@ -19,6 +19,7 @@
#include "glog/raw_logging.h"
#include "aos/thread_local.h"
+#include "aos/uuid.h"
DEFINE_bool(
die_on_malloc, true,
@@ -184,6 +185,10 @@
}
void SetCurrentThreadRealtimePriority(int priority) {
+ // Ensure that we won't get expensive reads of /dev/random when the realtime
+ // scheduler is running.
+ UUID::Random();
+
if (FLAGS_skip_realtime_scheduler) {
LOG(WARNING) << "Ignoring request to switch to the RT scheduler due to "
"--skip_realtime_scheduler.";
diff --git a/aos/starter/starter_test.cc b/aos/starter/starter_test.cc
index fcbf287..3070f8e 100644
--- a/aos/starter/starter_test.cc
+++ b/aos/starter/starter_test.cc
@@ -20,6 +20,23 @@
namespace aos {
namespace starter {
+class ThreadedStarterRunner {
+ public:
+ ThreadedStarterRunner(Starter *starter)
+ : my_thread_([this, starter]() {
+ starter->event_loop()->OnRun([this]() { event_.Set(); });
+ starter->Run();
+ }) {
+ event_.Wait();
+ }
+
+ ~ThreadedStarterRunner() { my_thread_.join(); }
+
+ private:
+ aos::Event event_;
+ std::thread my_thread_;
+};
+
class StarterdTest : public ::testing::Test {
public:
StarterdTest() {
@@ -161,26 +178,17 @@
SetupStarterCleanup(&starter);
- Event starter_started;
- std::thread starterd_thread([&starter, &starter_started] {
- starter.event_loop()->OnRun(
- [&starter_started]() { starter_started.Set(); });
- starter.Run();
- });
- starter_started.Wait();
+ ThreadedStarterRunner starterd_thread(&starter);
- Event client_started;
- std::thread client_thread([&client_loop, &client_started] {
- client_loop.OnRun([&client_started]() { client_started.Set(); });
- client_loop.Run();
- });
- client_started.Wait();
+ aos::Event event;
+ client_loop.OnRun([&event]() { event.Set(); });
+ std::thread client_thread([&client_loop] { client_loop.Run(); });
+ event.Wait();
watcher_loop.Run();
test_done_ = true;
client_thread.join();
ASSERT_TRUE(success);
- starterd_thread.join();
}
INSTANTIATE_TEST_SUITE_P(
@@ -270,18 +278,10 @@
SetupStarterCleanup(&starter);
- Event starter_started;
- std::thread starterd_thread([&starter, &starter_started] {
- starter.event_loop()->OnRun(
- [&starter_started]() { starter_started.Set(); });
- starter.Run();
- });
- starter_started.Wait();
+ ThreadedStarterRunner starterd_thread(&starter);
watcher_loop.Run();
test_done_ = true;
-
- starterd_thread.join();
}
TEST_F(StarterdTest, Autostart) {
@@ -365,18 +365,10 @@
SetupStarterCleanup(&starter);
- Event starter_started;
- std::thread starterd_thread([&starter, &starter_started] {
- starter.event_loop()->OnRun(
- [&starter_started]() { starter_started.Set(); });
- starter.Run();
- });
- starter_started.Wait();
+ ThreadedStarterRunner starterd_thread(&starter);
watcher_loop.Run();
test_done_ = true;
-
- starterd_thread.join();
}
// Tests that starterd respects autorestart.
@@ -462,18 +454,10 @@
SetupStarterCleanup(&starter);
- Event starter_started;
- std::thread starterd_thread([&starter, &starter_started] {
- starter.event_loop()->OnRun(
- [&starter_started]() { starter_started.Set(); });
- starter.Run();
- });
- starter_started.Wait();
+ ThreadedStarterRunner starterd_thread(&starter);
watcher_loop.Run();
test_done_ = true;
-
- starterd_thread.join();
}
TEST_F(StarterdTest, StarterChainTest) {
@@ -579,17 +563,11 @@
// run `starter.Run()` in a thread to simulate it running on
// another process.
- Event started;
- std::thread starterd_thread([&starter, &started] {
- starter.event_loop()->OnRun([&started]() { started.Set(); });
- starter.Run();
- });
+ ThreadedStarterRunner starterd_thread(&starter);
- started.Wait();
client_loop.Run();
EXPECT_TRUE(success);
ASSERT_FALSE(starter.event_loop()->is_running());
- starterd_thread.join();
}
} // namespace starter
diff --git a/aos/starter/subprocess.cc b/aos/starter/subprocess.cc
index b7a1cf6..b1320f4 100644
--- a/aos/starter/subprocess.cc
+++ b/aos/starter/subprocess.cc
@@ -77,15 +77,24 @@
SignalListener::SignalListener(aos::ShmEventLoop *loop,
std::function<void(signalfd_siginfo)> callback)
- : SignalListener(loop, callback,
+ : SignalListener(loop->epoll(), std::move(callback)) {}
+
+SignalListener::SignalListener(aos::internal::EPoll *epoll,
+ std::function<void(signalfd_siginfo)> callback)
+ : SignalListener(epoll, callback,
{SIGHUP, SIGINT, SIGQUIT, SIGABRT, SIGFPE, SIGSEGV,
SIGPIPE, SIGTERM, SIGBUS, SIGXCPU, SIGCHLD}) {}
SignalListener::SignalListener(aos::ShmEventLoop *loop,
std::function<void(signalfd_siginfo)> callback,
std::initializer_list<unsigned int> signals)
- : loop_(loop), callback_(std::move(callback)), signalfd_(signals) {
- loop->epoll()->OnReadable(signalfd_.fd(), [this] {
+ : SignalListener(loop->epoll(), std::move(callback), std::move(signals)) {}
+
+SignalListener::SignalListener(aos::internal::EPoll *epoll,
+ std::function<void(signalfd_siginfo)> callback,
+ std::initializer_list<unsigned int> signals)
+ : epoll_(epoll), callback_(std::move(callback)), signalfd_(signals) {
+ epoll_->OnReadable(signalfd_.fd(), [this] {
signalfd_siginfo info = signalfd_.Read();
if (info.ssi_signo == 0) {
@@ -97,7 +106,7 @@
});
}
-SignalListener::~SignalListener() { loop_->epoll()->DeleteFd(signalfd_.fd()); }
+SignalListener::~SignalListener() { epoll_->DeleteFd(signalfd_.fd()); }
Application::Application(std::string_view name,
std::string_view executable_name,
@@ -123,7 +132,7 @@
pipe_timer_(event_loop_->AddTimer([this]() { FetchOutputs(); })),
child_status_handler_(
event_loop_->AddTimer([this]() { MaybeHandleSignal(); })),
- on_change_(on_change),
+ on_change_({on_change}),
quiet_flag_(quiet_flag) {
event_loop_->OnRun([this]() {
// Every second poll to check if the child is dead. This is used as a
@@ -205,7 +214,7 @@
stdout_pipes_.write.reset();
stderr_pipes_.write.reset();
}
- on_change_();
+ OnChange();
return;
}
@@ -343,7 +352,7 @@
stop_timer_->Schedule(event_loop_->monotonic_now() +
std::chrono::seconds(1));
queue_restart_ = restart;
- on_change_();
+ OnChange();
break;
}
case aos::starter::State::WAITING: {
@@ -354,7 +363,7 @@
DoStart();
} else {
status_ = aos::starter::State::STOPPED;
- on_change_();
+ OnChange();
}
break;
}
@@ -384,7 +393,7 @@
std::chrono::seconds(3));
start_timer_->Disable();
stop_timer_->Disable();
- on_change_();
+ OnChange();
}
std::vector<char *> Application::CArgs() {
@@ -553,7 +562,7 @@
QueueStart();
} else {
status_ = aos::starter::State::STOPPED;
- on_change_();
+ OnChange();
}
break;
}
@@ -571,7 +580,7 @@
QueueStart();
} else {
status_ = aos::starter::State::STOPPED;
- on_change_();
+ OnChange();
}
break;
}
@@ -584,7 +593,7 @@
// Disable force stop timer since the process already died
stop_timer_->Disable();
- on_change_();
+ OnChange();
if (terminating_) {
return true;
}
@@ -606,4 +615,10 @@
return false;
}
+void Application::OnChange() {
+ for (auto &fn : on_change_) {
+ fn();
+ }
+}
+
} // namespace aos::starter
diff --git a/aos/starter/subprocess.h b/aos/starter/subprocess.h
index 60732c3..ff62117 100644
--- a/aos/starter/subprocess.h
+++ b/aos/starter/subprocess.h
@@ -21,14 +21,19 @@
public:
SignalListener(aos::ShmEventLoop *loop,
std::function<void(signalfd_siginfo)> callback);
+ SignalListener(aos::internal::EPoll *epoll,
+ std::function<void(signalfd_siginfo)> callback);
SignalListener(aos::ShmEventLoop *loop,
std::function<void(signalfd_siginfo)> callback,
std::initializer_list<unsigned int> signals);
+ SignalListener(aos::internal::EPoll *epoll,
+ std::function<void(signalfd_siginfo)> callback,
+ std::initializer_list<unsigned int> signals);
~SignalListener();
private:
- aos::ShmEventLoop *loop_;
+ aos::internal::EPoll *epoll_;
std::function<void(signalfd_siginfo)> callback_;
aos::ipc_lib::SignalFd signalfd_;
@@ -91,6 +96,13 @@
void Terminate();
+ // Adds a callback which gets notified when the application changes state.
+ // This is in addition to any existing callbacks and doesn't replace any of
+ // them.
+ void AddOnChange(std::function<void()> fn) {
+ on_change_.emplace_back(std::move(fn));
+ }
+
void set_args(std::vector<std::string> args);
void set_capture_stdout(bool capture);
void set_capture_stderr(bool capture);
@@ -127,6 +139,8 @@
void QueueStart();
+ void OnChange();
+
// Copy flatbuffer vector of strings to vector of std::string.
static std::vector<std::string> FbsVectorToVector(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &v);
@@ -178,7 +192,7 @@
aos::TimerHandler *start_timer_, *restart_timer_, *stop_timer_, *pipe_timer_,
*child_status_handler_;
- std::function<void()> on_change_;
+ std::vector<std::function<void()>> on_change_;
std::unique_ptr<MemoryCGroup> memory_cgroup_;
diff --git a/aos/util/file.cc b/aos/util/file.cc
index 4e2d1cd..52657a9 100644
--- a/aos/util/file.cc
+++ b/aos/util/file.cc
@@ -7,6 +7,7 @@
#include <sys/types.h>
#include <unistd.h>
+#include <optional>
#include <string_view>
#if __has_feature(memory_sanitizer)
#include <sanitizer/msan_interface.h>
@@ -17,18 +18,47 @@
namespace aos {
namespace util {
-::std::string ReadFileToStringOrDie(const std::string_view filename) {
- ::std::string r;
+std::string ReadFileToStringOrDie(const std::string_view filename) {
+ std::optional<std::string> r = MaybeReadFileToString(filename);
+ PCHECK(r.has_value()) << "Failed to read " << filename << " to string";
+ return r.value();
+}
+
+std::optional<std::string> MaybeReadFileToString(
+ const std::string_view filename) {
+ std::string r;
+ ScopedFD fd(open(::std::string(filename).c_str(), O_RDONLY));
+ if (fd.get() == -1) {
+ PLOG(ERROR) << "Failed to open " << filename;
+ return std::nullopt;
+ }
+ while (true) {
+ char buffer[1024];
+ const ssize_t result = read(fd.get(), buffer, sizeof(buffer));
+ if (result < 0) {
+ PLOG(ERROR) << "Failed to read from " << filename;
+ return std::nullopt;
+ }
+ if (result == 0) {
+ break;
+ }
+ r.append(buffer, result);
+ }
+ return r;
+}
+
+std::vector<uint8_t> ReadFileToVecOrDie(const std::string_view filename) {
+ std::vector<uint8_t> r;
ScopedFD fd(open(::std::string(filename).c_str(), O_RDONLY));
PCHECK(fd.get() != -1) << ": opening " << filename;
while (true) {
- char buffer[1024];
+ uint8_t buffer[1024];
const ssize_t result = read(fd.get(), buffer, sizeof(buffer));
PCHECK(result >= 0) << ": reading from " << filename;
if (result == 0) {
break;
}
- r.append(buffer, result);
+ std::copy(buffer, buffer + result, std::back_inserter(r));
}
return r;
}
diff --git a/aos/util/file.h b/aos/util/file.h
index 3b2231a..da5b29f 100644
--- a/aos/util/file.h
+++ b/aos/util/file.h
@@ -22,7 +22,16 @@
// Returns the complete contents of filename. LOG(FATAL)s if any errors are
// encountered.
-::std::string ReadFileToStringOrDie(const std::string_view filename);
+std::string ReadFileToStringOrDie(const std::string_view filename);
+
+// Returns the complete contents of filename. Returns nullopt, but never dies
+// if any errors are encountered.
+std::optional<std::string> MaybeReadFileToString(
+ const std::string_view filename);
+
+// Returns the complete contents of filename as a byte vector. LOG(FATAL)s if
+// any errors are encountered.
+std::vector<uint8_t> ReadFileToVecOrDie(const std::string_view filename);
// Creates filename if it doesn't exist and sets the contents to contents.
void WriteStringToFileOrDie(const std::string_view filename,
diff --git a/aos/util/file_test.cc b/aos/util/file_test.cc
index 712447e..ec4bfe4 100644
--- a/aos/util/file_test.cc
+++ b/aos/util/file_test.cc
@@ -1,8 +1,10 @@
#include "aos/util/file.h"
#include <cstdlib>
+#include <optional>
#include <string>
+#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"
#include "aos/realtime.h"
@@ -12,22 +14,57 @@
namespace util {
namespace testing {
+using ::testing::ElementsAre;
+
// Basic test of reading a normal file.
TEST(FileTest, ReadNormalFile) {
- const ::std::string tmpdir(aos::testing::TestTmpDir());
- const ::std::string test_file = tmpdir + "/test_file";
+ const std::string tmpdir(aos::testing::TestTmpDir());
+ const std::string test_file = tmpdir + "/test_file";
ASSERT_EQ(0, system(("echo contents > " + test_file).c_str()));
EXPECT_EQ("contents\n", ReadFileToStringOrDie(test_file));
}
+// Basic test of reading a normal file.
+TEST(FileTest, ReadNormalFileToBytes) {
+ const std::string tmpdir(aos::testing::TestTmpDir());
+ const std::string test_file = tmpdir + "/test_file";
+ ASSERT_EQ(0, system(("echo contents > " + test_file).c_str()));
+ EXPECT_THAT(ReadFileToVecOrDie(test_file),
+ ElementsAre('c', 'o', 'n', 't', 'e', 'n', 't', 's', '\n'));
+}
+
// Tests reading a file with 0 size, among other weird things.
TEST(FileTest, ReadSpecialFile) {
- const ::std::string stat = ReadFileToStringOrDie("/proc/self/stat");
+ const std::string stat = ReadFileToStringOrDie("/proc/self/stat");
EXPECT_EQ('\n', stat[stat.size() - 1]);
- const ::std::string my_pid = ::std::to_string(getpid());
+ const std::string my_pid = ::std::to_string(getpid());
EXPECT_EQ(my_pid, stat.substr(0, my_pid.size()));
}
+// Basic test of maybe reading a normal file.
+TEST(FileTest, MaybeReadNormalFile) {
+ const std::string tmpdir(aos::testing::TestTmpDir());
+ const std::string test_file = tmpdir + "/test_file";
+ ASSERT_EQ(0, system(("echo contents > " + test_file).c_str()));
+ EXPECT_EQ("contents\n", MaybeReadFileToString(test_file).value());
+}
+
+// Tests maybe reading a file with 0 size, among other weird things.
+TEST(FileTest, MaybeReadSpecialFile) {
+ const std::optional<std::string> stat =
+ MaybeReadFileToString("/proc/self/stat");
+ ASSERT_TRUE(stat.has_value());
+ EXPECT_EQ('\n', (*stat)[stat->size() - 1]);
+ const std::string my_pid = std::to_string(getpid());
+ EXPECT_EQ(my_pid, stat->substr(0, my_pid.size()));
+}
+
+// Tests maybe reading a non-existent file, and not fatally erroring.
+TEST(FileTest, MaybeReadNonexistentFile) {
+ const std::optional<std::string> contents = MaybeReadFileToString("/dne");
+ ASSERT_FALSE(contents.has_value());
+}
+
// Tests that the PathExists function works under normal conditions.
TEST(FileTest, PathExistsTest) {
const std::string tmpdir(aos::testing::TestTmpDir());
diff --git a/aos/uuid.cc b/aos/uuid.cc
index 67075b6..f794ec8 100644
--- a/aos/uuid.cc
+++ b/aos/uuid.cc
@@ -66,12 +66,39 @@
} // namespace
+namespace internal {
+std::mt19937 FullySeededRandomGenerator() {
+ // Total bits that the mt19937 has internally that we could plausibly
+ // initialize with.
+ // The internal state ends up being ~1200 bytes, which is significantly more
+ // than the 128 bits we want for UUIDs, but since we should only need to
+ // generate this randomness once, it should be fine.
+ // If the performance cost ends up causing issues, then we can revisit the
+ // need to *fully* seed the twister.
+ constexpr size_t kInternalEntropy =
+ std::mt19937::state_size * sizeof(std::mt19937::result_type);
+ // Number, rounded up, of random values required.
+ constexpr size_t kSeedsRequired =
+ ((kInternalEntropy - 1) / sizeof(std::random_device::result_type)) + 1;
+ std::random_device random_device;
+// Older LLVM libstdc++'s just return 0 for the random device entropy.
+#if !defined(__clang__) || (__clang_major__ > 13)
+ CHECK_EQ(sizeof(std::random_device::result_type) * 8, random_device.entropy())
+ << ": Does your random_device actually support generating entropy?";
+#endif
+ std::array<std::random_device::result_type, kSeedsRequired> random_data;
+ std::generate(std::begin(random_data), std::end(random_data),
+ std::ref(random_device));
+ std::seed_seq seeds(std::begin(random_data), std::end(random_data));
+ return std::mt19937(seeds);
+}
+} // namespace internal
+
UUID UUID::Random() {
- std::random_device rd;
- std::mt19937 gen(rd());
+ // thread_local to guarantee safe use of the generator itself.
+ thread_local std::mt19937 gen(internal::FullySeededRandomGenerator());
std::uniform_int_distribution<> dis(0, 255);
- std::uniform_int_distribution<> dis2(8, 11);
UUID result;
for (size_t i = 0; i < kDataSize; ++i) {
result.data_[i] = dis(gen);
diff --git a/aos/uuid.h b/aos/uuid.h
index 1b63ac3..371eb6d 100644
--- a/aos/uuid.h
+++ b/aos/uuid.h
@@ -3,6 +3,7 @@
#include <array>
#include <ostream>
+#include <random>
#include <string>
#include "absl/types/span.h"
@@ -20,6 +21,9 @@
static constexpr size_t kDataSize = 16;
// Returns a randomly generated UUID. This is known as a UUID4.
+ // The first Random() call in a thread will tend to be slightly slower than
+ // the rest so that it can seed the pseudo-random number generator used
+ // internally.
static UUID Random();
// Returns a uuid with all '0's.
@@ -95,6 +99,13 @@
std::ostream &operator<<(std::ostream &os, const UUID &uuid);
+namespace internal {
+// Initializes a mt19937 with as much entropy as it can take (rather than just a
+// 32-bit value from std::random_device).
+// Exposed for testing purposes.
+std::mt19937 FullySeededRandomGenerator();
+} // namespace internal
+
} // namespace aos
#endif // AOS_EVENTS_LOGGING_UUID_H_
diff --git a/aos/uuid_collision_test.cc b/aos/uuid_collision_test.cc
new file mode 100644
index 0000000..05bbccd
--- /dev/null
+++ b/aos/uuid_collision_test.cc
@@ -0,0 +1,44 @@
+#include <set>
+#include <unordered_set>
+
+#include "glog/logging.h"
+#include "gtest/gtest.h"
+
+#include "aos/uuid.h"
+
+namespace aos {
+namespace testing {
+
+// Tests that modest numbers of UUID::Random() calls cannot create UUID
+// collisions (to test that we have not *completely* messed up the random number
+// generation).
+TEST(UUIDTest, CollisionTest) {
+ std::set<UUID> uuids;
+ // When we only had ~32 bits of randomness in our UUIDs, we could generate
+ // issues with only ~sqrt(2 ** 32) (aka 2 ** 16) UUIDs.
+ // Just go up to 2 ** 22, since too much longer just makes this test take
+ // obnoxiously long.
+ for (size_t ii = 0; ii < (1UL << 22); ++ii) {
+ UUID uuid = UUID::Random();
+ ASSERT_FALSE(uuids.count(uuid) > 0) << ii;
+ uuids.insert(uuid);
+ }
+}
+
+// Tests that our random seed generation for the mt19937 does not trivially
+// collide.
+TEST(UUIDTest, SeedInitializationTest) {
+ std::uniform_int_distribution<uint64_t> distribution(0);
+ std::set<uint64_t> values;
+ // This test takes significantly longer than the above due to needing to query
+ // std::random_device substantially. However, covering a range of 2 ** 18
+ // should readily catch things if we are accidentally using 32-bit seeds.
+ for (size_t ii = 0; ii < (1UL << 18); ++ii) {
+ std::mt19937 twister = internal::FullySeededRandomGenerator();
+ const uint64_t value = distribution(twister);
+ ASSERT_FALSE(values.count(value) > 0) << ii;
+ values.insert(value);
+ }
+}
+} // namespace testing
+} // namespace aos
diff --git a/frc971/analysis/log_to_match.cc b/frc971/analysis/log_to_match.cc
index b23e8ad..cbe7e50 100644
--- a/frc971/analysis/log_to_match.cc
+++ b/frc971/analysis/log_to_match.cc
@@ -7,9 +7,8 @@
int main(int argc, char **argv) {
aos::InitGoogle(&argc, &argv);
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
- aos::logger::LogReader reader(aos::logger::SortParts(unsorted_logfiles));
+ aos::logger::LogReader reader(
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv)));
reader.Register();
const aos::Node *roborio =
aos::configuration::GetNode(reader.configuration(), "roborio");
diff --git a/y2020/control_loops/drivetrain/drivetrain_replay.cc b/y2020/control_loops/drivetrain/drivetrain_replay.cc
index 0035fb3..326c184 100644
--- a/y2020/control_loops/drivetrain/drivetrain_replay.cc
+++ b/y2020/control_loops/drivetrain/drivetrain_replay.cc
@@ -39,13 +39,9 @@
const aos::FlatbufferDetachedBuffer<aos::Configuration> config =
aos::configuration::ReadConfig(FLAGS_config);
- // find logfiles
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
// sort logfiles
const std::vector<aos::logger::LogFile> logfiles =
- aos::logger::SortParts(unsorted_logfiles);
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv));
// open logfiles
aos::logger::LogReader reader(logfiles, &config.message());
diff --git a/y2020/vision/viewer_replay.cc b/y2020/vision/viewer_replay.cc
index a818859..03777bc 100644
--- a/y2020/vision/viewer_replay.cc
+++ b/y2020/vision/viewer_replay.cc
@@ -17,11 +17,9 @@
namespace {
void ViewerMain(int argc, char *argv[]) {
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
// open logfiles
- aos::logger::LogReader reader(aos::logger::SortParts(unsorted_logfiles));
+ aos::logger::LogReader reader(
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv)));
reader.Register();
const aos::Node *node = nullptr;
if (aos::configuration::MultiNode(reader.configuration())) {
diff --git a/y2022/localizer/localizer_replay.cc b/y2022/localizer/localizer_replay.cc
index e706421..a3baf3b 100644
--- a/y2022/localizer/localizer_replay.cc
+++ b/y2022/localizer/localizer_replay.cc
@@ -29,13 +29,9 @@
const aos::FlatbufferDetachedBuffer<aos::Configuration> config =
aos::configuration::ReadConfig(FLAGS_config);
- // find logfiles
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
// sort logfiles
const std::vector<aos::logger::LogFile> logfiles =
- aos::logger::SortParts(unsorted_logfiles);
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv));
// open logfiles
aos::logger::LogReader reader(logfiles, &config.message());
diff --git a/y2022/vision/viewer_replay.cc b/y2022/vision/viewer_replay.cc
index 66087e5..7e8c800 100644
--- a/y2022/vision/viewer_replay.cc
+++ b/y2022/vision/viewer_replay.cc
@@ -47,10 +47,9 @@
data->match_start = monotonic_clock::min_time;
data->match_end = monotonic_clock::min_time;
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(FLAGS_roborio_log);
// Open logfiles
- aos::logger::LogReader reader(aos::logger::SortParts(unsorted_logfiles));
+ aos::logger::LogReader reader(
+ aos::logger::SortParts(aos::logger::FindLogs(FLAGS_roborio_log)));
reader.Register();
const aos::Node *roborio =
aos::configuration::GetNode(reader.configuration(), "roborio");
@@ -149,11 +148,9 @@
<< "Can't use match timestamps if match never ended";
}
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(FLAGS_logger_pi_log);
-
// Open logfiles
- aos::logger::LogReader reader(aos::logger::SortParts(unsorted_logfiles));
+ aos::logger::LogReader reader(
+ aos::logger::SortParts(aos::logger::FindLogs(FLAGS_logger_pi_log)));
reader.Register();
const aos::Node *pi =
aos::configuration::GetNode(reader.configuration(), FLAGS_pi);
diff --git a/y2023/localizer/localizer_replay.cc b/y2023/localizer/localizer_replay.cc
index 27fdd85..859b77e 100644
--- a/y2023/localizer/localizer_replay.cc
+++ b/y2023/localizer/localizer_replay.cc
@@ -25,13 +25,9 @@
const aos::FlatbufferDetachedBuffer<aos::Configuration> config =
aos::configuration::ReadConfig(FLAGS_config);
- // find logfiles
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
// sort logfiles
const std::vector<aos::logger::LogFile> logfiles =
- aos::logger::SortParts(unsorted_logfiles);
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv));
// open logfiles
aos::logger::LogReader reader(logfiles, &config.message());
diff --git a/y2023/vision/target_mapping.cc b/y2023/vision/target_mapping.cc
index 985a412..2ca6f66 100644
--- a/y2023/vision/target_mapping.cc
+++ b/y2023/vision/target_mapping.cc
@@ -398,9 +398,6 @@
}
void MappingMain(int argc, char *argv[]) {
- std::vector<std::string> unsorted_logfiles =
- aos::logger::FindLogs(argc, argv);
-
std::vector<DataAdapter::TimestampedDetection> timestamped_target_detections;
std::optional<aos::FlatbufferDetachedBuffer<aos::Configuration>> config =
@@ -410,7 +407,7 @@
// Open logfiles
aos::logger::LogReader reader(
- aos::logger::SortParts(unsorted_logfiles),
+ aos::logger::SortParts(aos::logger::FindLogs(argc, argv)),
config.has_value() ? &config->message() : nullptr);
TargetMapperReplay mapper_replay(&reader);