Add AOS message to update sctp auth key in message_bridge

Message bridge now listens to /aos aos.message_bridge.SctpConfig
channel to update the SCTP authentication key. This change only takes
effect if `--wants_sctp_authentication=true` and the kernel supports
SCTP authentication.

Change-Id: Id9e743471668cd26892a12e4fc9b03a73445df10
Signed-off-by: James Kuszmaul <james.kuszmaul@bluerivertech.com>
diff --git a/aos/network/BUILD b/aos/network/BUILD
index f683a28..75173e7 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -88,6 +88,31 @@
 )
 
 flatbuffer_cc_library(
+    name = "sctp_config_fbs",
+    srcs = ["sctp_config.fbs"],
+    gen_reflections = 1,
+)
+
+cc_static_flatbuffer(
+    name = "sctp_config_schema",
+    function = "aos::message_bridge::SctpConfig",
+    target = ":sctp_config_fbs_reflection_out",
+)
+
+flatbuffer_cc_library(
+    name = "sctp_config_request_fbs",
+    srcs = ["sctp_config_request.fbs"],
+    gen_reflections = 1,
+)
+
+cc_static_flatbuffer(
+    name = "sctp_config_request_schema",
+    function = "aos::message_bridge::SctpConfigRequest",
+    target = ":sctp_config_request_fbs_reflection_out",
+    visibility = ["//visibility:public"],
+)
+
+flatbuffer_cc_library(
     name = "message_bridge_server_fbs",
     srcs = ["message_bridge_server.fbs"],
     gen_reflections = 1,
@@ -271,6 +296,8 @@
         ":message_bridge_server_status",
         ":remote_data_fbs",
         ":remote_message_fbs",
+        ":sctp_config_fbs",
+        ":sctp_config_request_fbs",
         ":sctp_lib",
         ":sctp_server",
         ":timestamp_channel",
@@ -355,6 +382,8 @@
         ":remote_data_fbs",
         ":remote_message_fbs",
         ":sctp_client",
+        ":sctp_config_fbs",
+        ":sctp_config_request_fbs",
         ":timestamp_fbs",
         "//aos/events:shm_event_loop",
         "//aos/events/logging:log_reader",
@@ -387,6 +416,8 @@
     src = "message_bridge_test_combined_timestamps_common.json",
     flatbuffers = [
         ":remote_message_fbs",
+        ":sctp_config_fbs",
+        ":sctp_config_request_fbs",
         "//aos/events:ping_fbs",
         "//aos/events:pong_fbs",
         "//aos/network:message_bridge_client_fbs",
@@ -402,6 +433,8 @@
     src = "message_bridge_test_common.json",
     flatbuffers = [
         ":remote_message_fbs",
+        ":sctp_config_fbs",
+        ":sctp_config_request_fbs",
         "//aos/events:ping_fbs",
         "//aos/events:pong_fbs",
         "//aos/network:message_bridge_client_fbs",
diff --git a/aos/network/message_bridge_client.cc b/aos/network/message_bridge_client.cc
index 82a59fb..ef727eb 100644
--- a/aos/network/message_bridge_client.cc
+++ b/aos/network/message_bridge_client.cc
@@ -8,6 +8,9 @@
 
 DEFINE_string(config, "aos_config.json", "Path to the config.");
 DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
+DEFINE_bool(
+    wants_sctp_authentication, false,
+    "When set, try to use SCTP authentication if provided by the kernel");
 
 namespace aos {
 namespace message_bridge {
@@ -23,7 +26,10 @@
     event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
   }
 
-  MessageBridgeClient app(&event_loop, Sha256(config.span()));
+  MessageBridgeClient app(&event_loop, Sha256(config.span()),
+                          FLAGS_wants_sctp_authentication
+                              ? SctpAuthMethod::kAuth
+                              : SctpAuthMethod::kNoAuth);
 
   logging::DynamicLogging dynamic_logging(&event_loop);
   // TODO(austin): Save messages into a vector to be logged.  One file per
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index c23f748..8df81c6 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -13,10 +13,14 @@
 #include "aos/network/message_bridge_protocol.h"
 #include "aos/network/remote_data_generated.h"
 #include "aos/network/sctp_client.h"
+#include "aos/network/sctp_config_generated.h"
+#include "aos/network/sctp_config_request_generated.h"
 #include "aos/network/timestamp_generated.h"
 #include "aos/unique_malloc_ptr.h"
 #include "aos/util/file.h"
 
+DECLARE_bool(use_sctp_authentication);
+
 // This application receives messages from another node and re-publishes them on
 // this node.
 //
@@ -30,6 +34,9 @@
 namespace {
 namespace chrono = std::chrono;
 
+// How often we should poll for the active SCTP authentication key.
+constexpr chrono::seconds kRefreshAuthKeyPeriod{3};
+
 std::vector<int> StreamToChannel(const Configuration *config,
                                  const Node *my_node, const Node *other_node) {
   std::vector<int> stream_to_channel;
@@ -100,7 +107,8 @@
     aos::ShmEventLoop *const event_loop, std::string_view remote_name,
     const Node *my_node, std::string_view local_host,
     std::vector<SctpClientChannelState> *channels, int client_index,
-    MessageBridgeClientStatus *client_status, std::string_view config_sha256)
+    MessageBridgeClientStatus *client_status, std::string_view config_sha256,
+    SctpAuthMethod requested_authentication)
     : event_loop_(event_loop),
       connect_message_(MakeConnectMessage(event_loop->configuration(), my_node,
                                           remote_name, event_loop->boot_uuid(),
@@ -111,7 +119,7 @@
       client_(remote_node_->hostname()->string_view(), remote_node_->port(),
               connect_message_.message().channels_to_transfer()->size() +
                   kControlStreams(),
-              local_host, 0),
+              local_host, 0, requested_authentication),
       channels_(channels),
       stream_to_channel_(
           StreamToChannel(event_loop->configuration(), my_node, remote_node_)),
@@ -363,13 +371,34 @@
           << " cumtsn=" << message->header.rcvinfo.rcv_cumtsn << ")";
 }
 
-MessageBridgeClient::MessageBridgeClient(aos::ShmEventLoop *event_loop,
-                                         std::string config_sha256)
+MessageBridgeClient::MessageBridgeClient(
+    aos::ShmEventLoop *event_loop, std::string config_sha256,
+    SctpAuthMethod requested_authentication)
     : event_loop_(event_loop),
       client_status_(event_loop_),
-      config_sha256_(std::move(config_sha256)) {
+      config_sha256_(std::move(config_sha256)),
+      refresh_key_timer_(event_loop->AddTimer([this]() { RequestAuthKey(); })),
+      sctp_config_request_(event_loop_->MakeSender<SctpConfigRequest>("/aos")) {
   std::string_view node_name = event_loop->node()->name()->string_view();
 
+  // Set up the SCTP configuration watcher and timer.
+  if (requested_authentication == SctpAuthMethod::kAuth && HasSctpAuth()) {
+    event_loop->MakeWatcher("/aos", [this](const SctpConfig &config) {
+      if (config.has_key()) {
+        for (auto &conn : connections_) {
+          conn->SetAuthKey(*config.key());
+        }
+      }
+    });
+
+    // We poll in case the SCTP authentication key has changed.
+    refresh_key_timer_->set_name("refresh_key");
+    event_loop_->OnRun([this]() {
+      refresh_key_timer_->Schedule(event_loop_->monotonic_now(),
+                                   kRefreshAuthKeyPeriod);
+    });
+  }
+
   // Find all the channels which are supposed to be delivered to us.
   channels_.resize(event_loop_->configuration()->channels()->size());
   int channel_index = 0;
@@ -415,9 +444,16 @@
     connections_.emplace_back(new SctpClientConnection(
         event_loop, source_node, event_loop->node(), "", &channels_,
         client_status_.FindClientIndex(source_node), &client_status_,
-        config_sha256_));
+        config_sha256_, requested_authentication));
   }
 }
 
+void MessageBridgeClient::RequestAuthKey() {
+  auto sender = sctp_config_request_.MakeBuilder();
+  auto builder = sender.MakeBuilder<SctpConfigRequest>();
+  builder.add_request_key(true);
+  sender.CheckOk(sender.Send(builder.Finish()));
+}
+
 }  // namespace message_bridge
 }  // namespace aos
diff --git a/aos/network/message_bridge_client_lib.h b/aos/network/message_bridge_client_lib.h
index a5a69e3..e4b5e84 100644
--- a/aos/network/message_bridge_client_lib.h
+++ b/aos/network/message_bridge_client_lib.h
@@ -10,6 +10,7 @@
 #include "aos/network/message_bridge_client_generated.h"
 #include "aos/network/message_bridge_client_status.h"
 #include "aos/network/sctp_client.h"
+#include "aos/network/sctp_config_request_generated.h"
 #include "aos/network/sctp_lib.h"
 
 namespace aos {
@@ -38,10 +39,15 @@
                        std::vector<SctpClientChannelState> *channels,
                        int client_index,
                        MessageBridgeClientStatus *client_status,
-                       std::string_view config_sha256);
+                       std::string_view config_sha256,
+                       SctpAuthMethod requested_authentication);
 
   ~SctpClientConnection() { event_loop_->epoll()->DeleteFd(client_.fd()); }
 
+  void SetAuthKey(absl::Span<const uint8_t> auth_key) {
+    client_.SetAuthKey(auth_key);
+  }
+
  private:
   // Reads a message from the socket.  Could be a notification.
   void MessageReceived();
@@ -102,11 +108,15 @@
 // node.
 class MessageBridgeClient {
  public:
-  MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256);
+  MessageBridgeClient(aos::ShmEventLoop *event_loop, std::string config_sha256,
+                      SctpAuthMethod requested_authentication);
 
   ~MessageBridgeClient() {}
 
  private:
+  // Sends a request for the currently active authentication key.
+  void RequestAuthKey();
+
   // Event loop to schedule everything on.
   aos::ShmEventLoop *event_loop_;
 
@@ -119,6 +129,12 @@
   std::vector<std::unique_ptr<SctpClientConnection>> connections_;
 
   std::string config_sha256_;
+
+  // We use this timer to poll the active authentication key.
+  aos::TimerHandler *refresh_key_timer_;
+
+  // Used to request the current sctp settings to be used.
+  aos::Sender<SctpConfigRequest> sctp_config_request_;
 };
 
 }  // namespace message_bridge
diff --git a/aos/network/message_bridge_server.cc b/aos/network/message_bridge_server.cc
index 449333a..be6cc8e 100644
--- a/aos/network/message_bridge_server.cc
+++ b/aos/network/message_bridge_server.cc
@@ -10,6 +10,9 @@
 
 DEFINE_string(config, "aos_config.json", "Path to the config.");
 DEFINE_int32(rt_priority, -1, "If > 0, run as this RT priority");
+DEFINE_bool(
+    wants_sctp_authentication, false,
+    "When set, try to use SCTP authentication if provided by the kernel");
 
 namespace aos {
 namespace message_bridge {
@@ -25,7 +28,10 @@
     event_loop.SetRuntimeRealtimePriority(FLAGS_rt_priority);
   }
 
-  MessageBridgeServer app(&event_loop, Sha256(config.span()));
+  MessageBridgeServer app(&event_loop, Sha256(config.span()),
+                          FLAGS_wants_sctp_authentication
+                              ? SctpAuthMethod::kAuth
+                              : SctpAuthMethod::kNoAuth);
 
   logging::DynamicLogging dynamic_logging(&event_loop);
 
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index debfeeb..9ac062e 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -13,6 +13,7 @@
 #include "aos/network/message_bridge_server_generated.h"
 #include "aos/network/remote_data_generated.h"
 #include "aos/network/remote_message_generated.h"
+#include "aos/network/sctp_config_generated.h"
 #include "aos/network/sctp_server.h"
 #include "aos/network/timestamp_channel.h"
 
@@ -49,10 +50,15 @@
              "If set to a nonnegative numbers, the wmem buffer size to use, in "
              "bytes. Intended solely for testing purposes.");
 
+DECLARE_bool(use_sctp_authentication);
+
 namespace aos {
 namespace message_bridge {
 namespace chrono = std::chrono;
 
+// How often we should poll for the active SCTP authentication key.
+constexpr chrono::seconds kRefreshAuthKeyPeriod{3};
+
 bool ChannelState::Matches(const Channel *other_channel) {
   return channel_->name()->string_view() ==
              other_channel->name()->string_view() &&
@@ -401,18 +407,36 @@
   return -1;
 }
 
-MessageBridgeServer::MessageBridgeServer(aos::ShmEventLoop *event_loop,
-                                         std::string config_sha256)
+MessageBridgeServer::MessageBridgeServer(
+    aos::ShmEventLoop *event_loop, std::string config_sha256,
+    SctpAuthMethod requested_authentication)
     : event_loop_(event_loop),
       timestamp_loggers_(event_loop_),
       server_(max_channels() + kControlStreams(), "",
-              event_loop->node()->port()),
+              event_loop->node()->port(), requested_authentication),
       server_status_(event_loop, [this]() { timestamp_state_->SendData(); }),
       config_sha256_(std::move(config_sha256)),
-      allocator_(0) {
+      allocator_(0),
+      refresh_key_timer_(event_loop->AddTimer([this]() { RequestAuthKey(); })),
+      sctp_config_request_(event_loop_->MakeSender<SctpConfigRequest>("/aos")) {
   CHECK_EQ(config_sha256_.size(), 64u) << ": Wrong length sha256sum";
   CHECK(event_loop_->node() != nullptr) << ": No nodes configured.";
 
+  // Set up the SCTP configuration watcher and timer.
+  if (requested_authentication == SctpAuthMethod::kAuth && HasSctpAuth()) {
+    event_loop_->MakeWatcher("/aos", [this](const SctpConfig &config) {
+      if (config.has_key()) {
+        server_.SetAuthKey(*config.key());
+      }
+    });
+
+    // We poll in case the SCTP authentication key has changed.
+    refresh_key_timer_->set_name("refresh_key");
+    event_loop_->OnRun([this]() {
+      refresh_key_timer_->Schedule(event_loop_->monotonic_now(),
+                                   kRefreshAuthKeyPeriod);
+    });
+  }
   // Start out with a decent size big enough to hold timestamps.
   size_t max_size = 204;
 
@@ -821,5 +845,12 @@
   }
 }
 
+void MessageBridgeServer::RequestAuthKey() {
+  auto sender = sctp_config_request_.MakeBuilder();
+  auto builder = sender.MakeBuilder<SctpConfigRequest>();
+  builder.add_request_key(true);
+  sender.CheckOk(sender.Send(builder.Finish()));
+}
+
 }  // namespace message_bridge
 }  // namespace aos
diff --git a/aos/network/message_bridge_server_lib.h b/aos/network/message_bridge_server_lib.h
index 6a3fdd8..b8377c6 100644
--- a/aos/network/message_bridge_server_lib.h
+++ b/aos/network/message_bridge_server_lib.h
@@ -15,6 +15,7 @@
 #include "aos/network/message_bridge_server_status.h"
 #include "aos/network/remote_data_generated.h"
 #include "aos/network/remote_message_generated.h"
+#include "aos/network/sctp_config_request_generated.h"
 #include "aos/network/sctp_server.h"
 #include "aos/network/timestamp_channel.h"
 #include "aos/network/timestamp_generated.h"
@@ -167,7 +168,8 @@
 // node.  It handles the session and dispatches data to the ChannelState.
 class MessageBridgeServer {
  public:
-  MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256);
+  MessageBridgeServer(aos::ShmEventLoop *event_loop, std::string config_sha256,
+                      SctpAuthMethod requested_authentication);
 
   // Delete copy/move constructors explicitly--we internally pass around
   // pointers to internal state.
@@ -202,6 +204,9 @@
     return event_loop_->configuration()->channels()->size();
   }
 
+  // Sends a request for the currently active authentication key.
+  void RequestAuthKey();
+
   // Event loop to schedule everything on.
   aos::ShmEventLoop *event_loop_;
 
@@ -224,6 +229,12 @@
   std::vector<sctp_assoc_t> reconnected_;
 
   FixedAllocator allocator_;
+
+  // We use this timer to poll the active authentication key.
+  aos::TimerHandler *refresh_key_timer_;
+
+  // Used to request the current sctp settings to be used.
+  aos::Sender<SctpConfigRequest> sctp_config_request_;
 };
 
 }  // namespace message_bridge
diff --git a/aos/network/message_bridge_test_combined_timestamps_common.json b/aos/network/message_bridge_test_combined_timestamps_common.json
index 13a0514..a99c37c 100644
--- a/aos/network/message_bridge_test_combined_timestamps_common.json
+++ b/aos/network/message_bridge_test_combined_timestamps_common.json
@@ -100,6 +100,38 @@
       "max_size": 2048
     },
     {
+      "name": "/pi1/aos",
+      "type": "aos.message_bridge.SctpConfig",
+      "source_node": "pi1",
+      "frequency": 10,
+      "num_senders": 1,
+      "max_size": 256
+    },
+    {
+      "name": "/pi2/aos",
+      "type": "aos.message_bridge.SctpConfig",
+      "source_node": "pi2",
+      "frequency": 10,
+      "num_senders": 1,
+      "max_size": 256
+    },
+    {
+      "name": "/pi1/aos",
+      "type": "aos.message_bridge.SctpConfigRequest",
+      "source_node": "pi1",
+      "frequency": 1,
+      "num_senders": 2,
+      "max_size": 32
+    },
+    {
+      "name": "/pi2/aos",
+      "type": "aos.message_bridge.SctpConfigRequest",
+      "source_node": "pi2",
+      "frequency": 1,
+      "num_senders": 2,
+      "max_size": 32
+    },
+    {
       "name": "/test",
       "type": "aos.examples.Ping",
       "source_node": "pi1",
diff --git a/aos/network/message_bridge_test_common.json b/aos/network/message_bridge_test_common.json
index 9bb0863..09d19c8 100644
--- a/aos/network/message_bridge_test_common.json
+++ b/aos/network/message_bridge_test_common.json
@@ -121,6 +121,38 @@
       "max_size": 2048
     },
     {
+      "name": "/pi1/aos",
+      "type": "aos.message_bridge.SctpConfig",
+      "source_node": "pi1",
+      "frequency": 10,
+      "num_senders": 1,
+      "max_size": 256
+    },
+    {
+      "name": "/pi2/aos",
+      "type": "aos.message_bridge.SctpConfig",
+      "source_node": "pi2",
+      "frequency": 10,
+      "num_senders": 1,
+      "max_size": 256
+    },
+    {
+      "name": "/pi1/aos",
+      "type": "aos.message_bridge.SctpConfigRequest",
+      "source_node": "pi1",
+      "frequency": 1,
+      "num_senders": 2,
+      "max_size": 32
+    },
+    {
+      "name": "/pi2/aos",
+      "type": "aos.message_bridge.SctpConfigRequest",
+      "source_node": "pi2",
+      "frequency": 1,
+      "num_senders": 2,
+      "max_size": 32
+    },
+    {
       "name": "/test",
       "type": "aos.examples.Ping",
       "source_node": "pi1",
diff --git a/aos/network/message_bridge_test_lib.cc b/aos/network/message_bridge_test_lib.cc
index e818075..d5a5f8f 100644
--- a/aos/network/message_bridge_test_lib.cc
+++ b/aos/network/message_bridge_test_lib.cc
@@ -81,7 +81,8 @@
   pi1_server_event_loop->SetRuntimeRealtimePriority(1);
   pi1_message_bridge_server = std::make_unique<MessageBridgeServer>(
       pi1_server_event_loop.get(),
-      server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256);
+      server_config_sha256.size() == 0 ? config_sha256 : server_config_sha256,
+      SctpAuthMethod::kNoAuth);
 }
 
 void MessageBridgeParameterizedTest::RunPi1Server(
@@ -115,7 +116,7 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi1_client_event_loop->SetRuntimeRealtimePriority(1);
   pi1_message_bridge_client = std::make_unique<MessageBridgeClient>(
-      pi1_client_event_loop.get(), config_sha256);
+      pi1_client_event_loop.get(), config_sha256, SctpAuthMethod::kNoAuth);
 }
 
 void MessageBridgeParameterizedTest::StartPi1Client() {
@@ -168,7 +169,7 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi2_server_event_loop->SetRuntimeRealtimePriority(1);
   pi2_message_bridge_server = std::make_unique<MessageBridgeServer>(
-      pi2_server_event_loop.get(), config_sha256);
+      pi2_server_event_loop.get(), config_sha256, SctpAuthMethod::kNoAuth);
 }
 
 void MessageBridgeParameterizedTest::RunPi2Server(
@@ -202,7 +203,7 @@
       std::make_unique<aos::ShmEventLoop>(&config.message());
   pi2_client_event_loop->SetRuntimeRealtimePriority(1);
   pi2_message_bridge_client = std::make_unique<MessageBridgeClient>(
-      pi2_client_event_loop.get(), config_sha256);
+      pi2_client_event_loop.get(), config_sha256, SctpAuthMethod::kNoAuth);
 }
 
 void MessageBridgeParameterizedTest::RunPi2Client(
diff --git a/aos/network/sctp_config.fbs b/aos/network/sctp_config.fbs
new file mode 100644
index 0000000..4c1819b
--- /dev/null
+++ b/aos/network/sctp_config.fbs
@@ -0,0 +1,10 @@
+namespace aos.message_bridge;
+
+// SCTP Configuration options for message bridge.
+table SctpConfig {
+  // The authentication key to use.
+  key:[ubyte] (id: 0);
+}
+
+root_type SctpConfig;
+
diff --git a/aos/network/sctp_config_request.fbs b/aos/network/sctp_config_request.fbs
new file mode 100644
index 0000000..196589b
--- /dev/null
+++ b/aos/network/sctp_config_request.fbs
@@ -0,0 +1,10 @@
+namespace aos.message_bridge;
+
+// SCTP configuration requests for message bridge.
+table SctpConfigRequest {
+  // When set, the authentication key is being requested.
+  request_key:bool (id: 0);
+}
+
+root_type SctpConfigRequest;
+