Rework SCTP auth pipeline to allow dynamic key change

The SCTP key sharing mechanism won't be using a file to communicate
the active authentication key anymore as we will be receiving
it directly into message bridge through an AOS channel instead.

Change-Id: I46e079b98cbb6a0ed52fca36c67f7fa724ba249c
Signed-off-by: James Kuszmaul <james.kuszmaul@bluerivertech.com>
diff --git a/aos/network/BUILD b/aos/network/BUILD
index 6f98b6b..f683a28 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -157,6 +157,7 @@
         "//aos:unique_malloc_ptr",
         "//aos/util:file",
         "@com_github_google_glog//:glog",
+        "@com_google_absl//absl/types:span",
     ],
 )
 
diff --git a/aos/network/message_bridge_client.cc b/aos/network/message_bridge_client.cc
index afe2a90..82a59fb 100644
--- a/aos/network/message_bridge_client.cc
+++ b/aos/network/message_bridge_client.cc
@@ -9,12 +9,6 @@
 DEFINE_string(config, "aos_config.json", "Path to the config.");
 DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
 
-#if HAS_SCTP_AUTH
-DEFINE_string(sctp_auth_key_file, "",
-              "When set, use the provided key for SCTP authentication as "
-              "defined in RFC 4895. The file should be binary-encoded");
-#endif
-
 namespace aos {
 namespace message_bridge {
 
@@ -29,14 +23,7 @@
     event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
   }
 
-  std::vector<uint8_t> auth_key;
-#if HAS_SCTP_AUTH
-  if (!FLAGS_sctp_auth_key_file.empty()) {
-    auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
-  }
-#endif
-  MessageBridgeClient app(&event_loop, Sha256(config.span()),
-                          std::move(auth_key));
+  MessageBridgeClient app(&event_loop, Sha256(config.span()));
 
   logging::DynamicLogging dynamic_logging(&event_loop);
   // TODO(austin): Save messages into a vector to be logged.  One file per
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index 0f491fc..c23f748 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -100,8 +100,7 @@
     aos::ShmEventLoop *const event_loop, std::string_view remote_name,
     const Node *my_node, std::string_view local_host,
     std::vector<SctpClientChannelState> *channels, int client_index,
-    MessageBridgeClientStatus *client_status, std::string_view config_sha256,
-    std::vector<uint8_t> auth_key)
+    MessageBridgeClientStatus *client_status, std::string_view config_sha256)
     : event_loop_(event_loop),
       connect_message_(MakeConnectMessage(event_loop->configuration(), my_node,
                                           remote_name, event_loop->boot_uuid(),
@@ -112,7 +111,7 @@
       client_(remote_node_->hostname()->string_view(), remote_node_->port(),
               connect_message_.message().channels_to_transfer()->size() +
                   kControlStreams(),
-              local_host, 0, std::move(auth_key)),
+              local_host, 0),
       channels_(channels),
       stream_to_channel_(
           StreamToChannel(event_loop->configuration(), my_node, remote_node_)),
@@ -365,8 +364,7 @@
 }
 
 MessageBridgeClient::MessageBridgeClient(aos::ShmEventLoop *event_loop,
-                                         std::string config_sha256,
-                                         std::vector<uint8_t> auth_key)
+                                         std::string config_sha256)
     : event_loop_(event_loop),
       client_status_(event_loop_),
       config_sha256_(std::move(config_sha256)) {
@@ -417,7 +415,7 @@
     connections_.emplace_back(new SctpClientConnection(
         event_loop, source_node, event_loop->node(), "", &channels_,
         client_status_.FindClientIndex(source_node), &client_status_,
-        config_sha256_, auth_key));
+        config_sha256_));
   }
 }
 
diff --git a/aos/network/message_bridge_client_lib.h b/aos/network/message_bridge_client_lib.h
index bb3bc1a..a5a69e3 100644
--- a/aos/network/message_bridge_client_lib.h
+++ b/aos/network/message_bridge_client_lib.h
@@ -38,8 +38,7 @@
                        std::vector<SctpClientChannelState> *channels,
                        int client_index,
                        MessageBridgeClientStatus *client_status,
-                       std::string_view config_sha256,
-                       std::vector<uint8_t> auth_key);
+                       std::string_view config_sha256);
 
   ~SctpClientConnection() { event_loop_->epoll()->DeleteFd(client_.fd()); }
 
@@ -103,10 +102,7 @@
 // node.
 class MessageBridgeClient {
  public:
-  // When the `auth_key` byte-vector is non-empty, it will be used as the shared
-  // key to authenticate every channel (See RFC4895 for more info).
-  MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256,
-                      std::vector<uint8_t> auth_key);
+  MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256);
 
   ~MessageBridgeClient() {}
 
diff --git a/aos/network/message_bridge_server.cc b/aos/network/message_bridge_server.cc
index 4daf9c7..449333a 100644
--- a/aos/network/message_bridge_server.cc
+++ b/aos/network/message_bridge_server.cc
@@ -11,12 +11,6 @@
 DEFINE_string(config, "aos_config.json", "Path to the config.");
 DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
 
-#if HAS_SCTP_AUTH
-DEFINE_string(sctp_auth_key_file, "",
-              "When set, use the provided key for SCTP authentication as "
-              "defined in RFC 4895");
-#endif
-
 namespace aos {
 namespace message_bridge {
 
@@ -31,14 +25,7 @@
     event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
   }
 
-  std::vector<uint8_t> auth_key;
-#if HAS_SCTP_AUTH
-  if (!FLAGS_sctp_auth_key_file.empty()) {
-    auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
-  }
-#endif
-  MessageBridgeServer app(&event_loop, Sha256(config.span()),
-                          std::move(auth_key));
+  MessageBridgeServer app(&event_loop, Sha256(config.span()));
 
   logging::DynamicLogging dynamic_logging(&event_loop);
 
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index b925c6b..debfeeb 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -402,12 +402,11 @@
 }
 
 MessageBridgeServer::MessageBridgeServer(aos::ShmEventLoop *event_loop,
-                                         std::string config_sha256,
-                                         std::vector<uint8_t> auth_key)
+                                         std::string config_sha256)
     : event_loop_(event_loop),
       timestamp_loggers_(event_loop_),
       server_(max_channels() + kControlStreams(), "",
-              event_loop->node()->port(), auth_key),
+              event_loop->node()->port()),
       server_status_(event_loop, [this]() { timestamp_state_->SendData(); }),
       config_sha256_(std::move(config_sha256)),
       allocator_(0) {
diff --git a/aos/network/message_bridge_server_lib.h b/aos/network/message_bridge_server_lib.h
index ed81de6..6a3fdd8 100644
--- a/aos/network/message_bridge_server_lib.h
+++ b/aos/network/message_bridge_server_lib.h
@@ -167,10 +167,7 @@
 // node.  It handles the session and dispatches data to the ChannelState.
 class MessageBridgeServer {
  public:
-  // When the `auth_key` byte-vector is non-empty, it will be used as the shared
-  // key to authenticate every channel (See RFC4895 for more info).
-  MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256,
-                      std::vector<uint8_t> auth_key);
+  MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256);
 
   // Delete copy/move constructors explicitly--we internally pass around
   // pointers to internal state.
diff --git a/aos/network/message_bridge_test_lib.cc b/aos/network/message_bridge_test_lib.cc
index 226c13a..e818075 100644
--- a/aos/network/message_bridge_test_lib.cc
+++ b/aos/network/message_bridge_test_lib.cc
@@ -81,8 +81,7 @@
   pi1_server_event_loop->SetRuntimeRealtimePriority(1);
   pi1_message_bridge_server = std::make_unique<MessageBridgeServer>(
       pi1_server_event_loop.get(),
-      server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256,
-      /*auth_key=*/std::vector<uint8_t>({}));
+      server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256);
 }
 
 void MessageBridgeParameterizedTest::RunPi1Server(
@@ -116,8 +115,7 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi1_client_event_loop->SetRuntimeRealtimePriority(1);
   pi1_message_bridge_client = std::make_unique<MessageBridgeClient>(
-      pi1_client_event_loop.get(), config_sha256,
-      /*auth_key=*/std::vector<uint8_t>({}));
+      pi1_client_event_loop.get(), config_sha256);
 }
 
 void MessageBridgeParameterizedTest::StartPi1Client() {
@@ -170,8 +168,7 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi2_server_event_loop->SetRuntimeRealtimePriority(1);
   pi2_message_bridge_server = std::make_unique<MessageBridgeServer>(
-      pi2_server_event_loop.get(), config_sha256,
-      /*auth_key=*/std::vector<uint8_t>({}));
+      pi2_server_event_loop.get(), config_sha256);
 }
 
 void MessageBridgeParameterizedTest::RunPi2Server(
@@ -205,8 +202,7 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi2_client_event_loop->SetRuntimeRealtimePriority(1);
   pi2_message_bridge_client = std::make_unique<MessageBridgeClient>(
-      pi2_client_event_loop.get(), config_sha256,
-      /*auth_key=*/std::vector<uint8_t>({}));
+      pi2_client_event_loop.get(), config_sha256);
 }
 
 void MessageBridgeParameterizedTest::RunPi2Client(
diff --git a/aos/network/sctp_client.cc b/aos/network/sctp_client.cc
index e044396..fa1828d 100644
--- a/aos/network/sctp_client.cc
+++ b/aos/network/sctp_client.cc
@@ -23,8 +23,8 @@
 
 SctpClient::SctpClient(std::string_view remote_host, int remote_port,
                        int streams, std::string_view local_host, int local_port,
-                       std::vector<uint8_t> sctp_auth_key)
-    : sctp_(std::move(sctp_auth_key)) {
+                       SctpAuthMethod requested_authentication)
+    : sctp_(requested_authentication) {
   bool use_ipv6 = Ipv6Enabled();
   sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
   sockaddr_remote_ = ResolveSocket(remote_host, remote_port, use_ipv6);
diff --git a/aos/network/sctp_client.h b/aos/network/sctp_client.h
index c6b6324..06f6b15 100644
--- a/aos/network/sctp_client.h
+++ b/aos/network/sctp_client.h
@@ -5,6 +5,7 @@
 #include <cstdlib>
 #include <string_view>
 
+#include "absl/types/span.h"
 #include "glog/logging.h"
 
 #include "aos/network/sctp_lib.h"
@@ -18,7 +19,7 @@
  public:
   SctpClient(std::string_view remote_host, int remote_port, int streams,
              std::string_view local_host = "0.0.0.0", int local_port = 9971,
-             std::vector<uint8_t> sctp_auth_key = {});
+             SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth);
 
   ~SctpClient() {}
 
@@ -61,6 +62,10 @@
     sctp_.FreeMessage(std::move(message));
   }
 
+  void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+    sctp_.SetAuthKey(auth_key);
+  }
+
  private:
   struct sockaddr_storage sockaddr_remote_;
   struct sockaddr_storage sockaddr_local_;
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 77cad02..c3c6e23 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -10,7 +10,10 @@
 #include <unistd.h>
 
 #include <algorithm>
+#include <cerrno>
+#include <fstream>
 #include <string_view>
+#include <vector>
 
 #include "aos/util/file.h"
 
@@ -31,9 +34,9 @@
   struct sctp_sndrcvinfo sndrcvinfo;
 } _sctp_cmsg_data_t;
 
+#if HAS_SCTP_AUTH
 // Returns true if SCTP authentication is available and enabled.
 bool SctpAuthIsEnabled() {
-#if HAS_SCTP_AUTH
   struct stat current_stat;
   if (stat("/proc/sys/net/sctp/auth_enable", &current_stat) != -1) {
     int value = std::stoi(
@@ -45,11 +48,19 @@
     LOG(WARNING) << "/proc/sys/net/sctp/auth_enable doesn't exist.";
     return false;
   }
-#else
-  return false;
-#endif
 }
 
+std::vector<uint8_t> GenerateSecureRandomSequence(size_t count) {
+  std::ifstream rng("/dev/random", std::ios::in | std::ios::binary);
+  CHECK(rng) << "Unable to open /dev/random";
+  std::vector<uint8_t> out(count, 0);
+  rng.read(reinterpret_cast<char *>(out.data()), count);
+  CHECK(rng) << "Couldn't read from random device";
+  rng.close();
+  return out;
+}
+#endif
+
 }  // namespace
 
 bool Ipv6Enabled() {
@@ -289,28 +300,22 @@
                       sizeof(subscribe)) == 0);
   }
 
-  if (!auth_key_.empty()) {
-    CHECK(SctpAuthIsEnabled())
-        << "SCTP Authentication is disabled. Enable it with 'sysctl -w "
-           "net.sctp.auth_enable=1' and try again.";
 #if HAS_SCTP_AUTH
-    // Set up the key with id `1`.
-    sctp_authkey *const authkey =
-        (sctp_authkey *)malloc(sizeof(sctp_authkey) + auth_key_.size());
-    authkey->sca_keynumber = 1;
-    authkey->sca_keylength = auth_key_.size();
-    authkey->sca_assoc_id = SCTP_ALL_ASSOC;
-    memcpy(&authkey->sca_key, auth_key_.data(), auth_key_.size());
+  if (sctp_authentication_) {
+    CHECK(SctpAuthIsEnabled())
+        << "SCTP Authentication key requested, but authentication isn't "
+           "enabled... Use `sysctl -w net.sctp.auth_enable=1` to enable";
 
-    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_KEY, authkey,
-                      sizeof(sctp_authkey) + auth_key_.size()) == 0);
-    free(authkey);
+    // Unfortunately there's no way to delete the null key if we don't have
+    // another key active so this is the only way to prevent unauthenticated
+    // traffic until the real shared key is established.
+    SetAuthKey(GenerateSecureRandomSequence(16));
 
-    // Set key `1` as active.
+    // Disallow the null key.
     struct sctp_authkeyid authkeyid;
-    authkeyid.scact_keynumber = 1;
+    authkeyid.scact_keynumber = 0;
     authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
-    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_ACTIVE_KEY, &authkeyid,
+    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_DELETE_KEY, &authkeyid,
                       sizeof(authkeyid)) == 0);
 
     // Set up authentication for data chunks.
@@ -319,13 +324,8 @@
 
     PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_CHUNK, &authchunk,
                       sizeof(authchunk)) == 0);
-
-    // Disallow the null key.
-    authkeyid.scact_keynumber = 0;
-    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_DELETE_KEY, &authkeyid,
-                      sizeof(authkeyid)) == 0);
-#endif
   }
+#endif
 
   DoSetMaxSize();
 }
@@ -335,6 +335,8 @@
     std::optional<struct sockaddr_storage> sockaddr_remote,
     sctp_assoc_t snd_assoc_id) {
   CHECK(fd_ != -1);
+  LOG_IF(FATAL, sctp_authentication_ && current_key_.empty())
+      << "Expected SCTP authentication but no key active";
   struct iovec iov;
   iov.iov_base = const_cast<char *>(data.data());
   iov.iov_len = data.size();
@@ -392,9 +394,6 @@
   return true;
 }
 
-SctpReadWrite::SctpReadWrite(std::vector<uint8_t> auth_key)
-    : auth_key_(std::move(auth_key)) {}
-
 void SctpReadWrite::FreeMessage(aos::unique_c_ptr<Message> &&message) {
   if (use_pool_) {
     free_messages_.emplace_back(std::move(message));
@@ -432,6 +431,8 @@
 // fragmented. If we do end up with a fragment, then we copy the data out of it.
 aos::unique_c_ptr<Message> SctpReadWrite::ReadMessage() {
   CHECK(fd_ != -1);
+  LOG_IF(FATAL, sctp_authentication_ && current_key_.empty())
+      << "Expected SCTP authentication but no key active";
 
   while (true) {
     aos::unique_c_ptr<Message> result = AcquireMessage();
@@ -697,6 +698,55 @@
   return false;
 }
 
+void SctpReadWrite::SetAuthKey(absl::Span<const uint8_t> auth_key) {
+  PCHECK(fd_ != -1);
+  if (auth_key.empty()) {
+    return;
+  }
+  // We are already using the key, nothing to do.
+  if (auth_key == current_key_) {
+    return;
+  }
+#if !(HAS_SCTP_AUTH)
+  LOG(FATAL) << "SCTP Authentication key requested, but authentication isn't "
+                "available... You may need a newer kernel";
+#else
+  LOG_IF(FATAL, !SctpAuthIsEnabled())
+      << "SCTP Authentication key requested, but authentication isn't "
+         "enabled... Use `sysctl -w net.sctp.auth_enable=1` to enable";
+  // Set up the key with id `1`.
+  std::unique_ptr<sctp_authkey> authkey(
+      (sctp_authkey *)malloc(sizeof(sctp_authkey) + auth_key.size()));
+
+  authkey->sca_keynumber = 1;
+  authkey->sca_keylength = auth_key.size();
+  authkey->sca_assoc_id = SCTP_ALL_ASSOC;
+  memcpy(&authkey->sca_key, auth_key.data(), auth_key.size());
+
+  if (setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_KEY, authkey.get(),
+                 sizeof(sctp_authkey) + auth_key.size()) != 0) {
+    if (errno == EACCES) {
+      // TODO(adam.snaider): Figure out why this fails when expected nodes are
+      // not connected.
+      PLOG_EVERY_N(ERROR, 100) << "Setting authentication key failed";
+      return;
+    } else {
+      PLOG(FATAL) << "Setting authentication key failed";
+    }
+  }
+
+  // Set key `1` as active.
+  struct sctp_authkeyid authkeyid;
+  authkeyid.scact_keynumber = 1;
+  authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
+  if (setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_ACTIVE_KEY, &authkeyid,
+                 sizeof(authkeyid)) != 0) {
+    PLOG(FATAL) << "Setting key id `1` as active failed";
+  }
+  current_key_.assign(auth_key.begin(), auth_key.end());
+#endif
+}  // namespace message_bridge
+
 void Message::LogRcvInfo() const {
   LOG(INFO) << "\tSNDRCV (stream=" << header.rcvinfo.rcv_sid
             << " ssn=" << header.rcvinfo.rcv_ssn
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index f78934a..0d021a9 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -11,6 +11,7 @@
 #include <string_view>
 #include <vector>
 
+#include "absl/types/span.h"
 #include "gflags/gflags.h"
 #include "glog/logging.h"
 
@@ -21,6 +22,8 @@
 namespace aos {
 namespace message_bridge {
 
+constexpr bool HasSctpAuth() { return HAS_SCTP_AUTH; }
+
 // Check if ipv6 is enabled.
 // If we don't try IPv6, and omit AI_ADDRCONFIG when resolving addresses, the
 // library will happily resolve nodes to IPv6 IPs that can't be used. If we add
@@ -92,10 +95,31 @@
 // Gets and logs the contents of the sctp_status message.
 void LogSctpStatus(int fd, sctp_assoc_t assoc_id);
 
+// Authentication method used for the SCTP socket.
+enum class SctpAuthMethod {
+  // Use unauthenticated sockets.
+  kNoAuth,
+  // Use RFC4895 authentication for SCTP.
+  kAuth,
+};
+
 // Manages reading and writing SCTP messages.
 class SctpReadWrite {
  public:
-  SctpReadWrite(std::vector<uint8_t> auth_key = {});
+  // When `requested_authentication` is kAuth, it will use SCTP authentication
+  // if it's provided by the kernel. Note that this will ignore the value of
+  // `requested_authentication` if the kernel is too old and will fall back to
+  // an unauthenticated channel.
+  SctpReadWrite(
+      SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth)
+      : sctp_authentication_(HasSctpAuth() ? requested_authentication ==
+                                                 SctpAuthMethod::kAuth
+                                           : false) {
+    LOG_IF(WARNING,
+           requested_authentication == SctpAuthMethod::kAuth && !HasSctpAuth())
+        << "SCTP authentication requested but not provided by the kernel... "
+           "You may need a newer kernel";
+  }
   ~SctpReadWrite() { CloseSocket(); }
 
   // Opens a new socket.
@@ -142,6 +166,9 @@
   // Allocates messages for the pool.  SetMaxSize must be set first.
   void SetPoolSize(size_t pool_size);
 
+  // Set the active authentication key to `auth_key`.
+  void SetAuthKey(absl::Span<const uint8_t> auth_key);
+
  private:
   aos::unique_c_ptr<Message> AcquireMessage();
 
@@ -165,7 +192,9 @@
   bool use_pool_ = false;
   std::vector<aos::unique_c_ptr<Message>> free_messages_;
 
-  std::vector<uint8_t> auth_key_;
+  // Use SCTP authentication (RFC4895).
+  bool sctp_authentication_;
+  std::vector<uint8_t> current_key_;
 };
 
 // Returns the max network buffer available for reading for a socket.
diff --git a/aos/network/sctp_perf.cc b/aos/network/sctp_perf.cc
index 3bafed1..5201f47 100644
--- a/aos/network/sctp_perf.cc
+++ b/aos/network/sctp_perf.cc
@@ -22,11 +22,9 @@
 DEFINE_uint32(skip_first_n, 10,
               "Skip the first 'n' messages when computing statistics.");
 
-#if HAS_SCTP_AUTH
 DEFINE_string(sctp_auth_key_file, "",
               "When set, use the provided key for SCTP authentication as "
               "defined in RFC 4895");
-#endif
 
 DECLARE_bool(die_on_malloc);
 
@@ -36,13 +34,16 @@
 
 using util::ReadFileToVecOrDie;
 
+SctpAuthMethod SctpAuthMethod() {
+  return FLAGS_sctp_auth_key_file.empty() ? SctpAuthMethod::kNoAuth
+                                          : SctpAuthMethod::kAuth;
+}
+
 std::vector<uint8_t> GetSctpAuthKey() {
-#if HAS_SCTP_AUTH
-  if (!FLAGS_sctp_auth_key_file.empty()) {
-    return ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+  if (SctpAuthMethod() == SctpAuthMethod::kNoAuth) {
+    return {};
   }
-#endif
-  return {};
+  return ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
 }
 
 }  // namespace
@@ -53,7 +54,8 @@
  public:
   Server(aos::ShmEventLoop *event_loop)
       : event_loop_(event_loop),
-        server_(2, "0.0.0.0", FLAGS_port, GetSctpAuthKey()) {
+        server_(2, "0.0.0.0", FLAGS_port, SctpAuthMethod()) {
+    server_.SetAuthKey(GetSctpAuthKey());
     event_loop_->epoll()->OnReadable(server_.fd(),
                                      [this]() { MessageReceived(); });
     server_.SetMaxReadSize(FLAGS_rx_size + 100);
@@ -134,7 +136,8 @@
   Client(aos::ShmEventLoop *event_loop)
       : event_loop_(event_loop),
         client_(FLAGS_host, FLAGS_port, 2, "0.0.0.0", FLAGS_port,
-                GetSctpAuthKey()) {
+                SctpAuthMethod()) {
+    client_.SetAuthKey(GetSctpAuthKey());
     client_.SetMaxReadSize(FLAGS_rx_size + 100);
     client_.SetMaxWriteSize(FLAGS_rx_size + 100);
 
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index 1bcbebd..f90b21b 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -22,8 +22,8 @@
 namespace message_bridge {
 
 SctpServer::SctpServer(int streams, std::string_view local_host, int local_port,
-                       std::vector<uint8_t> sctp_auth_key)
-    : sctp_(std::move(sctp_auth_key)) {
+                       SctpAuthMethod requested_authentication)
+    : sctp_(requested_authentication) {
   bool use_ipv6 = Ipv6Enabled();
   sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
   while (true) {
diff --git a/aos/network/sctp_server.h b/aos/network/sctp_server.h
index a8c70aa..e641fd8 100644
--- a/aos/network/sctp_server.h
+++ b/aos/network/sctp_server.h
@@ -13,6 +13,7 @@
 #include <cstring>
 #include <memory>
 
+#include "absl/types/span.h"
 #include "glog/logging.h"
 
 #include "aos/network/sctp_lib.h"
@@ -24,7 +25,8 @@
 class SctpServer {
  public:
   SctpServer(int streams, std::string_view local_host = "0.0.0.0",
-             int local_port = 9971, std::vector<uint8_t> sctp_auth_key = {});
+             int local_port = 9971,
+             SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth);
 
   ~SctpServer() {}
 
@@ -63,6 +65,10 @@
 
   void SetPoolSize(size_t pool_size) { sctp_.SetPoolSize(pool_size); }
 
+  void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+    sctp_.SetAuthKey(auth_key);
+  }
+
  private:
   struct sockaddr_storage sockaddr_local_;
   SctpReadWrite sctp_;
diff --git a/aos/network/sctp_test.cc b/aos/network/sctp_test.cc
index edf13d4..8e332e4 100644
--- a/aos/network/sctp_test.cc
+++ b/aos/network/sctp_test.cc
@@ -93,9 +93,10 @@
  public:
   SctpTest(std::vector<uint8_t> server_key = {},
            std::vector<uint8_t> client_key = {},
+           SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
            std::chrono::milliseconds timeout = 1000ms)
-      : server_(kStreams, "", kPort, std::move(server_key)),
-        client_("localhost", kPort, kStreams, "", 0, std::move(client_key)),
+      : server_(kStreams, "", kPort, requested_authentication),
+        client_("localhost", kPort, kStreams, "", 0, requested_authentication),
         client_receiver_(
             epoll_, client_,
             [this](SctpClient &client,
@@ -114,6 +115,8 @@
             [this](SctpServer &server, std::vector<uint8_t> message) {
               HandleMessage(server, std::move(message));
             }) {
+    server_.SetAuthKey(server_key);
+    client_.SetAuthKey(client_key);
     timeout_.SetTime(aos::monotonic_clock::now() + timeout,
                      std::chrono::milliseconds::zero());
     epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
@@ -176,7 +179,8 @@
 // Verifies we can ping the server, and the server replies.
 class SctpPingPongTest : public SctpTest {
  public:
-  SctpPingPongTest() : SctpTest({}, {}, /*timeout=*/2s) {
+  SctpPingPongTest()
+      : SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
     // Start by having the client send "ping".
     client_.Send(0, "ping", 0);
   }
@@ -220,7 +224,7 @@
 class SctpAuthTest : public SctpTest {
  public:
   SctpAuthTest()
-      : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+      : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
                  /*timeout*/ 20s) {
     // Start by having the client send "ping".
     client_.Send(0, "ping", 0);
@@ -254,11 +258,54 @@
 
 TEST_F(SctpAuthTest, Test) {}
 
+// Tests that we can dynamically change the SCTP authentication key used.
+class SctpChangingAuthKeysTest : public SctpTest {
+ public:
+  SctpChangingAuthKeysTest()
+      : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+                 SctpAuthMethod::kAuth) {
+    // Start by having the client send "ping".
+    client_.SetAuthKey({5, 4, 3, 2, 1});
+    server_.SetAuthKey({5, 4, 3, 2, 1});
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &server,
+                     std::vector<uint8_t> message) override {
+    // Server should receive a ping message.
+    EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+    got_ping_ = true;
+    ASSERT_NE(assoc_, 0);
+    // Reply with "pong".
+    server.Send("pong", assoc_, 0, 0);
+  }
+  void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+    // Client should receive a "pong" message.
+    EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+    got_pong_ = true;
+    // We are done.
+    Quit();
+  }
+
+  ~SctpChangingAuthKeysTest() {
+    EXPECT_TRUE(got_ping_);
+    EXPECT_TRUE(got_pong_);
+  }
+
+ protected:
+  bool got_ping_ = false;
+  bool got_pong_ = false;
+};
+
+TEST_F(SctpChangingAuthKeysTest, Test) {}
+
 // Keys don't match, we should send the `ping` message but the server should
 // never receive it. We then time out as nothing calls Quit.
 class SctpMismatchedAuthTest : public SctpTest {
  public:
-  SctpMismatchedAuthTest() : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10}) {
+  SctpMismatchedAuthTest()
+      : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
+                 SctpAuthMethod::kAuth) {
     // Start by having the client send "ping".
     client_.Send(0, "ping", 0);
   }
@@ -278,7 +325,8 @@
 // see the same behaviour.
 class SctpOneNullKeyTest : public SctpTest {
  public:
-  SctpOneNullKeyTest() : SctpTest({1, 2, 3, 4, 5, 6}, {}) {
+  SctpOneNullKeyTest()
+      : SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
     // Start by having the client send "ping".
     client_.Send(0, "ping", 0);
   }
@@ -293,6 +341,27 @@
 };
 
 TEST_F(SctpOneNullKeyTest, Test) {}
+
+// If we want SCTP authentication but we don't set the auth keys, we shouldn't
+// be able to send data.
+class SctpAuthKeysNotSet : public SctpTest {
+ public:
+  SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+    FAIL() << "Haven't setup authentication keys. Should not get message.";
+    Quit();
+  }
+
+  // We expect to time out since we never get the message.
+  void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpAuthKeysNotSet, Test) {}
+
 #endif  // HAS_SCTP_AUTH
 
 }  // namespace aos::message_bridge::testing