Add sctp authentication to sctp_lib
This only works for linux >=5.4. When enabled, it will use
a shared key to authenticate messages. The functionality is
controlled by a flag and behind a linux version check.
Performance degradation is minimal, even for smaller messages
and unnoticeable when measuring overall system performance.
Change-Id: I836e61ec38a0c116fd7244b771437738ccca9828
Signed-off-by: James Kuszmaul <jabukuszmaul+collab@gmail.com>
diff --git a/aos/network/BUILD b/aos/network/BUILD
index d98b43b..6f98b6b 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -160,6 +160,25 @@
],
)
+cc_test(
+ name = "sctp_test",
+ srcs = [
+ "sctp_test.cc",
+ ],
+ tags = [
+ # Fakeroot is required to enable "net.sctp.auth_enable".
+ "requires-fakeroot",
+ ],
+ target_compatible_with = ["@platforms//cpu:x86_64"],
+ deps = [
+ ":sctp_client",
+ ":sctp_lib",
+ ":sctp_server",
+ "//aos/events:epoll",
+ "//aos/testing:googletest",
+ ],
+)
+
cc_library(
name = "sctp_server",
srcs = [
@@ -269,6 +288,7 @@
target_compatible_with = ["@platforms//os:linux"],
deps = [
":message_bridge_server_lib",
+ ":sctp_lib",
"//aos:init",
"//aos:json_to_flatbuffer",
"//aos:sha256",
@@ -351,11 +371,13 @@
target_compatible_with = ["@platforms//os:linux"],
deps = [
":message_bridge_client_lib",
+ ":sctp_lib",
"//aos:init",
"//aos:json_to_flatbuffer",
"//aos:sha256",
"//aos/events:shm_event_loop",
"//aos/logging:dynamic_logging",
+ "//aos/util:file",
],
)
diff --git a/aos/network/message_bridge_client.cc b/aos/network/message_bridge_client.cc
index c3f55ba..afe2a90 100644
--- a/aos/network/message_bridge_client.cc
+++ b/aos/network/message_bridge_client.cc
@@ -2,14 +2,24 @@
#include "aos/init.h"
#include "aos/logging/dynamic_logging.h"
#include "aos/network/message_bridge_client_lib.h"
+#include "aos/network/sctp_lib.h"
#include "aos/sha256.h"
+#include "aos/util/file.h"
DEFINE_string(config, "aos_config.json", "Path to the config.");
DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
+#if HAS_SCTP_AUTH
+DEFINE_string(sctp_auth_key_file, "",
+ "When set, use the provided key for SCTP authentication as "
+ "defined in RFC 4895. The file should be binary-encoded");
+#endif
+
namespace aos {
namespace message_bridge {
+using ::aos::util::ReadFileToVecOrDie;
+
int Main() {
aos::FlatbufferDetachedBuffer<aos::Configuration> config =
aos::configuration::ReadConfig(FLAGS_config);
@@ -19,7 +29,14 @@
event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
}
- MessageBridgeClient app(&event_loop, Sha256(config.span()));
+ std::vector<uint8_t> auth_key;
+#if HAS_SCTP_AUTH
+ if (!FLAGS_sctp_auth_key_file.empty()) {
+ auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+ }
+#endif
+ MessageBridgeClient app(&event_loop, Sha256(config.span()),
+ std::move(auth_key));
logging::DynamicLogging dynamic_logging(&event_loop);
// TODO(austin): Save messages into a vector to be logged. One file per
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index c23f748..0f491fc 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -100,7 +100,8 @@
aos::ShmEventLoop *const event_loop, std::string_view remote_name,
const Node *my_node, std::string_view local_host,
std::vector<SctpClientChannelState> *channels, int client_index,
- MessageBridgeClientStatus *client_status, std::string_view config_sha256)
+ MessageBridgeClientStatus *client_status, std::string_view config_sha256,
+ std::vector<uint8_t> auth_key)
: event_loop_(event_loop),
connect_message_(MakeConnectMessage(event_loop->configuration(), my_node,
remote_name, event_loop->boot_uuid(),
@@ -111,7 +112,7 @@
client_(remote_node_->hostname()->string_view(), remote_node_->port(),
connect_message_.message().channels_to_transfer()->size() +
kControlStreams(),
- local_host, 0),
+ local_host, 0, std::move(auth_key)),
channels_(channels),
stream_to_channel_(
StreamToChannel(event_loop->configuration(), my_node, remote_node_)),
@@ -364,7 +365,8 @@
}
MessageBridgeClient::MessageBridgeClient(aos::ShmEventLoop *event_loop,
- std::string config_sha256)
+ std::string config_sha256,
+ std::vector<uint8_t> auth_key)
: event_loop_(event_loop),
client_status_(event_loop_),
config_sha256_(std::move(config_sha256)) {
@@ -415,7 +417,7 @@
connections_.emplace_back(new SctpClientConnection(
event_loop, source_node, event_loop->node(), "", &channels_,
client_status_.FindClientIndex(source_node), &client_status_,
- config_sha256_));
+ config_sha256_, auth_key));
}
}
diff --git a/aos/network/message_bridge_client_lib.h b/aos/network/message_bridge_client_lib.h
index a5a69e3..bb3bc1a 100644
--- a/aos/network/message_bridge_client_lib.h
+++ b/aos/network/message_bridge_client_lib.h
@@ -38,7 +38,8 @@
std::vector<SctpClientChannelState> *channels,
int client_index,
MessageBridgeClientStatus *client_status,
- std::string_view config_sha256);
+ std::string_view config_sha256,
+ std::vector<uint8_t> auth_key);
~SctpClientConnection() { event_loop_->epoll()->DeleteFd(client_.fd()); }
@@ -102,7 +103,10 @@
// node.
class MessageBridgeClient {
public:
- MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256);
+ // When the `auth_key` byte-vector is non-empty, it will be used as the shared
+ // key to authenticate every channel (See RFC4895 for more info).
+ MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256,
+ std::vector<uint8_t> auth_key);
~MessageBridgeClient() {}
diff --git a/aos/network/message_bridge_server.cc b/aos/network/message_bridge_server.cc
index 04b07c3..4daf9c7 100644
--- a/aos/network/message_bridge_server.cc
+++ b/aos/network/message_bridge_server.cc
@@ -5,14 +5,23 @@
#include "aos/init.h"
#include "aos/logging/dynamic_logging.h"
#include "aos/network/message_bridge_server_lib.h"
+#include "aos/network/sctp_lib.h"
#include "aos/sha256.h"
DEFINE_string(config, "aos_config.json", "Path to the config.");
DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
+#if HAS_SCTP_AUTH
+DEFINE_string(sctp_auth_key_file, "",
+ "When set, use the provided key for SCTP authentication as "
+ "defined in RFC 4895");
+#endif
+
namespace aos {
namespace message_bridge {
+using ::aos::util::ReadFileToVecOrDie;
+
int Main() {
aos::FlatbufferDetachedBuffer<aos::Configuration> config =
aos::configuration::ReadConfig(FLAGS_config);
@@ -22,7 +31,14 @@
event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
}
- MessageBridgeServer app(&event_loop, Sha256(config.span()));
+ std::vector<uint8_t> auth_key;
+#if HAS_SCTP_AUTH
+ if (!FLAGS_sctp_auth_key_file.empty()) {
+ auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+ }
+#endif
+ MessageBridgeServer app(&event_loop, Sha256(config.span()),
+ std::move(auth_key));
logging::DynamicLogging dynamic_logging(&event_loop);
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index debfeeb..b925c6b 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -402,11 +402,12 @@
}
MessageBridgeServer::MessageBridgeServer(aos::ShmEventLoop *event_loop,
- std::string config_sha256)
+ std::string config_sha256,
+ std::vector<uint8_t> auth_key)
: event_loop_(event_loop),
timestamp_loggers_(event_loop_),
server_(max_channels() + kControlStreams(), "",
- event_loop->node()->port()),
+ event_loop->node()->port(), auth_key),
server_status_(event_loop, [this]() { timestamp_state_->SendData(); }),
config_sha256_(std::move(config_sha256)),
allocator_(0) {
diff --git a/aos/network/message_bridge_server_lib.h b/aos/network/message_bridge_server_lib.h
index 6a3fdd8..ed81de6 100644
--- a/aos/network/message_bridge_server_lib.h
+++ b/aos/network/message_bridge_server_lib.h
@@ -167,7 +167,10 @@
// node. It handles the session and dispatches data to the ChannelState.
class MessageBridgeServer {
public:
- MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256);
+ // When the `auth_key` byte-vector is non-empty, it will be used as the shared
+ // key to authenticate every channel (See RFC4895 for more info).
+ MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256,
+ std::vector<uint8_t> auth_key);
// Delete copy/move constructors explicitly--we internally pass around
// pointers to internal state.
diff --git a/aos/network/message_bridge_test_lib.cc b/aos/network/message_bridge_test_lib.cc
index e818075..226c13a 100644
--- a/aos/network/message_bridge_test_lib.cc
+++ b/aos/network/message_bridge_test_lib.cc
@@ -81,7 +81,8 @@
pi1_server_event_loop->SetRuntimeRealtimePriority(1);
pi1_message_bridge_server = std::make_unique<MessageBridgeServer>(
pi1_server_event_loop.get(),
- server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256);
+ server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256,
+ /*auth_key=*/std::vector<uint8_t>({}));
}
void MessageBridgeParameterizedTest::RunPi1Server(
@@ -115,7 +116,8 @@
std::make_unique<aos::ShmEventLoop>(&config.message());
pi1_client_event_loop->SetRuntimeRealtimePriority(1);
pi1_message_bridge_client = std::make_unique<MessageBridgeClient>(
- pi1_client_event_loop.get(), config_sha256);
+ pi1_client_event_loop.get(), config_sha256,
+ /*auth_key=*/std::vector<uint8_t>({}));
}
void MessageBridgeParameterizedTest::StartPi1Client() {
@@ -168,7 +170,8 @@
std::make_unique<aos::ShmEventLoop>(&config.message());
pi2_server_event_loop->SetRuntimeRealtimePriority(1);
pi2_message_bridge_server = std::make_unique<MessageBridgeServer>(
- pi2_server_event_loop.get(), config_sha256);
+ pi2_server_event_loop.get(), config_sha256,
+ /*auth_key=*/std::vector<uint8_t>({}));
}
void MessageBridgeParameterizedTest::RunPi2Server(
@@ -202,7 +205,8 @@
std::make_unique<aos::ShmEventLoop>(&config.message());
pi2_client_event_loop->SetRuntimeRealtimePriority(1);
pi2_message_bridge_client = std::make_unique<MessageBridgeClient>(
- pi2_client_event_loop.get(), config_sha256);
+ pi2_client_event_loop.get(), config_sha256,
+ /*auth_key=*/std::vector<uint8_t>({}));
}
void MessageBridgeParameterizedTest::RunPi2Client(
diff --git a/aos/network/sctp_client.cc b/aos/network/sctp_client.cc
index a87dfc6..e044396 100644
--- a/aos/network/sctp_client.cc
+++ b/aos/network/sctp_client.cc
@@ -22,8 +22,9 @@
namespace message_bridge {
SctpClient::SctpClient(std::string_view remote_host, int remote_port,
- int streams, std::string_view local_host,
- int local_port) {
+ int streams, std::string_view local_host, int local_port,
+ std::vector<uint8_t> sctp_auth_key)
+ : sctp_(std::move(sctp_auth_key)) {
bool use_ipv6 = Ipv6Enabled();
sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
sockaddr_remote_ = ResolveSocket(remote_host, remote_port, use_ipv6);
diff --git a/aos/network/sctp_client.h b/aos/network/sctp_client.h
index 5affecc..c6b6324 100644
--- a/aos/network/sctp_client.h
+++ b/aos/network/sctp_client.h
@@ -17,7 +17,8 @@
class SctpClient {
public:
SctpClient(std::string_view remote_host, int remote_port, int streams,
- std::string_view local_host = "0.0.0.0", int local_port = 9971);
+ std::string_view local_host = "0.0.0.0", int local_port = 9971,
+ std::vector<uint8_t> sctp_auth_key = {});
~SctpClient() {}
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 6030ea7..77cad02 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -31,6 +31,25 @@
struct sctp_sndrcvinfo sndrcvinfo;
} _sctp_cmsg_data_t;
+// Returns true if SCTP authentication is available and enabled.
+bool SctpAuthIsEnabled() {
+#if HAS_SCTP_AUTH
+ struct stat current_stat;
+ if (stat("/proc/sys/net/sctp/auth_enable", ¤t_stat) != -1) {
+ int value = std::stoi(
+ util::ReadFileToStringOrDie("/proc/sys/net/sctp/auth_enable"));
+ CHECK(value == 0 || value == 1)
+ << "Unknown auth enable sysctl value: " << value;
+ return value == 1;
+ } else {
+ LOG(WARNING) << "/proc/sys/net/sctp/auth_enable doesn't exist.";
+ return false;
+ }
+#else
+ return false;
+#endif
+}
+
} // namespace
bool Ipv6Enabled() {
@@ -270,6 +289,44 @@
sizeof(subscribe)) == 0);
}
+ if (!auth_key_.empty()) {
+ CHECK(SctpAuthIsEnabled())
+ << "SCTP Authentication is disabled. Enable it with 'sysctl -w "
+ "net.sctp.auth_enable=1' and try again.";
+#if HAS_SCTP_AUTH
+ // Set up the key with id `1`.
+ sctp_authkey *const authkey =
+ (sctp_authkey *)malloc(sizeof(sctp_authkey) + auth_key_.size());
+ authkey->sca_keynumber = 1;
+ authkey->sca_keylength = auth_key_.size();
+ authkey->sca_assoc_id = SCTP_ALL_ASSOC;
+ memcpy(&authkey->sca_key, auth_key_.data(), auth_key_.size());
+
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_KEY, authkey,
+ sizeof(sctp_authkey) + auth_key_.size()) == 0);
+ free(authkey);
+
+ // Set key `1` as active.
+ struct sctp_authkeyid authkeyid;
+ authkeyid.scact_keynumber = 1;
+ authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_ACTIVE_KEY, &authkeyid,
+ sizeof(authkeyid)) == 0);
+
+ // Set up authentication for data chunks.
+ struct sctp_authchunk authchunk;
+ authchunk.sauth_chunk = 0;
+
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_CHUNK, &authchunk,
+ sizeof(authchunk)) == 0);
+
+ // Disallow the null key.
+ authkeyid.scact_keynumber = 0;
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_DELETE_KEY, &authkeyid,
+ sizeof(authkeyid)) == 0);
+#endif
+ }
+
DoSetMaxSize();
}
@@ -335,6 +392,9 @@
return true;
}
+SctpReadWrite::SctpReadWrite(std::vector<uint8_t> auth_key)
+ : auth_key_(std::move(auth_key)) {}
+
void SctpReadWrite::FreeMessage(aos::unique_c_ptr<Message> &&message) {
if (use_pool_) {
free_messages_.emplace_back(std::move(message));
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index 8eb57d3..f78934a 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -3,6 +3,7 @@
#include <arpa/inet.h>
#include <linux/sctp.h>
+#include <linux/version.h>
#include <memory>
#include <optional>
@@ -15,6 +16,8 @@
#include "aos/unique_malloc_ptr.h"
+#define HAS_SCTP_AUTH LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 0)
+
namespace aos {
namespace message_bridge {
@@ -92,7 +95,7 @@
// Manages reading and writing SCTP messages.
class SctpReadWrite {
public:
- SctpReadWrite() = default;
+ SctpReadWrite(std::vector<uint8_t> auth_key = {});
~SctpReadWrite() { CloseSocket(); }
// Opens a new socket.
@@ -161,6 +164,8 @@
bool use_pool_ = false;
std::vector<aos::unique_c_ptr<Message>> free_messages_;
+
+ std::vector<uint8_t> auth_key_;
};
// Returns the max network buffer available for reading for a socket.
diff --git a/aos/network/sctp_perf.cc b/aos/network/sctp_perf.cc
index cce4bed..3bafed1 100644
--- a/aos/network/sctp_perf.cc
+++ b/aos/network/sctp_perf.cc
@@ -6,6 +6,7 @@
#include "aos/events/shm_event_loop.h"
#include "aos/init.h"
#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_lib.h"
#include "aos/network/sctp_server.h"
DEFINE_string(config, "aos_config.json", "Path to the config.");
@@ -21,16 +22,38 @@
DEFINE_uint32(skip_first_n, 10,
"Skip the first 'n' messages when computing statistics.");
+#if HAS_SCTP_AUTH
+DEFINE_string(sctp_auth_key_file, "",
+ "When set, use the provided key for SCTP authentication as "
+ "defined in RFC 4895");
+#endif
+
DECLARE_bool(die_on_malloc);
namespace aos::message_bridge::perf {
+namespace {
+
+using util::ReadFileToVecOrDie;
+
+std::vector<uint8_t> GetSctpAuthKey() {
+#if HAS_SCTP_AUTH
+ if (!FLAGS_sctp_auth_key_file.empty()) {
+ return ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+ }
+#endif
+ return {};
+}
+
+} // namespace
+
namespace chrono = std::chrono;
class Server {
public:
Server(aos::ShmEventLoop *event_loop)
- : event_loop_(event_loop), server_(2, "0.0.0.0", FLAGS_port) {
+ : event_loop_(event_loop),
+ server_(2, "0.0.0.0", FLAGS_port, GetSctpAuthKey()) {
event_loop_->epoll()->OnReadable(server_.fd(),
[this]() { MessageReceived(); });
server_.SetMaxReadSize(FLAGS_rx_size + 100);
@@ -109,7 +132,9 @@
class Client {
public:
Client(aos::ShmEventLoop *event_loop)
- : event_loop_(event_loop), client_(FLAGS_host, FLAGS_port, 2) {
+ : event_loop_(event_loop),
+ client_(FLAGS_host, FLAGS_port, 2, "0.0.0.0", FLAGS_port,
+ GetSctpAuthKey()) {
client_.SetMaxReadSize(FLAGS_rx_size + 100);
client_.SetMaxWriteSize(FLAGS_rx_size + 100);
@@ -196,8 +221,8 @@
double throughput = FLAGS_payload_size * 2.0 / elapsed_secs;
double avg_throughput = FLAGS_payload_size * 2.0 / avg_latency_;
printf(
- "Round trip: %.2fms | %.2f KB/s | Avg RTL: %.2fms | %.2f KB/s | Count: "
- "%d\n",
+ "Round trip: %.2fms | %.2f KB/s | Avg RTL: %.2fms | %.2f KB/s | "
+ "Count: %d\n",
elapsed_secs * 1000, throughput / 1024, avg_latency_ * 1000,
avg_throughput / 1024, count_);
}
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index fb736de..1bcbebd 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -21,8 +21,9 @@
namespace aos {
namespace message_bridge {
-SctpServer::SctpServer(int streams, std::string_view local_host,
- int local_port) {
+SctpServer::SctpServer(int streams, std::string_view local_host, int local_port,
+ std::vector<uint8_t> sctp_auth_key)
+ : sctp_(std::move(sctp_auth_key)) {
bool use_ipv6 = Ipv6Enabled();
sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
while (true) {
diff --git a/aos/network/sctp_server.h b/aos/network/sctp_server.h
index 2831a68..a8c70aa 100644
--- a/aos/network/sctp_server.h
+++ b/aos/network/sctp_server.h
@@ -24,7 +24,7 @@
class SctpServer {
public:
SctpServer(int streams, std::string_view local_host = "0.0.0.0",
- int local_port = 9971);
+ int local_port = 9971, std::vector<uint8_t> sctp_auth_key = {});
~SctpServer() {}
diff --git a/aos/network/sctp_test.cc b/aos/network/sctp_test.cc
new file mode 100644
index 0000000..edf13d4
--- /dev/null
+++ b/aos/network/sctp_test.cc
@@ -0,0 +1,298 @@
+#include <unistd.h>
+
+#include <chrono>
+#include <functional>
+
+#include "gflags/gflags.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+#include "aos/events/epoll.h"
+#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_lib.h"
+#include "aos/network/sctp_server.h"
+
+DECLARE_bool(disable_ipv6);
+
+namespace aos::message_bridge::testing {
+
+using ::aos::internal::EPoll;
+using ::aos::internal::TimerFd;
+using ::testing::ElementsAre;
+
+using namespace ::std::chrono_literals;
+
+constexpr int kPort = 19423;
+constexpr int kStreams = 1;
+
+namespace {
+void EnableSctpAuthIfAvailable() {
+#if HAS_SCTP_AUTH
+ CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || /sbin/sysctl "
+ "net.sctp.auth_enable=1") == 0)
+ << "Couldn't enable sctp authentication.";
+#endif
+}
+} // namespace
+
+// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
+// SctpClient), and an `sctp_notification` handler and a `message` handler. It
+// asynchronously routes incoming messages to the appropriate handler.
+template <typename T>
+class SctpReceiver {
+ public:
+ SctpReceiver(
+ EPoll &epoll, T &receiver,
+ std::function<void(T &, const union sctp_notification *)> on_notify,
+ std::function<void(T &, std::vector<uint8_t>)> on_message)
+ : epoll_(epoll),
+ receiver_(receiver),
+ on_notify_(std::move(on_notify)),
+ on_message_(std::move(on_message)) {
+ epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
+ }
+
+ ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
+
+ private:
+ // Handles an incoming message by routing it to the apropriate handler.
+ void Read() {
+ aos::unique_c_ptr<Message> message = receiver_.Read();
+ if (!message) {
+ return;
+ }
+
+ switch (message->message_type) {
+ case Message::kNotification: {
+ const union sctp_notification *notification =
+ reinterpret_cast<const union sctp_notification *>(message->data());
+ on_notify_(receiver_, notification);
+ break;
+ }
+ case Message::kMessage:
+ on_message_(receiver_, std::vector(message->data(),
+ message->data() + message->size));
+ break;
+ case Message::kOverflow:
+ LOG(FATAL) << "Overflow";
+ }
+ receiver_.FreeMessage(std::move(message));
+ }
+
+ EPoll &epoll_;
+ T &receiver_;
+ std::function<void(T &, const union sctp_notification *)> on_notify_;
+ std::function<void(T &, std::vector<uint8_t>)> on_message_;
+};
+
+// Base SctpTest class.
+//
+// The class provides a few virtual methods that should be overriden to define
+// the behavior of the test.
+class SctpTest : public ::testing::Test {
+ public:
+ SctpTest(std::vector<uint8_t> server_key = {},
+ std::vector<uint8_t> client_key = {},
+ std::chrono::milliseconds timeout = 1000ms)
+ : server_(kStreams, "", kPort, std::move(server_key)),
+ client_("localhost", kPort, kStreams, "", 0, std::move(client_key)),
+ client_receiver_(
+ epoll_, client_,
+ [this](SctpClient &client,
+ const union sctp_notification *notification) {
+ HandleNotification(client, notification);
+ },
+ [this](SctpClient &client, std::vector<uint8_t> message) {
+ HandleMessage(client, std::move(message));
+ }),
+ server_receiver_(
+ epoll_, server_,
+ [this](SctpServer &server,
+ const union sctp_notification *notification) {
+ HandleNotification(server, notification);
+ },
+ [this](SctpServer &server, std::vector<uint8_t> message) {
+ HandleMessage(server, std::move(message));
+ }) {
+ timeout_.SetTime(aos::monotonic_clock::now() + timeout,
+ std::chrono::milliseconds::zero());
+ epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
+ }
+
+ static void SetUpTestSuite() {
+ EnableSctpAuthIfAvailable();
+ // Buildkite seems to have issues with ipv6 sctp sockets...
+ FLAGS_disable_ipv6 = true;
+ }
+
+ void SetUp() override { Run(); }
+
+ protected:
+ // Handles a server notification message.
+ //
+ // The default behaviour is to track the sctp association ID.
+ virtual void HandleNotification(SctpServer &,
+ const union sctp_notification *notification) {
+ if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
+ assoc_ = notification->sn_assoc_change.sac_assoc_id;
+ }
+ }
+
+ // Handles the client notification message.
+ virtual void HandleNotification(SctpClient &,
+ const union sctp_notification *) {}
+
+ // Handles a server "data" message.
+ virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
+ // Handles a client "data" message.
+ virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
+
+ // Defines the "timeout" behaviour (fail by default).
+ virtual void TimeOut() {
+ Quit();
+ FAIL() << "Timer expired";
+ }
+
+ virtual ~SctpTest() {}
+
+ // Quit the test.
+ void Quit() {
+ epoll_.DeleteFd(timeout_.fd());
+ epoll_.Quit();
+ }
+ void Run() { epoll_.Run(); }
+
+ SctpServer server_;
+ SctpClient client_;
+ sctp_assoc_t assoc_ = 0;
+
+ private:
+ TimerFd timeout_;
+ EPoll epoll_;
+ SctpReceiver<SctpClient> client_receiver_;
+ SctpReceiver<SctpServer> server_receiver_;
+};
+
+// Verifies we can ping the server, and the server replies.
+class SctpPingPongTest : public SctpTest {
+ public:
+ SctpPingPongTest() : SctpTest({}, {}, /*timeout=*/2s) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &server,
+ std::vector<uint8_t> message) override {
+ // Server should receive a ping message.
+ EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+ got_ping_ = true;
+ ASSERT_NE(assoc_, 0);
+ // Reply with "pong".
+ server.Send("pong", assoc_, 0, 0);
+ }
+
+ void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+ // Client should receive a "pong" message.
+ EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+ got_pong_ = true;
+ // We are done.
+ Quit();
+ }
+ ~SctpPingPongTest() {
+ // Check that we got the ping/pong messages.
+ // This isn't strictly necessary as otherwise we would time out and fail
+ // anyway.
+ EXPECT_TRUE(got_ping_);
+ EXPECT_TRUE(got_pong_);
+ }
+
+ protected:
+ bool got_ping_ = false;
+ bool got_pong_ = false;
+};
+
+TEST_F(SctpPingPongTest, Test) {}
+
+#if HAS_SCTP_AUTH
+
+// Same as SctpPingPongTest but with authentication keys. Both keys are the
+// same so it should work the same way.
+class SctpAuthTest : public SctpTest {
+ public:
+ SctpAuthTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+ /*timeout*/ 20s) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &server,
+ std::vector<uint8_t> message) override {
+ // Server should receive a ping message.
+ EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+ got_ping_ = true;
+ ASSERT_NE(assoc_, 0);
+ // Reply with "pong".
+ server.Send("pong", assoc_, 0, 0);
+ }
+ void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+ // Client should receive a "pong" message.
+ EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+ got_pong_ = true;
+ // We are done.
+ Quit();
+ }
+ ~SctpAuthTest() {
+ EXPECT_TRUE(got_ping_);
+ EXPECT_TRUE(got_pong_);
+ }
+
+ protected:
+ bool got_ping_ = false;
+ bool got_pong_ = false;
+};
+
+TEST_F(SctpAuthTest, Test) {}
+
+// Keys don't match, we should send the `ping` message but the server should
+// never receive it. We then time out as nothing calls Quit.
+class SctpMismatchedAuthTest : public SctpTest {
+ public:
+ SctpMismatchedAuthTest() : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10}) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+ FAIL() << "Authentication keys don't match. Message should be discarded";
+ Quit();
+ }
+
+ // We expect to time out since we never get the message.
+ void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpMismatchedAuthTest, Test) {}
+
+// Same as SctpMismatchedAuthTest but the client uses the null key. We should
+// see the same behaviour.
+class SctpOneNullKeyTest : public SctpTest {
+ public:
+ SctpOneNullKeyTest() : SctpTest({1, 2, 3, 4, 5, 6}, {}) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+ FAIL() << "Authentication keys don't match. Message should be discarded";
+ Quit();
+ }
+
+ // We expect to time out since we never get the message.
+ void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpOneNullKeyTest, Test) {}
+#endif // HAS_SCTP_AUTH
+
+} // namespace aos::message_bridge::testing
diff --git a/aos/util/file.cc b/aos/util/file.cc
index 591539c..52657a9 100644
--- a/aos/util/file.cc
+++ b/aos/util/file.cc
@@ -47,6 +47,22 @@
return r;
}
+std::vector<uint8_t> ReadFileToVecOrDie(const std::string_view filename) {
+ std::vector<uint8_t> r;
+ ScopedFD fd(open(::std::string(filename).c_str(), O_RDONLY));
+ PCHECK(fd.get() != -1) << ": opening " << filename;
+ while (true) {
+ uint8_t buffer[1024];
+ const ssize_t result = read(fd.get(), buffer, sizeof(buffer));
+ PCHECK(result >= 0) << ": reading from " << filename;
+ if (result == 0) {
+ break;
+ }
+ std::copy(buffer, buffer + result, std::back_inserter(r));
+ }
+ return r;
+}
+
void WriteStringToFileOrDie(const std::string_view filename,
const std::string_view contents,
mode_t permissions) {
diff --git a/aos/util/file.h b/aos/util/file.h
index 0cba867..da5b29f 100644
--- a/aos/util/file.h
+++ b/aos/util/file.h
@@ -29,6 +29,10 @@
std::optional<std::string> MaybeReadFileToString(
const std::string_view filename);
+// Returns the complete contents of filename as a byte vector. LOG(FATAL)s if
+// any errors are encountered.
+std::vector<uint8_t> ReadFileToVecOrDie(const std::string_view filename);
+
// Creates filename if it doesn't exist and sets the contents to contents.
void WriteStringToFileOrDie(const std::string_view filename,
const std::string_view contents,
diff --git a/aos/util/file_test.cc b/aos/util/file_test.cc
index df16d58..ec4bfe4 100644
--- a/aos/util/file_test.cc
+++ b/aos/util/file_test.cc
@@ -4,6 +4,7 @@
#include <optional>
#include <string>
+#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"
#include "aos/realtime.h"
@@ -13,6 +14,8 @@
namespace util {
namespace testing {
+using ::testing::ElementsAre;
+
// Basic test of reading a normal file.
TEST(FileTest, ReadNormalFile) {
const std::string tmpdir(aos::testing::TestTmpDir());
@@ -21,6 +24,15 @@
EXPECT_EQ("contents\n", ReadFileToStringOrDie(test_file));
}
+// Basic test of reading a normal file.
+TEST(FileTest, ReadNormalFileToBytes) {
+ const std::string tmpdir(aos::testing::TestTmpDir());
+ const std::string test_file = tmpdir + "/test_file";
+ ASSERT_EQ(0, system(("echo contents > " + test_file).c_str()));
+ EXPECT_THAT(ReadFileToVecOrDie(test_file),
+ ElementsAre('c', 'o', 'n', 't', 'e', 'n', 't', 's', '\n'));
+}
+
// Tests reading a file with 0 size, among other weird things.
TEST(FileTest, ReadSpecialFile) {
const std::string stat = ReadFileToStringOrDie("/proc/self/stat");