Rework SCTP auth pipeline to allow dynamic key change
The SCTP key sharing mechanism won't be using a file to communicate
the active authentication key anymore as we will be receiving
it directly into message bridge through an AOS channel instead.
Change-Id: I46e079b98cbb6a0ed52fca36c67f7fa724ba249c
Signed-off-by: James Kuszmaul <james.kuszmaul@bluerivertech.com>
diff --git a/aos/network/BUILD b/aos/network/BUILD
index 6f98b6b..f683a28 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -157,6 +157,7 @@
"//aos:unique_malloc_ptr",
"//aos/util:file",
"@com_github_google_glog//:glog",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/aos/network/message_bridge_client.cc b/aos/network/message_bridge_client.cc
index afe2a90..82a59fb 100644
--- a/aos/network/message_bridge_client.cc
+++ b/aos/network/message_bridge_client.cc
@@ -9,12 +9,6 @@
DEFINE_string(config, "aos_config.json", "Path to the config.");
DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
-#if HAS_SCTP_AUTH
-DEFINE_string(sctp_auth_key_file, "",
- "When set, use the provided key for SCTP authentication as "
- "defined in RFC 4895. The file should be binary-encoded");
-#endif
-
namespace aos {
namespace message_bridge {
@@ -29,14 +23,7 @@
event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
}
- std::vector<uint8_t> auth_key;
-#if HAS_SCTP_AUTH
- if (!FLAGS_sctp_auth_key_file.empty()) {
- auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
- }
-#endif
- MessageBridgeClient app(&event_loop, Sha256(config.span()),
- std::move(auth_key));
+ MessageBridgeClient app(&event_loop, Sha256(config.span()));
logging::DynamicLogging dynamic_logging(&event_loop);
// TODO(austin): Save messages into a vector to be logged. One file per
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index 0f491fc..c23f748 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -100,8 +100,7 @@
aos::ShmEventLoop *const event_loop, std::string_view remote_name,
const Node *my_node, std::string_view local_host,
std::vector<SctpClientChannelState> *channels, int client_index,
- MessageBridgeClientStatus *client_status, std::string_view config_sha256,
- std::vector<uint8_t> auth_key)
+ MessageBridgeClientStatus *client_status, std::string_view config_sha256)
: event_loop_(event_loop),
connect_message_(MakeConnectMessage(event_loop->configuration(), my_node,
remote_name, event_loop->boot_uuid(),
@@ -112,7 +111,7 @@
client_(remote_node_->hostname()->string_view(), remote_node_->port(),
connect_message_.message().channels_to_transfer()->size() +
kControlStreams(),
- local_host, 0, std::move(auth_key)),
+ local_host, 0),
channels_(channels),
stream_to_channel_(
StreamToChannel(event_loop->configuration(), my_node, remote_node_)),
@@ -365,8 +364,7 @@
}
MessageBridgeClient::MessageBridgeClient(aos::ShmEventLoop *event_loop,
- std::string config_sha256,
- std::vector<uint8_t> auth_key)
+ std::string config_sha256)
: event_loop_(event_loop),
client_status_(event_loop_),
config_sha256_(std::move(config_sha256)) {
@@ -417,7 +415,7 @@
connections_.emplace_back(new SctpClientConnection(
event_loop, source_node, event_loop->node(), "", &channels_,
client_status_.FindClientIndex(source_node), &client_status_,
- config_sha256_, auth_key));
+ config_sha256_));
}
}
diff --git a/aos/network/message_bridge_client_lib.h b/aos/network/message_bridge_client_lib.h
index bb3bc1a..a5a69e3 100644
--- a/aos/network/message_bridge_client_lib.h
+++ b/aos/network/message_bridge_client_lib.h
@@ -38,8 +38,7 @@
std::vector<SctpClientChannelState> *channels,
int client_index,
MessageBridgeClientStatus *client_status,
- std::string_view config_sha256,
- std::vector<uint8_t> auth_key);
+ std::string_view config_sha256);
~SctpClientConnection() { event_loop_->epoll()->DeleteFd(client_.fd()); }
@@ -103,10 +102,7 @@
// node.
class MessageBridgeClient {
public:
- // When the `auth_key` byte-vector is non-empty, it will be used as the shared
- // key to authenticate every channel (See RFC4895 for more info).
- MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256,
- std::vector<uint8_t> auth_key);
+ MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256);
~MessageBridgeClient() {}
diff --git a/aos/network/message_bridge_server.cc b/aos/network/message_bridge_server.cc
index 4daf9c7..449333a 100644
--- a/aos/network/message_bridge_server.cc
+++ b/aos/network/message_bridge_server.cc
@@ -11,12 +11,6 @@
DEFINE_string(config, "aos_config.json", "Path to the config.");
DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
-#if HAS_SCTP_AUTH
-DEFINE_string(sctp_auth_key_file, "",
- "When set, use the provided key for SCTP authentication as "
- "defined in RFC 4895");
-#endif
-
namespace aos {
namespace message_bridge {
@@ -31,14 +25,7 @@
event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
}
- std::vector<uint8_t> auth_key;
-#if HAS_SCTP_AUTH
- if (!FLAGS_sctp_auth_key_file.empty()) {
- auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
- }
-#endif
- MessageBridgeServer app(&event_loop, Sha256(config.span()),
- std::move(auth_key));
+ MessageBridgeServer app(&event_loop, Sha256(config.span()));
logging::DynamicLogging dynamic_logging(&event_loop);
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index b925c6b..debfeeb 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -402,12 +402,11 @@
}
MessageBridgeServer::MessageBridgeServer(aos::ShmEventLoop *event_loop,
- std::string config_sha256,
- std::vector<uint8_t> auth_key)
+ std::string config_sha256)
: event_loop_(event_loop),
timestamp_loggers_(event_loop_),
server_(max_channels() + kControlStreams(), "",
- event_loop->node()->port(), auth_key),
+ event_loop->node()->port()),
server_status_(event_loop, [this]() { timestamp_state_->SendData(); }),
config_sha256_(std::move(config_sha256)),
allocator_(0) {
diff --git a/aos/network/message_bridge_server_lib.h b/aos/network/message_bridge_server_lib.h
index ed81de6..6a3fdd8 100644
--- a/aos/network/message_bridge_server_lib.h
+++ b/aos/network/message_bridge_server_lib.h
@@ -167,10 +167,7 @@
// node. It handles the session and dispatches data to the ChannelState.
class MessageBridgeServer {
public:
- // When the `auth_key` byte-vector is non-empty, it will be used as the shared
- // key to authenticate every channel (See RFC4895 for more info).
- MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256,
- std::vector<uint8_t> auth_key);
+ MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256);
// Delete copy/move constructors explicitly--we internally pass around
// pointers to internal state.
diff --git a/aos/network/message_bridge_test_lib.cc b/aos/network/message_bridge_test_lib.cc
index 226c13a..e818075 100644
--- a/aos/network/message_bridge_test_lib.cc
+++ b/aos/network/message_bridge_test_lib.cc
@@ -81,8 +81,7 @@
pi1_server_event_loop->SetRuntimeRealtimePriority(1);
pi1_message_bridge_server = std::make_unique<MessageBridgeServer>(
pi1_server_event_loop.get(),
- server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256,
- /*auth_key=*/std::vector<uint8_t>({}));
+ server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256);
}
void MessageBridgeParameterizedTest::RunPi1Server(
@@ -116,8 +115,7 @@
std::make_unique<aos::ShmEventLoop>(&config.message());
pi1_client_event_loop->SetRuntimeRealtimePriority(1);
pi1_message_bridge_client = std::make_unique<MessageBridgeClient>(
- pi1_client_event_loop.get(), config_sha256,
- /*auth_key=*/std::vector<uint8_t>({}));
+ pi1_client_event_loop.get(), config_sha256);
}
void MessageBridgeParameterizedTest::StartPi1Client() {
@@ -170,8 +168,7 @@
std::make_unique<aos::ShmEventLoop>(&config.message());
pi2_server_event_loop->SetRuntimeRealtimePriority(1);
pi2_message_bridge_server = std::make_unique<MessageBridgeServer>(
- pi2_server_event_loop.get(), config_sha256,
- /*auth_key=*/std::vector<uint8_t>({}));
+ pi2_server_event_loop.get(), config_sha256);
}
void MessageBridgeParameterizedTest::RunPi2Server(
@@ -205,8 +202,7 @@
std::make_unique<aos::ShmEventLoop>(&config.message());
pi2_client_event_loop->SetRuntimeRealtimePriority(1);
pi2_message_bridge_client = std::make_unique<MessageBridgeClient>(
- pi2_client_event_loop.get(), config_sha256,
- /*auth_key=*/std::vector<uint8_t>({}));
+ pi2_client_event_loop.get(), config_sha256);
}
void MessageBridgeParameterizedTest::RunPi2Client(
diff --git a/aos/network/sctp_client.cc b/aos/network/sctp_client.cc
index e044396..fa1828d 100644
--- a/aos/network/sctp_client.cc
+++ b/aos/network/sctp_client.cc
@@ -23,8 +23,8 @@
SctpClient::SctpClient(std::string_view remote_host, int remote_port,
int streams, std::string_view local_host, int local_port,
- std::vector<uint8_t> sctp_auth_key)
- : sctp_(std::move(sctp_auth_key)) {
+ SctpAuthMethod requested_authentication)
+ : sctp_(requested_authentication) {
bool use_ipv6 = Ipv6Enabled();
sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
sockaddr_remote_ = ResolveSocket(remote_host, remote_port, use_ipv6);
diff --git a/aos/network/sctp_client.h b/aos/network/sctp_client.h
index c6b6324..06f6b15 100644
--- a/aos/network/sctp_client.h
+++ b/aos/network/sctp_client.h
@@ -5,6 +5,7 @@
#include <cstdlib>
#include <string_view>
+#include "absl/types/span.h"
#include "glog/logging.h"
#include "aos/network/sctp_lib.h"
@@ -18,7 +19,7 @@
public:
SctpClient(std::string_view remote_host, int remote_port, int streams,
std::string_view local_host = "0.0.0.0", int local_port = 9971,
- std::vector<uint8_t> sctp_auth_key = {});
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth);
~SctpClient() {}
@@ -61,6 +62,10 @@
sctp_.FreeMessage(std::move(message));
}
+ void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+ sctp_.SetAuthKey(auth_key);
+ }
+
private:
struct sockaddr_storage sockaddr_remote_;
struct sockaddr_storage sockaddr_local_;
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 77cad02..c3c6e23 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -10,7 +10,10 @@
#include <unistd.h>
#include <algorithm>
+#include <cerrno>
+#include <fstream>
#include <string_view>
+#include <vector>
#include "aos/util/file.h"
@@ -31,9 +34,9 @@
struct sctp_sndrcvinfo sndrcvinfo;
} _sctp_cmsg_data_t;
+#if HAS_SCTP_AUTH
// Returns true if SCTP authentication is available and enabled.
bool SctpAuthIsEnabled() {
-#if HAS_SCTP_AUTH
struct stat current_stat;
if (stat("/proc/sys/net/sctp/auth_enable", ¤t_stat) != -1) {
int value = std::stoi(
@@ -45,11 +48,19 @@
LOG(WARNING) << "/proc/sys/net/sctp/auth_enable doesn't exist.";
return false;
}
-#else
- return false;
-#endif
}
+std::vector<uint8_t> GenerateSecureRandomSequence(size_t count) {
+ std::ifstream rng("/dev/random", std::ios::in | std::ios::binary);
+ CHECK(rng) << "Unable to open /dev/random";
+ std::vector<uint8_t> out(count, 0);
+ rng.read(reinterpret_cast<char *>(out.data()), count);
+ CHECK(rng) << "Couldn't read from random device";
+ rng.close();
+ return out;
+}
+#endif
+
} // namespace
bool Ipv6Enabled() {
@@ -289,28 +300,22 @@
sizeof(subscribe)) == 0);
}
- if (!auth_key_.empty()) {
- CHECK(SctpAuthIsEnabled())
- << "SCTP Authentication is disabled. Enable it with 'sysctl -w "
- "net.sctp.auth_enable=1' and try again.";
#if HAS_SCTP_AUTH
- // Set up the key with id `1`.
- sctp_authkey *const authkey =
- (sctp_authkey *)malloc(sizeof(sctp_authkey) + auth_key_.size());
- authkey->sca_keynumber = 1;
- authkey->sca_keylength = auth_key_.size();
- authkey->sca_assoc_id = SCTP_ALL_ASSOC;
- memcpy(&authkey->sca_key, auth_key_.data(), auth_key_.size());
+ if (sctp_authentication_) {
+ CHECK(SctpAuthIsEnabled())
+ << "SCTP Authentication key requested, but authentication isn't "
+ "enabled... Use `sysctl -w net.sctp.auth_enable=1` to enable";
- PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_KEY, authkey,
- sizeof(sctp_authkey) + auth_key_.size()) == 0);
- free(authkey);
+ // Unfortunately there's no way to delete the null key if we don't have
+ // another key active so this is the only way to prevent unauthenticated
+ // traffic until the real shared key is established.
+ SetAuthKey(GenerateSecureRandomSequence(16));
- // Set key `1` as active.
+ // Disallow the null key.
struct sctp_authkeyid authkeyid;
- authkeyid.scact_keynumber = 1;
+ authkeyid.scact_keynumber = 0;
authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
- PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_ACTIVE_KEY, &authkeyid,
+ PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_DELETE_KEY, &authkeyid,
sizeof(authkeyid)) == 0);
// Set up authentication for data chunks.
@@ -319,13 +324,8 @@
PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_CHUNK, &authchunk,
sizeof(authchunk)) == 0);
-
- // Disallow the null key.
- authkeyid.scact_keynumber = 0;
- PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_DELETE_KEY, &authkeyid,
- sizeof(authkeyid)) == 0);
-#endif
}
+#endif
DoSetMaxSize();
}
@@ -335,6 +335,8 @@
std::optional<struct sockaddr_storage> sockaddr_remote,
sctp_assoc_t snd_assoc_id) {
CHECK(fd_ != -1);
+ LOG_IF(FATAL, sctp_authentication_ && current_key_.empty())
+ << "Expected SCTP authentication but no key active";
struct iovec iov;
iov.iov_base = const_cast<char *>(data.data());
iov.iov_len = data.size();
@@ -392,9 +394,6 @@
return true;
}
-SctpReadWrite::SctpReadWrite(std::vector<uint8_t> auth_key)
- : auth_key_(std::move(auth_key)) {}
-
void SctpReadWrite::FreeMessage(aos::unique_c_ptr<Message> &&message) {
if (use_pool_) {
free_messages_.emplace_back(std::move(message));
@@ -432,6 +431,8 @@
// fragmented. If we do end up with a fragment, then we copy the data out of it.
aos::unique_c_ptr<Message> SctpReadWrite::ReadMessage() {
CHECK(fd_ != -1);
+ LOG_IF(FATAL, sctp_authentication_ && current_key_.empty())
+ << "Expected SCTP authentication but no key active";
while (true) {
aos::unique_c_ptr<Message> result = AcquireMessage();
@@ -697,6 +698,55 @@
return false;
}
+void SctpReadWrite::SetAuthKey(absl::Span<const uint8_t> auth_key) {
+ PCHECK(fd_ != -1);
+ if (auth_key.empty()) {
+ return;
+ }
+ // We are already using the key, nothing to do.
+ if (auth_key == current_key_) {
+ return;
+ }
+#if !(HAS_SCTP_AUTH)
+ LOG(FATAL) << "SCTP Authentication key requested, but authentication isn't "
+ "available... You may need a newer kernel";
+#else
+ LOG_IF(FATAL, !SctpAuthIsEnabled())
+ << "SCTP Authentication key requested, but authentication isn't "
+ "enabled... Use `sysctl -w net.sctp.auth_enable=1` to enable";
+ // Set up the key with id `1`.
+ std::unique_ptr<sctp_authkey> authkey(
+ (sctp_authkey *)malloc(sizeof(sctp_authkey) + auth_key.size()));
+
+ authkey->sca_keynumber = 1;
+ authkey->sca_keylength = auth_key.size();
+ authkey->sca_assoc_id = SCTP_ALL_ASSOC;
+ memcpy(&authkey->sca_key, auth_key.data(), auth_key.size());
+
+ if (setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_KEY, authkey.get(),
+ sizeof(sctp_authkey) + auth_key.size()) != 0) {
+ if (errno == EACCES) {
+ // TODO(adam.snaider): Figure out why this fails when expected nodes are
+ // not connected.
+ PLOG_EVERY_N(ERROR, 100) << "Setting authentication key failed";
+ return;
+ } else {
+ PLOG(FATAL) << "Setting authentication key failed";
+ }
+ }
+
+ // Set key `1` as active.
+ struct sctp_authkeyid authkeyid;
+ authkeyid.scact_keynumber = 1;
+ authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
+ if (setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_ACTIVE_KEY, &authkeyid,
+ sizeof(authkeyid)) != 0) {
+ PLOG(FATAL) << "Setting key id `1` as active failed";
+ }
+ current_key_.assign(auth_key.begin(), auth_key.end());
+#endif
+} // namespace message_bridge
+
void Message::LogRcvInfo() const {
LOG(INFO) << "\tSNDRCV (stream=" << header.rcvinfo.rcv_sid
<< " ssn=" << header.rcvinfo.rcv_ssn
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index f78934a..0d021a9 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -11,6 +11,7 @@
#include <string_view>
#include <vector>
+#include "absl/types/span.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
@@ -21,6 +22,8 @@
namespace aos {
namespace message_bridge {
+constexpr bool HasSctpAuth() { return HAS_SCTP_AUTH; }
+
// Check if ipv6 is enabled.
// If we don't try IPv6, and omit AI_ADDRCONFIG when resolving addresses, the
// library will happily resolve nodes to IPv6 IPs that can't be used. If we add
@@ -92,10 +95,31 @@
// Gets and logs the contents of the sctp_status message.
void LogSctpStatus(int fd, sctp_assoc_t assoc_id);
+// Authentication method used for the SCTP socket.
+enum class SctpAuthMethod {
+ // Use unauthenticated sockets.
+ kNoAuth,
+ // Use RFC4895 authentication for SCTP.
+ kAuth,
+};
+
// Manages reading and writing SCTP messages.
class SctpReadWrite {
public:
- SctpReadWrite(std::vector<uint8_t> auth_key = {});
+ // When `requested_authentication` is kAuth, it will use SCTP authentication
+ // if it's provided by the kernel. Note that this will ignore the value of
+ // `requested_authentication` if the kernel is too old and will fall back to
+ // an unauthenticated channel.
+ SctpReadWrite(
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth)
+ : sctp_authentication_(HasSctpAuth() ? requested_authentication ==
+ SctpAuthMethod::kAuth
+ : false) {
+ LOG_IF(WARNING,
+ requested_authentication == SctpAuthMethod::kAuth && !HasSctpAuth())
+ << "SCTP authentication requested but not provided by the kernel... "
+ "You may need a newer kernel";
+ }
~SctpReadWrite() { CloseSocket(); }
// Opens a new socket.
@@ -142,6 +166,9 @@
// Allocates messages for the pool. SetMaxSize must be set first.
void SetPoolSize(size_t pool_size);
+ // Set the active authentication key to `auth_key`.
+ void SetAuthKey(absl::Span<const uint8_t> auth_key);
+
private:
aos::unique_c_ptr<Message> AcquireMessage();
@@ -165,7 +192,9 @@
bool use_pool_ = false;
std::vector<aos::unique_c_ptr<Message>> free_messages_;
- std::vector<uint8_t> auth_key_;
+ // Use SCTP authentication (RFC4895).
+ bool sctp_authentication_;
+ std::vector<uint8_t> current_key_;
};
// Returns the max network buffer available for reading for a socket.
diff --git a/aos/network/sctp_perf.cc b/aos/network/sctp_perf.cc
index 3bafed1..5201f47 100644
--- a/aos/network/sctp_perf.cc
+++ b/aos/network/sctp_perf.cc
@@ -22,11 +22,9 @@
DEFINE_uint32(skip_first_n, 10,
"Skip the first 'n' messages when computing statistics.");
-#if HAS_SCTP_AUTH
DEFINE_string(sctp_auth_key_file, "",
"When set, use the provided key for SCTP authentication as "
"defined in RFC 4895");
-#endif
DECLARE_bool(die_on_malloc);
@@ -36,13 +34,16 @@
using util::ReadFileToVecOrDie;
+SctpAuthMethod SctpAuthMethod() {
+ return FLAGS_sctp_auth_key_file.empty() ? SctpAuthMethod::kNoAuth
+ : SctpAuthMethod::kAuth;
+}
+
std::vector<uint8_t> GetSctpAuthKey() {
-#if HAS_SCTP_AUTH
- if (!FLAGS_sctp_auth_key_file.empty()) {
- return ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+ if (SctpAuthMethod() == SctpAuthMethod::kNoAuth) {
+ return {};
}
-#endif
- return {};
+ return ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
}
} // namespace
@@ -53,7 +54,8 @@
public:
Server(aos::ShmEventLoop *event_loop)
: event_loop_(event_loop),
- server_(2, "0.0.0.0", FLAGS_port, GetSctpAuthKey()) {
+ server_(2, "0.0.0.0", FLAGS_port, SctpAuthMethod()) {
+ server_.SetAuthKey(GetSctpAuthKey());
event_loop_->epoll()->OnReadable(server_.fd(),
[this]() { MessageReceived(); });
server_.SetMaxReadSize(FLAGS_rx_size + 100);
@@ -134,7 +136,8 @@
Client(aos::ShmEventLoop *event_loop)
: event_loop_(event_loop),
client_(FLAGS_host, FLAGS_port, 2, "0.0.0.0", FLAGS_port,
- GetSctpAuthKey()) {
+ SctpAuthMethod()) {
+ client_.SetAuthKey(GetSctpAuthKey());
client_.SetMaxReadSize(FLAGS_rx_size + 100);
client_.SetMaxWriteSize(FLAGS_rx_size + 100);
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index 1bcbebd..f90b21b 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -22,8 +22,8 @@
namespace message_bridge {
SctpServer::SctpServer(int streams, std::string_view local_host, int local_port,
- std::vector<uint8_t> sctp_auth_key)
- : sctp_(std::move(sctp_auth_key)) {
+ SctpAuthMethod requested_authentication)
+ : sctp_(requested_authentication) {
bool use_ipv6 = Ipv6Enabled();
sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
while (true) {
diff --git a/aos/network/sctp_server.h b/aos/network/sctp_server.h
index a8c70aa..e641fd8 100644
--- a/aos/network/sctp_server.h
+++ b/aos/network/sctp_server.h
@@ -13,6 +13,7 @@
#include <cstring>
#include <memory>
+#include "absl/types/span.h"
#include "glog/logging.h"
#include "aos/network/sctp_lib.h"
@@ -24,7 +25,8 @@
class SctpServer {
public:
SctpServer(int streams, std::string_view local_host = "0.0.0.0",
- int local_port = 9971, std::vector<uint8_t> sctp_auth_key = {});
+ int local_port = 9971,
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth);
~SctpServer() {}
@@ -63,6 +65,10 @@
void SetPoolSize(size_t pool_size) { sctp_.SetPoolSize(pool_size); }
+ void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+ sctp_.SetAuthKey(auth_key);
+ }
+
private:
struct sockaddr_storage sockaddr_local_;
SctpReadWrite sctp_;
diff --git a/aos/network/sctp_test.cc b/aos/network/sctp_test.cc
index edf13d4..8e332e4 100644
--- a/aos/network/sctp_test.cc
+++ b/aos/network/sctp_test.cc
@@ -93,9 +93,10 @@
public:
SctpTest(std::vector<uint8_t> server_key = {},
std::vector<uint8_t> client_key = {},
+ SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
std::chrono::milliseconds timeout = 1000ms)
- : server_(kStreams, "", kPort, std::move(server_key)),
- client_("localhost", kPort, kStreams, "", 0, std::move(client_key)),
+ : server_(kStreams, "", kPort, requested_authentication),
+ client_("localhost", kPort, kStreams, "", 0, requested_authentication),
client_receiver_(
epoll_, client_,
[this](SctpClient &client,
@@ -114,6 +115,8 @@
[this](SctpServer &server, std::vector<uint8_t> message) {
HandleMessage(server, std::move(message));
}) {
+ server_.SetAuthKey(server_key);
+ client_.SetAuthKey(client_key);
timeout_.SetTime(aos::monotonic_clock::now() + timeout,
std::chrono::milliseconds::zero());
epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
@@ -176,7 +179,8 @@
// Verifies we can ping the server, and the server replies.
class SctpPingPongTest : public SctpTest {
public:
- SctpPingPongTest() : SctpTest({}, {}, /*timeout=*/2s) {
+ SctpPingPongTest()
+ : SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
@@ -220,7 +224,7 @@
class SctpAuthTest : public SctpTest {
public:
SctpAuthTest()
- : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+ : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
/*timeout*/ 20s) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
@@ -254,11 +258,54 @@
TEST_F(SctpAuthTest, Test) {}
+// Tests that we can dynamically change the SCTP authentication key used.
+class SctpChangingAuthKeysTest : public SctpTest {
+ public:
+ SctpChangingAuthKeysTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+ SctpAuthMethod::kAuth) {
+ // Start by having the client send "ping".
+ client_.SetAuthKey({5, 4, 3, 2, 1});
+ server_.SetAuthKey({5, 4, 3, 2, 1});
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &server,
+ std::vector<uint8_t> message) override {
+ // Server should receive a ping message.
+ EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+ got_ping_ = true;
+ ASSERT_NE(assoc_, 0);
+ // Reply with "pong".
+ server.Send("pong", assoc_, 0, 0);
+ }
+ void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+ // Client should receive a "pong" message.
+ EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+ got_pong_ = true;
+ // We are done.
+ Quit();
+ }
+
+ ~SctpChangingAuthKeysTest() {
+ EXPECT_TRUE(got_ping_);
+ EXPECT_TRUE(got_pong_);
+ }
+
+ protected:
+ bool got_ping_ = false;
+ bool got_pong_ = false;
+};
+
+TEST_F(SctpChangingAuthKeysTest, Test) {}
+
// Keys don't match, we should send the `ping` message but the server should
// never receive it. We then time out as nothing calls Quit.
class SctpMismatchedAuthTest : public SctpTest {
public:
- SctpMismatchedAuthTest() : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10}) {
+ SctpMismatchedAuthTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
+ SctpAuthMethod::kAuth) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
@@ -278,7 +325,8 @@
// see the same behaviour.
class SctpOneNullKeyTest : public SctpTest {
public:
- SctpOneNullKeyTest() : SctpTest({1, 2, 3, 4, 5, 6}, {}) {
+ SctpOneNullKeyTest()
+ : SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
@@ -293,6 +341,27 @@
};
TEST_F(SctpOneNullKeyTest, Test) {}
+
+// If we want SCTP authentication but we don't set the auth keys, we shouldn't
+// be able to send data.
+class SctpAuthKeysNotSet : public SctpTest {
+ public:
+ SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
+ // Start by having the client send "ping".
+ client_.Send(0, "ping", 0);
+ }
+
+ void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+ FAIL() << "Haven't setup authentication keys. Should not get message.";
+ Quit();
+ }
+
+ // We expect to time out since we never get the message.
+ void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpAuthKeysNotSet, Test) {}
+
#endif // HAS_SCTP_AUTH
} // namespace aos::message_bridge::testing