Add sctp authentication to sctp_lib

This only works for linux >=5.4. When enabled, it will use
a shared key to authenticate messages. The functionality is
controlled by a flag and behind a linux version check.

Performance degradation is minimal, even for smaller messages
and unnoticeable when measuring overall system performance.

Change-Id: I836e61ec38a0c116fd7244b771437738ccca9828
Signed-off-by: James Kuszmaul <jabukuszmaul+collab@gmail.com>
diff --git a/aos/network/sctp_test.cc b/aos/network/sctp_test.cc
new file mode 100644
index 0000000..edf13d4
--- /dev/null
+++ b/aos/network/sctp_test.cc
@@ -0,0 +1,298 @@
+#include <unistd.h>
+
+#include <chrono>
+#include <functional>
+
+#include "gflags/gflags.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+#include "aos/events/epoll.h"
+#include "aos/network/sctp_client.h"
+#include "aos/network/sctp_lib.h"
+#include "aos/network/sctp_server.h"
+
+DECLARE_bool(disable_ipv6);
+
+namespace aos::message_bridge::testing {
+
+using ::aos::internal::EPoll;
+using ::aos::internal::TimerFd;
+using ::testing::ElementsAre;
+
+using namespace ::std::chrono_literals;
+
+constexpr int kPort = 19423;
+constexpr int kStreams = 1;
+
+namespace {
+void EnableSctpAuthIfAvailable() {
+#if HAS_SCTP_AUTH
+  CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || /sbin/sysctl "
+               "net.sctp.auth_enable=1") == 0)
+      << "Couldn't enable sctp authentication.";
+#endif
+}
+}  // namespace
+
+// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
+// SctpClient), and an `sctp_notification` handler and a `message` handler. It
+// asynchronously routes incoming messages to the appropriate handler.
+template <typename T>
+class SctpReceiver {
+ public:
+  SctpReceiver(
+      EPoll &epoll, T &receiver,
+      std::function<void(T &, const union sctp_notification *)> on_notify,
+      std::function<void(T &, std::vector<uint8_t>)> on_message)
+      : epoll_(epoll),
+        receiver_(receiver),
+        on_notify_(std::move(on_notify)),
+        on_message_(std::move(on_message)) {
+    epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
+  }
+
+  ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
+
+ private:
+  // Handles an incoming message by routing it to the apropriate handler.
+  void Read() {
+    aos::unique_c_ptr<Message> message = receiver_.Read();
+    if (!message) {
+      return;
+    }
+
+    switch (message->message_type) {
+      case Message::kNotification: {
+        const union sctp_notification *notification =
+            reinterpret_cast<const union sctp_notification *>(message->data());
+        on_notify_(receiver_, notification);
+        break;
+      }
+      case Message::kMessage:
+        on_message_(receiver_, std::vector(message->data(),
+                                           message->data() + message->size));
+        break;
+      case Message::kOverflow:
+        LOG(FATAL) << "Overflow";
+    }
+    receiver_.FreeMessage(std::move(message));
+  }
+
+  EPoll &epoll_;
+  T &receiver_;
+  std::function<void(T &, const union sctp_notification *)> on_notify_;
+  std::function<void(T &, std::vector<uint8_t>)> on_message_;
+};
+
+// Base SctpTest class.
+//
+// The class provides a few virtual methods that should be overriden to define
+// the behavior of the test.
+class SctpTest : public ::testing::Test {
+ public:
+  SctpTest(std::vector<uint8_t> server_key = {},
+           std::vector<uint8_t> client_key = {},
+           std::chrono::milliseconds timeout = 1000ms)
+      : server_(kStreams, "", kPort, std::move(server_key)),
+        client_("localhost", kPort, kStreams, "", 0, std::move(client_key)),
+        client_receiver_(
+            epoll_, client_,
+            [this](SctpClient &client,
+                   const union sctp_notification *notification) {
+              HandleNotification(client, notification);
+            },
+            [this](SctpClient &client, std::vector<uint8_t> message) {
+              HandleMessage(client, std::move(message));
+            }),
+        server_receiver_(
+            epoll_, server_,
+            [this](SctpServer &server,
+                   const union sctp_notification *notification) {
+              HandleNotification(server, notification);
+            },
+            [this](SctpServer &server, std::vector<uint8_t> message) {
+              HandleMessage(server, std::move(message));
+            }) {
+    timeout_.SetTime(aos::monotonic_clock::now() + timeout,
+                     std::chrono::milliseconds::zero());
+    epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
+  }
+
+  static void SetUpTestSuite() {
+    EnableSctpAuthIfAvailable();
+    // Buildkite seems to have issues with ipv6 sctp sockets...
+    FLAGS_disable_ipv6 = true;
+  }
+
+  void SetUp() override { Run(); }
+
+ protected:
+  // Handles a server notification message.
+  //
+  // The default behaviour is to track the sctp association ID.
+  virtual void HandleNotification(SctpServer &,
+                                  const union sctp_notification *notification) {
+    if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
+      assoc_ = notification->sn_assoc_change.sac_assoc_id;
+    }
+  }
+
+  // Handles the client notification message.
+  virtual void HandleNotification(SctpClient &,
+                                  const union sctp_notification *) {}
+
+  // Handles a server "data" message.
+  virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
+  // Handles a client "data" message.
+  virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
+
+  // Defines the "timeout" behaviour (fail by default).
+  virtual void TimeOut() {
+    Quit();
+    FAIL() << "Timer expired";
+  }
+
+  virtual ~SctpTest() {}
+
+  // Quit the test.
+  void Quit() {
+    epoll_.DeleteFd(timeout_.fd());
+    epoll_.Quit();
+  }
+  void Run() { epoll_.Run(); }
+
+  SctpServer server_;
+  SctpClient client_;
+  sctp_assoc_t assoc_ = 0;
+
+ private:
+  TimerFd timeout_;
+  EPoll epoll_;
+  SctpReceiver<SctpClient> client_receiver_;
+  SctpReceiver<SctpServer> server_receiver_;
+};
+
+// Verifies we can ping the server, and the server replies.
+class SctpPingPongTest : public SctpTest {
+ public:
+  SctpPingPongTest() : SctpTest({}, {}, /*timeout=*/2s) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &server,
+                     std::vector<uint8_t> message) override {
+    // Server should receive a ping message.
+    EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+    got_ping_ = true;
+    ASSERT_NE(assoc_, 0);
+    // Reply with "pong".
+    server.Send("pong", assoc_, 0, 0);
+  }
+
+  void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+    // Client should receive a "pong" message.
+    EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+    got_pong_ = true;
+    // We are done.
+    Quit();
+  }
+  ~SctpPingPongTest() {
+    // Check that we got the ping/pong messages.
+    // This isn't strictly necessary as otherwise we would time out and fail
+    // anyway.
+    EXPECT_TRUE(got_ping_);
+    EXPECT_TRUE(got_pong_);
+  }
+
+ protected:
+  bool got_ping_ = false;
+  bool got_pong_ = false;
+};
+
+TEST_F(SctpPingPongTest, Test) {}
+
+#if HAS_SCTP_AUTH
+
+// Same as SctpPingPongTest but with authentication keys. Both keys are the
+// same so it should work the same way.
+class SctpAuthTest : public SctpTest {
+ public:
+  SctpAuthTest()
+      : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
+                 /*timeout*/ 20s) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &server,
+                     std::vector<uint8_t> message) override {
+    // Server should receive a ping message.
+    EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
+    got_ping_ = true;
+    ASSERT_NE(assoc_, 0);
+    // Reply with "pong".
+    server.Send("pong", assoc_, 0, 0);
+  }
+  void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
+    // Client should receive a "pong" message.
+    EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
+    got_pong_ = true;
+    // We are done.
+    Quit();
+  }
+  ~SctpAuthTest() {
+    EXPECT_TRUE(got_ping_);
+    EXPECT_TRUE(got_pong_);
+  }
+
+ protected:
+  bool got_ping_ = false;
+  bool got_pong_ = false;
+};
+
+TEST_F(SctpAuthTest, Test) {}
+
+// Keys don't match, we should send the `ping` message but the server should
+// never receive it. We then time out as nothing calls Quit.
+class SctpMismatchedAuthTest : public SctpTest {
+ public:
+  SctpMismatchedAuthTest() : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10}) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+    FAIL() << "Authentication keys don't match. Message should be discarded";
+    Quit();
+  }
+
+  // We expect to time out since we never get the message.
+  void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpMismatchedAuthTest, Test) {}
+
+// Same as SctpMismatchedAuthTest but the client uses the null key. We should
+// see the same behaviour.
+class SctpOneNullKeyTest : public SctpTest {
+ public:
+  SctpOneNullKeyTest() : SctpTest({1, 2, 3, 4, 5, 6}, {}) {
+    // Start by having the client send "ping".
+    client_.Send(0, "ping", 0);
+  }
+
+  void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
+    FAIL() << "Authentication keys don't match. Message should be discarded";
+    Quit();
+  }
+
+  // We expect to time out since we never get the message.
+  void TimeOut() override { Quit(); }
+};
+
+TEST_F(SctpOneNullKeyTest, Test) {}
+#endif  // HAS_SCTP_AUTH
+
+}  // namespace aos::message_bridge::testing