Add sctp authentication to sctp_lib

This only works for linux >=5.4. When enabled, it will use
a shared key to authenticate messages. The functionality is
controlled by a flag and behind a linux version check.

Performance degradation is minimal, even for smaller messages
and unnoticeable when measuring overall system performance.

Change-Id: I836e61ec38a0c116fd7244b771437738ccca9828
Signed-off-by: James Kuszmaul <jabukuszmaul+collab@gmail.com>
diff --git a/aos/network/BUILD b/aos/network/BUILD
index d98b43b..6f98b6b 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -160,6 +160,25 @@
     ],
 )
 
+cc_test(
+    name = "sctp_test",
+    srcs = [
+        "sctp_test.cc",
+    ],
+    tags = [
+        # Fakeroot is required to enable "net.sctp.auth_enable".
+        "requires-fakeroot",
+    ],
+    target_compatible_with = ["@platforms//cpu:x86_64"],
+    deps = [
+        ":sctp_client",
+        ":sctp_lib",
+        ":sctp_server",
+        "//aos/events:epoll",
+        "//aos/testing:googletest",
+    ],
+)
+
 cc_library(
     name = "sctp_server",
     srcs = [
@@ -269,6 +288,7 @@
     target_compatible_with = ["@platforms//os:linux"],
     deps = [
         ":message_bridge_server_lib",
+        ":sctp_lib",
         "//aos:init",
         "//aos:json_to_flatbuffer",
         "//aos:sha256",
@@ -351,11 +371,13 @@
     target_compatible_with = ["@platforms//os:linux"],
     deps = [
         ":message_bridge_client_lib",
+        ":sctp_lib",
         "//aos:init",
         "//aos:json_to_flatbuffer",
         "//aos:sha256",
         "//aos/events:shm_event_loop",
         "//aos/logging:dynamic_logging",
+        "//aos/util:file",
     ],
 )
 
diff --git a/aos/network/message_bridge_client.cc b/aos/network/message_bridge_client.cc
index c3f55ba..afe2a90 100644
--- a/aos/network/message_bridge_client.cc
+++ b/aos/network/message_bridge_client.cc
@@ -2,14 +2,24 @@
 #include "aos/init.h"
 #include "aos/logging/dynamic_logging.h"
 #include "aos/network/message_bridge_client_lib.h"
+#include "aos/network/sctp_lib.h"
 #include "aos/sha256.h"
+#include "aos/util/file.h"
 
 DEFINE_string(config, "aos_config.json", "Path to the config.");
 DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
 
+#if HAS_SCTP_AUTH
+DEFINE_string(sctp_auth_key_file, "",
+              "When set, use the provided key for SCTP authentication as "
+              "defined in RFC 4895. The file should be binary-encoded");
+#endif
+
 namespace aos {
 namespace message_bridge {
 
+using ::aos::util::ReadFileToVecOrDie;
+
 int Main() {
   aos::FlatbufferDetachedBuffer<aos::Configuration> config =
       aos::configuration::ReadConfig(FLAGS_config);
@@ -19,7 +29,14 @@
     event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
   }
 
-  MessageBridgeClient app(&event_loop, Sha256(config.span()));
+  std::vector<uint8_t> auth_key;
+#if HAS_SCTP_AUTH
+  if (!FLAGS_sctp_auth_key_file.empty()) {
+    auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+  }
+#endif
+  MessageBridgeClient app(&event_loop, Sha256(config.span()),
+                          std::move(auth_key));
 
   logging::DynamicLogging dynamic_logging(&event_loop);
   // TODO(austin): Save messages into a vector to be logged.  One file per
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index c23f748..0f491fc 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -100,7 +100,8 @@
     aos::ShmEventLoop *const event_loop, std::string_view remote_name,
     const Node *my_node, std::string_view local_host,
     std::vector<SctpClientChannelState> *channels, int client_index,
-    MessageBridgeClientStatus *client_status, std::string_view config_sha256)
+    MessageBridgeClientStatus *client_status, std::string_view config_sha256,
+    std::vector<uint8_t> auth_key)
     : event_loop_(event_loop),
       connect_message_(MakeConnectMessage(event_loop->configuration(), my_node,
                                           remote_name, event_loop->boot_uuid(),
@@ -111,7 +112,7 @@
       client_(remote_node_->hostname()->string_view(), remote_node_->port(),
               connect_message_.message().channels_to_transfer()->size() +
                   kControlStreams(),
-              local_host, 0),
+              local_host, 0, std::move(auth_key)),
       channels_(channels),
       stream_to_channel_(
           StreamToChannel(event_loop->configuration(), my_node, remote_node_)),
@@ -364,7 +365,8 @@
 }
 
 MessageBridgeClient::MessageBridgeClient(aos::ShmEventLoop *event_loop,
-                                         std::string config_sha256)
+                                         std::string config_sha256,
+                                         std::vector<uint8_t> auth_key)
     : event_loop_(event_loop),
       client_status_(event_loop_),
       config_sha256_(std::move(config_sha256)) {
@@ -415,7 +417,7 @@
     connections_.emplace_back(new SctpClientConnection(
         event_loop, source_node, event_loop->node(), "", &channels_,
         client_status_.FindClientIndex(source_node), &client_status_,
-        config_sha256_));
+        config_sha256_, auth_key));
   }
 }
 
diff --git a/aos/network/message_bridge_client_lib.h b/aos/network/message_bridge_client_lib.h
index a5a69e3..bb3bc1a 100644
--- a/aos/network/message_bridge_client_lib.h
+++ b/aos/network/message_bridge_client_lib.h
@@ -38,7 +38,8 @@
                        std::vector<SctpClientChannelState> *channels,
                        int client_index,
                        MessageBridgeClientStatus *client_status,
-                       std::string_view config_sha256);
+                       std::string_view config_sha256,
+                       std::vector<uint8_t> auth_key);
 
   ~SctpClientConnection() { event_loop_->epoll()->DeleteFd(client_.fd()); }
 
@@ -102,7 +103,10 @@
 // node.
 class MessageBridgeClient {
  public:
-  MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256);
+  // When the `auth_key` byte-vector is non-empty, it will be used as the shared
+  // key to authenticate every channel (See RFC4895 for more info).
+  MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256,
+                      std::vector<uint8_t> auth_key);
 
   ~MessageBridgeClient() {}
 
diff --git a/aos/network/message_bridge_server.cc b/aos/network/message_bridge_server.cc
index 04b07c3..4daf9c7 100644
--- a/aos/network/message_bridge_server.cc
+++ b/aos/network/message_bridge_server.cc
@@ -5,14 +5,23 @@
 #include "aos/init.h"
 #include "aos/logging/dynamic_logging.h"
 #include "aos/network/message_bridge_server_lib.h"
+#include "aos/network/sctp_lib.h"
 #include "aos/sha256.h"
 
 DEFINE_string(config, "aos_config.json", "Path to the config.");
 DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
 
+#if HAS_SCTP_AUTH
+DEFINE_string(sctp_auth_key_file, "",
+              "When set, use the provided key for SCTP authentication as "
+              "defined in RFC 4895");
+#endif
+
 namespace aos {
 namespace message_bridge {
 
+using ::aos::util::ReadFileToVecOrDie;
+
 int Main() {
   aos::FlatbufferDetachedBuffer<aos::Configuration> config =
       aos::configuration::ReadConfig(FLAGS_config);
@@ -22,7 +31,14 @@
     event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
   }
 
-  MessageBridgeServer app(&event_loop, Sha256(config.span()));
+  std::vector<uint8_t> auth_key;
+#if HAS_SCTP_AUTH
+  if (!FLAGS_sctp_auth_key_file.empty()) {
+    auth_key = ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+  }
+#endif
+  MessageBridgeServer app(&event_loop, Sha256(config.span()),
+                          std::move(auth_key));
 
   logging::DynamicLogging dynamic_logging(&event_loop);
 
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index debfeeb..b925c6b 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -402,11 +402,12 @@
 }
 
 MessageBridgeServer::MessageBridgeServer(aos::ShmEventLoop *event_loop,
-                                         std::string config_sha256)
+                                         std::string config_sha256,
+                                         std::vector<uint8_t> auth_key)
     : event_loop_(event_loop),
       timestamp_loggers_(event_loop_),
       server_(max_channels() + kControlStreams(), "",
-              event_loop->node()->port()),
+              event_loop->node()->port(), auth_key),
       server_status_(event_loop, [this]() { timestamp_state_->SendData(); }),
       config_sha256_(std::move(config_sha256)),
       allocator_(0) {
diff --git a/aos/network/message_bridge_server_lib.h b/aos/network/message_bridge_server_lib.h
index 6a3fdd8..ed81de6 100644
--- a/aos/network/message_bridge_server_lib.h
+++ b/aos/network/message_bridge_server_lib.h
@@ -167,7 +167,10 @@
 // node.  It handles the session and dispatches data to the ChannelState.
 class MessageBridgeServer {
  public:
-  MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256);
+  // When the `auth_key` byte-vector is non-empty, it will be used as the shared
+  // key to authenticate every channel (See RFC4895 for more info).
+  MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256,
+                      std::vector<uint8_t> auth_key);
 
   // Delete copy/move constructors explicitly--we internally pass around
   // pointers to internal state.
diff --git a/aos/network/message_bridge_test_lib.cc b/aos/network/message_bridge_test_lib.cc
index e818075..226c13a 100644
--- a/aos/network/message_bridge_test_lib.cc
+++ b/aos/network/message_bridge_test_lib.cc
@@ -81,7 +81,8 @@
   pi1_server_event_loop->SetRuntimeRealtimePriority(1);
   pi1_message_bridge_server = std::make_unique<MessageBridgeServer>(
       pi1_server_event_loop.get(),
-      server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256);
+      server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256,
+      /*auth_key=*/std::vector<uint8_t>({}));
 }
 
 void MessageBridgeParameterizedTest::RunPi1Server(
@@ -115,7 +116,8 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi1_client_event_loop->SetRuntimeRealtimePriority(1);
   pi1_message_bridge_client = std::make_unique<MessageBridgeClient>(
-      pi1_client_event_loop.get(), config_sha256);
+      pi1_client_event_loop.get(), config_sha256,
+      /*auth_key=*/std::vector<uint8_t>({}));
 }
 
 void MessageBridgeParameterizedTest::StartPi1Client() {
@@ -168,7 +170,8 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi2_server_event_loop->SetRuntimeRealtimePriority(1);
   pi2_message_bridge_server = std::make_unique<MessageBridgeServer>(
-      pi2_server_event_loop.get(), config_sha256);
+      pi2_server_event_loop.get(), config_sha256,
+      /*auth_key=*/std::vector<uint8_t>({}));
 }
 
 void MessageBridgeParameterizedTest::RunPi2Server(
@@ -202,7 +205,8 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi2_client_event_loop->SetRuntimeRealtimePriority(1);
   pi2_message_bridge_client = std::make_unique<MessageBridgeClient>(
-      pi2_client_event_loop.get(), config_sha256);
+      pi2_client_event_loop.get(), config_sha256,
+      /*auth_key=*/std::vector<uint8_t>({}));
 }
 
 void MessageBridgeParameterizedTest::RunPi2Client(
diff --git a/aos/network/sctp_client.cc b/aos/network/sctp_client.cc
index a87dfc6..e044396 100644
--- a/aos/network/sctp_client.cc
+++ b/aos/network/sctp_client.cc
@@ -22,8 +22,9 @@
 namespace message_bridge {
 
 SctpClient::SctpClient(std::string_view remote_host, int remote_port,
-                       int streams, std::string_view local_host,
-                       int local_port) {
+                       int streams, std::string_view local_host, int local_port,
+                       std::vector<uint8_t> sctp_auth_key)
+    : sctp_(std::move(sctp_auth_key)) {
   bool use_ipv6 = Ipv6Enabled();
   sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
   sockaddr_remote_ = ResolveSocket(remote_host, remote_port, use_ipv6);
diff --git a/aos/network/sctp_client.h b/aos/network/sctp_client.h
index 5affecc..c6b6324 100644
--- a/aos/network/sctp_client.h
+++ b/aos/network/sctp_client.h
@@ -17,7 +17,8 @@
 class SctpClient {
  public:
   SctpClient(std::string_view remote_host, int remote_port, int streams,
-             std::string_view local_host = "0.0.0.0", int local_port = 9971);
+             std::string_view local_host = "0.0.0.0", int local_port = 9971,
+             std::vector<uint8_t> sctp_auth_key = {});
 
   ~SctpClient() {}
 
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 6030ea7..77cad02 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -31,6 +31,25 @@
   struct sctp_sndrcvinfo sndrcvinfo;
 } _sctp_cmsg_data_t;
 
+// Returns true if SCTP authentication is available and enabled.
+bool SctpAuthIsEnabled() {
+#if HAS_SCTP_AUTH
+  struct stat current_stat;
+  if (stat("/proc/sys/net/sctp/auth_enable", &current_stat) != -1) {
+    int value = std::stoi(
+        util::ReadFileToStringOrDie("/proc/sys/net/sctp/auth_enable"));
+    CHECK(value == 0 || value == 1)
+        << "Unknown auth enable sysctl value: " << value;
+    return value == 1;
+  } else {
+    LOG(WARNING) << "/proc/sys/net/sctp/auth_enable doesn't exist.";
+    return false;
+  }
+#else
+  return false;
+#endif
+}
+
 }  // namespace
 
 bool Ipv6Enabled() {
@@ -270,6 +289,44 @@
                       sizeof(subscribe)) == 0);
   }
 
+  if (!auth_key_.empty()) {
+    CHECK(SctpAuthIsEnabled())
+        << "SCTP Authentication is disabled. Enable it with 'sysctl -w "
+           "net.sctp.auth_enable=1' and try again.";
+#if HAS_SCTP_AUTH
+    // Set up the key with id `1`.
+    sctp_authkey *const authkey =
+        (sctp_authkey *)malloc(sizeof(sctp_authkey) + auth_key_.size());
+    authkey->sca_keynumber = 1;
+    authkey->sca_keylength = auth_key_.size();
+    authkey->sca_assoc_id = SCTP_ALL_ASSOC;
+    memcpy(&authkey->sca_key, auth_key_.data(), auth_key_.size());
+
+    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_KEY, authkey,
+                      sizeof(sctp_authkey) + auth_key_.size()) == 0);
+    free(authkey);
+
+    // Set key `1` as active.
+    struct sctp_authkeyid authkeyid;
+    authkeyid.scact_keynumber = 1;
+    authkeyid.scact_assoc_id = SCTP_ALL_ASSOC;
+    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_ACTIVE_KEY, &authkeyid,
+                      sizeof(authkeyid)) == 0);
+
+    // Set up authentication for data chunks.
+    struct sctp_authchunk authchunk;
+    authchunk.sauth_chunk = 0;
+
+    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_CHUNK, &authchunk,
+                      sizeof(authchunk)) == 0);
+
+    // Disallow the null key.
+    authkeyid.scact_keynumber = 0;
+    PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_AUTH_DELETE_KEY, &authkeyid,
+                      sizeof(authkeyid)) == 0);
+#endif
+  }
+
   DoSetMaxSize();
 }
 
@@ -335,6 +392,9 @@
   return true;
 }
 
+SctpReadWrite::SctpReadWrite(std::vector<uint8_t> auth_key)
+    : auth_key_(std::move(auth_key)) {}
+
 void SctpReadWrite::FreeMessage(aos::unique_c_ptr<Message> &&message) {
   if (use_pool_) {
     free_messages_.emplace_back(std::move(message));
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index 8eb57d3..f78934a 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -3,6 +3,7 @@
 
 #include <arpa/inet.h>
 #include <linux/sctp.h>
+#include <linux/version.h>
 
 #include <memory>
 #include <optional>
@@ -15,6 +16,8 @@
 
 #include "aos/unique_malloc_ptr.h"
 
+#define HAS_SCTP_AUTH LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 0)
+
 namespace aos {
 namespace message_bridge {
 
@@ -92,7 +95,7 @@
 // Manages reading and writing SCTP messages.
 class SctpReadWrite {
  public:
-  SctpReadWrite() = default;
+  SctpReadWrite(std::vector<uint8_t> auth_key = {});
   ~SctpReadWrite() { CloseSocket(); }
 
   // Opens a new socket.
@@ -161,6 +164,8 @@
 
   bool use_pool_ = false;
   std::vector<aos::unique_c_ptr<Message>> free_messages_;
+
+  std::vector<uint8_t> auth_key_;
 };
 
 // Returns the max network buffer available for reading for a socket.
diff --git a/aos/network/sctp_perf.cc b/aos/network/sctp_perf.cc
index cce4bed..3bafed1 100644
--- a/aos/network/sctp_perf.cc
+++ b/aos/network/sctp_perf.cc
@@ -6,6 +6,7 @@
 #include "aos/events/shm_event_loop.h"
 #include "aos/init.h"
 #include "aos/network/sctp_client.h"
+#include "aos/network/sctp_lib.h"
 #include "aos/network/sctp_server.h"
 
 DEFINE_string(config, "aos_config.json", "Path to the config.");
@@ -21,16 +22,38 @@
 DEFINE_uint32(skip_first_n, 10,
               "Skip the first 'n' messages when computing statistics.");
 
+#if HAS_SCTP_AUTH
+DEFINE_string(sctp_auth_key_file, "",
+              "When set, use the provided key for SCTP authentication as "
+              "defined in RFC 4895");
+#endif
+
 DECLARE_bool(die_on_malloc);
 
 namespace aos::message_bridge::perf {
 
+namespace {
+
+using util::ReadFileToVecOrDie;
+
+std::vector<uint8_t> GetSctpAuthKey() {
+#if HAS_SCTP_AUTH
+  if (!FLAGS_sctp_auth_key_file.empty()) {
+    return ReadFileToVecOrDie(FLAGS_sctp_auth_key_file);
+  }
+#endif
+  return {};
+}
+
+}  // namespace
+
 namespace chrono = std::chrono;
 
 class Server {
  public:
   Server(aos::ShmEventLoop *event_loop)
-      : event_loop_(event_loop), server_(2, "0.0.0.0", FLAGS_port) {
+      : event_loop_(event_loop),
+        server_(2, "0.0.0.0", FLAGS_port, GetSctpAuthKey()) {
     event_loop_->epoll()->OnReadable(server_.fd(),
                                      [this]() { MessageReceived(); });
     server_.SetMaxReadSize(FLAGS_rx_size + 100);
@@ -109,7 +132,9 @@
 class Client {
  public:
   Client(aos::ShmEventLoop *event_loop)
-      : event_loop_(event_loop), client_(FLAGS_host, FLAGS_port, 2) {
+      : event_loop_(event_loop),
+        client_(FLAGS_host, FLAGS_port, 2, "0.0.0.0", FLAGS_port,
+                GetSctpAuthKey()) {
     client_.SetMaxReadSize(FLAGS_rx_size + 100);
     client_.SetMaxWriteSize(FLAGS_rx_size + 100);
 
@@ -196,8 +221,8 @@
     double throughput = FLAGS_payload_size * 2.0 / elapsed_secs;
     double avg_throughput = FLAGS_payload_size * 2.0 / avg_latency_;
     printf(
-        "Round trip: %.2fms | %.2f KB/s | Avg RTL: %.2fms | %.2f KB/s | Count: "
-        "%d\n",
+        "Round trip: %.2fms | %.2f KB/s | Avg RTL: %.2fms | %.2f KB/s | "
+        "Count: %d\n",
         elapsed_secs * 1000, throughput / 1024, avg_latency_ * 1000,
         avg_throughput / 1024, count_);
   }
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index fb736de..1bcbebd 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -21,8 +21,9 @@
 namespace aos {
 namespace message_bridge {
 
-SctpServer::SctpServer(int streams, std::string_view local_host,
-                       int local_port) {
+SctpServer::SctpServer(int streams, std::string_view local_host, int local_port,
+                       std::vector<uint8_t> sctp_auth_key)
+    : sctp_(std::move(sctp_auth_key)) {
   bool use_ipv6 = Ipv6Enabled();
   sockaddr_local_ = ResolveSocket(local_host, local_port, use_ipv6);
   while (true) {
diff --git a/aos/network/sctp_server.h b/aos/network/sctp_server.h
index 2831a68..a8c70aa 100644
--- a/aos/network/sctp_server.h
+++ b/aos/network/sctp_server.h
@@ -24,7 +24,7 @@
 class SctpServer {
  public:
   SctpServer(int streams, std::string_view local_host = "0.0.0.0",
-             int local_port = 9971);
+             int local_port = 9971, std::vector<uint8_t> sctp_auth_key = {});
 
   ~SctpServer() {}
 
diff --git a/aos/network/sctp_test.cc b/aos/network/sctp_test.cc
new file mode 100644
index 0000000..edf13d4
--- /dev/null
+++ b/aos/network/sctp_test.cc
@@ -0,0 +1,298 @@
+#include <unistd.h>
+
+#include <chrono>
+#include <functional>
+
+#include "gflags/gflags.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+#include "aos/events/epoll.h"
+#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_lib.h"
+#include "aos/network/sctp_server.h"
+
+DECLARE_bool(disable_ipv6);
+
+namespace aos::message_bridge::testing {
+
+using ::aos::internal::EPoll;
+using ::aos::internal::TimerFd;
+using ::testing::ElementsAre;
+
+using namespace ::std::chrono_literals;
+
+constexpr int kPort = 19423;
+constexpr int kStreams = 1;
+
+namespace {
+void EnableSctpAuthIfAvailable() {
+#if HAS_SCTP_AUTH
+  CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || /sbin/sysctl "
+               "net.sctp.auth_enable=1") == 0)
+      << "Couldn't enable sctp authentication.";
+#endif
+}
+}  // namespace
+
+// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
+// SctpClient), and an `sctp_notification` handler and a `message` handler. It
+// asynchronously routes incoming messages to the appropriate handler.
+template <typename T>
+class SctpReceiver {
+ public:
+  SctpReceiver(
+      EPoll &epoll, T &receiver,
+      std::function<void(T &, const union sctp_notification *)> on_notify,
+      std::function<void(T &, std::vector<uint8_t>)> on_message)
+      : epoll_(epoll),
+        receiver_(receiver),
+        on_notify_(std::move(on_notify)),
+        on_message_(std::move(on_message)) {
+    epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
+  }
+
+  ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
+
+ private:
+  // Handles an incoming message by routing it to the apropriate handler.
+  void Read() {
+    aos::unique_c_ptr<Message> message = receiver_.Read();
+    if (!message) {
+      return;
+    }
+
+    switch (message->message_type) {
+      case Message::kNotification: {
+        const union sctp_notification *notification =
+            reinterpret_cast<const union sctp_notification *>(message->data());
+        on_notify_(receiver_, notification);
+        break;
+      }
+      case Message::kMessage:
+        on_message_(receiver_, std::vector(message->data(),
+                                           message->data() + message->size));
+        break;
+      case Message::kOverflow:
+        LOG(FATAL) << "Overflow";
+    }
+    receiver_.FreeMessage(std::move(message));
+  }
+
+  EPoll &epoll_;
+  T &receiver_;
+  std::function<void(T &, const union sctp_notification *)> on_notify_;
+  std::function<void(T &, std::vector<uint8_t>)> on_message_;
+};
+
+// Base SctpTest class.
+//
+// The class provides a few virtual methods that should be overriden to define
+// the behavior of the test.
+class SctpTest : public ::testing::Test {
+ public:
+  SctpTest(std::vector<uint8_t> server_key = {},
+           std::vector<uint8_t> client_key = {},
+           std::chrono::milliseconds timeout = 1000ms)
+      : server_(kStreams, "", kPort, std::move(server_key)),
+        client_("localhost", kPort, kStreams, "", 0, std::move(client_key)),
+        client_receiver_(
+            epoll_, client_,
+            [this](SctpClient &client,
+                   const union sctp_notification *notification) {
+              HandleNotification(client, notification);
+            },
+            [this](SctpClient &client, std::vector<uint8_t> message) {
+              HandleMessage(client, std::move(message));
+            }),
+        server_receiver_(
+            epoll_, server_,
+            [this](SctpServer &server,
+                   const union sctp_notification *notification) {
+              HandleNotification(server, notification);
+            },
+            [this](SctpServer &server, std::vector<uint8_t> message) {
+              HandleMessage(server, std::move(message));
+            }) {
+    timeout_.SetTime(aos::monotonic_clock::now() + timeout,
+                     std::chrono::milliseconds::zero());
+    epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
+  }
+
+  static void SetUpTestSuite() {
+    EnableSctpAuthIfAvailable();
+    // Buildkite seems to have issues with ipv6 sctp sockets...
+    FLAGS_disable_ipv6 = true;
+  }
+
+  void SetUp() override { Run(); }
+
+ protected:
+  // Handles a server notification message.
+  //
+  // The default behaviour is to track the sctp association ID.
+  virtual void HandleNotification(SctpServer &,
+                                  const union sctp_notification *notification) {
+    if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
+      assoc_ = notification->sn_assoc_change.sac_assoc_id;
+    }
+  }
+
+  // Handles the client notification message.
+  virtual void HandleNotification(SctpClient &,
+                                  const union sctp_notification *) {}
+
+  // Handles a server "data" message.
+  virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
+  // Handles a client "data" message.
+  virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
+
+  // Defines the "timeout" behaviour (fail by default).
+  virtual void TimeOut() {
+    Quit();
+    FAIL() << "Timer expired";
+  }
+
+  virtual ~SctpTest() {}
+
+  // Quit the test.
+  void Quit() {
+    epoll_.DeleteFd(timeout_.fd());
+    epoll_.Quit();
+  }
+  void Run() { epoll_.Run(); }
+
+  SctpServer server_;
+  SctpClient client_;
+  sctp_assoc_t assoc_ = 0;
+
+ private:
+  TimerFd timeout_;
+  EPoll epoll_;
+  SctpReceiver<SctpClient> client_receiver_;
+  SctpReceiver<SctpServer> server_receiver_;
+};
+
+// Verifies we can ping the server, and the server replies.
+class SctpPingPongTest : public SctpTest {
+ public:
+  SctpPingPongTest() : SctpTest({}, {}, /*timeout=*/2s) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &server,
+                     std::vector<uint8_t> message) override {
+    // Server should receive a ping message.
+    EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+    got_ping_ = true;
+    ASSERT_NE(assoc_, 0);
+    // Reply with "pong".
+    server.Send("pong", assoc_, 0, 0);
+  }
+
+  void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+    // Client should receive a "pong" message.
+    EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+    got_pong_ = true;
+    // We are done.
+    Quit();
+  }
+  ~SctpPingPongTest() {
+    // Check that we got the ping/pong messages.
+    // This isn't strictly necessary as otherwise we would time out and fail
+    // anyway.
+    EXPECT_TRUE(got_ping_);
+    EXPECT_TRUE(got_pong_);
+  }
+
+ protected:
+  bool got_ping_ = false;
+  bool got_pong_ = false;
+};
+
+TEST_F(SctpPingPongTest, Test) {}
+
+#if HAS_SCTP_AUTH
+
+// Same as SctpPingPongTest but with authentication keys. Both keys are the
+// same so it should work the same way.
+class SctpAuthTest : public SctpTest {
+ public:
+  SctpAuthTest()
+      : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+                 /*timeout*/ 20s) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &server,
+                     std::vector<uint8_t> message) override {
+    // Server should receive a ping message.
+    EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+    got_ping_ = true;
+    ASSERT_NE(assoc_, 0);
+    // Reply with "pong".
+    server.Send("pong", assoc_, 0, 0);
+  }
+  void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+    // Client should receive a "pong" message.
+    EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+    got_pong_ = true;
+    // We are done.
+    Quit();
+  }
+  ~SctpAuthTest() {
+    EXPECT_TRUE(got_ping_);
+    EXPECT_TRUE(got_pong_);
+  }
+
+ protected:
+  bool got_ping_ = false;
+  bool got_pong_ = false;
+};
+
+TEST_F(SctpAuthTest, Test) {}
+
+// Keys don't match, we should send the `ping` message but the server should
+// never receive it. We then time out as nothing calls Quit.
+class SctpMismatchedAuthTest : public SctpTest {
+ public:
+  SctpMismatchedAuthTest() : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10}) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+    FAIL() << "Authentication keys don't match. Message should be discarded";
+    Quit();
+  }
+
+  // We expect to time out since we never get the message.
+  void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpMismatchedAuthTest, Test) {}
+
+// Same as SctpMismatchedAuthTest but the client uses the null key. We should
+// see the same behaviour.
+class SctpOneNullKeyTest : public SctpTest {
+ public:
+  SctpOneNullKeyTest() : SctpTest({1, 2, 3, 4, 5, 6}, {}) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+    FAIL() << "Authentication keys don't match. Message should be discarded";
+    Quit();
+  }
+
+  // We expect to time out since we never get the message.
+  void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpOneNullKeyTest, Test) {}
+#endif  // HAS_SCTP_AUTH
+
+}  // namespace aos::message_bridge::testing
diff --git a/aos/util/file.cc b/aos/util/file.cc
index 591539c..52657a9 100644
--- a/aos/util/file.cc
+++ b/aos/util/file.cc
@@ -47,6 +47,22 @@
   return r;
 }
 
+std::vector<uint8_t> ReadFileToVecOrDie(const std::string_view filename) {
+  std::vector<uint8_t> r;
+  ScopedFD fd(open(::std::string(filename).c_str(), O_RDONLY));
+  PCHECK(fd.get() != -1) << ": opening " << filename;
+  while (true) {
+    uint8_t buffer[1024];
+    const ssize_t result = read(fd.get(), buffer, sizeof(buffer));
+    PCHECK(result >= 0) << ": reading from " << filename;
+    if (result == 0) {
+      break;
+    }
+    std::copy(buffer, buffer + result, std::back_inserter(r));
+  }
+  return r;
+}
+
 void WriteStringToFileOrDie(const std::string_view filename,
                             const std::string_view contents,
                             mode_t permissions) {
diff --git a/aos/util/file.h b/aos/util/file.h
index 0cba867..da5b29f 100644
--- a/aos/util/file.h
+++ b/aos/util/file.h
@@ -29,6 +29,10 @@
 std::optional<std::string> MaybeReadFileToString(
     const std::string_view filename);
 
+// Returns the complete contents of filename as a byte vector. LOG(FATAL)s if
+// any errors are encountered.
+std::vector<uint8_t> ReadFileToVecOrDie(const std::string_view filename);
+
 // Creates filename if it doesn't exist and sets the contents to contents.
 void WriteStringToFileOrDie(const std::string_view filename,
                             const std::string_view contents,
diff --git a/aos/util/file_test.cc b/aos/util/file_test.cc
index df16d58..ec4bfe4 100644
--- a/aos/util/file_test.cc
+++ b/aos/util/file_test.cc
@@ -4,6 +4,7 @@
 #include <optional>
 #include <string>
 
+#include "gmock/gmock-matchers.h"
 #include "gtest/gtest.h"
 
 #include "aos/realtime.h"
@@ -13,6 +14,8 @@
 namespace util {
 namespace testing {
 
+using ::testing::ElementsAre;
+
 // Basic test of reading a normal file.
 TEST(FileTest, ReadNormalFile) {
   const std::string tmpdir(aos::testing::TestTmpDir());
@@ -21,6 +24,15 @@
   EXPECT_EQ("contents\n", ReadFileToStringOrDie(test_file));
 }
 
+// Basic test of reading a normal file.
+TEST(FileTest, ReadNormalFileToBytes) {
+  const std::string tmpdir(aos::testing::TestTmpDir());
+  const std::string test_file = tmpdir + "/test_file";
+  ASSERT_EQ(0, system(("echo contents > " + test_file).c_str()));
+  EXPECT_THAT(ReadFileToVecOrDie(test_file),
+              ElementsAre('c', 'o', 'n', 't', 'e', 'n', 't', 's', '\n'));
+}
+
 // Tests reading a file with 0 size, among other weird things.
 TEST(FileTest, ReadSpecialFile) {
   const std::string stat = ReadFileToStringOrDie("/proc/self/stat");