Merge "aos/network: opt in message_bridge to dynamic vlog"
diff --git a/aos/events/event_loop.cc b/aos/events/event_loop.cc
index 3679062..df0390c 100644
--- a/aos/events/event_loop.cc
+++ b/aos/events/event_loop.cc
@@ -61,6 +61,24 @@
                 realtime_remote_time, remote_queue_index, source_boot_uuid);
 }
 
+void RawSender::RecordSendResult(const Error error, size_t message_size) {
+  switch (error) {
+    case Error::kOk: {
+      if (timing_.sender) {
+        timing_.size.Add(message_size);
+        timing_.sender->mutate_count(timing_.sender->count() + 1);
+      }
+      break;
+    }
+    case Error::kMessagesSentTooFast:
+      timing_.IncrementError(timing::SendError::MESSAGE_SENT_TOO_FAST);
+      break;
+    case Error::kInvalidRedzone:
+      timing_.IncrementError(timing::SendError::INVALID_REDZONE);
+      break;
+  }
+}
+
 RawFetcher::RawFetcher(EventLoop *event_loop, const Channel *channel)
     : event_loop_(event_loop),
       channel_(channel),
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index c2498b3..984a006 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -244,6 +244,8 @@
                        uint32_t remote_queue_index,
                        const UUID &source_boot_uuid);
 
+  void RecordSendResult(const Error error, size_t message_size);
+
   EventLoop *const event_loop_;
   const Channel *const channel_;
   const std::string ftrace_prefix_;
diff --git a/aos/events/event_loop_param_test.cc b/aos/events/event_loop_param_test.cc
index c8ceeb5..f9e673c 100644
--- a/aos/events/event_loop_param_test.cc
+++ b/aos/events/event_loop_param_test.cc
@@ -1918,6 +1918,63 @@
   }
 }
 
+// Tests that the RawSender::Send(void*, size_t) overload tracks things properly
+// in its timing report.
+TEST_P(AbstractEventLoopTest, CopySenderTimingReport) {
+  gflags::FlagSaver flag_saver;
+  FLAGS_timing_report_ms = 1000;
+  auto loop1 = Make();
+  auto loop2 = MakePrimary();
+
+  const FlatbufferDetachedBuffer<TestMessage> kMessage =
+      JsonToFlatbuffer<TestMessage>("{}");
+
+  std::unique_ptr<aos::RawSender> sender =
+      loop2->MakeRawSender(configuration::GetChannel(
+          loop2->configuration(), "/test", "aos.TestMessage", "", nullptr));
+
+  Fetcher<timing::Report> report_fetcher =
+      loop1->MakeFetcher<timing::Report>("/aos");
+  EXPECT_FALSE(report_fetcher.Fetch());
+
+  loop2->OnRun([&]() {
+    for (int ii = 0; ii < TestChannelQueueSize(loop2.get()); ++ii) {
+      EXPECT_EQ(sender->Send(kMessage.span().data(), kMessage.span().size()),
+                RawSender::Error::kOk);
+    }
+    EXPECT_EQ(sender->Send(kMessage.span().data(), kMessage.span().size()),
+              RawSender::Error::kMessagesSentTooFast);
+  });
+  // Quit after 1 timing report, mid way through the next cycle.
+  EndEventLoop(loop2.get(), chrono::milliseconds(1500));
+
+  Run();
+
+  if (do_timing_reports() == DoTimingReports::kYes) {
+    // Check that the sent too fast actually got recorded by the timing report.
+    FlatbufferDetachedBuffer<timing::Report> primary_report =
+        FlatbufferDetachedBuffer<timing::Report>::Empty();
+    while (report_fetcher.FetchNext()) {
+      if (report_fetcher->name()->string_view() == "primary") {
+        primary_report = CopyFlatBuffer(report_fetcher.get());
+      }
+    }
+
+    EXPECT_EQ(primary_report.message().name()->string_view(), "primary");
+
+    ASSERT_NE(primary_report.message().senders(), nullptr);
+    EXPECT_EQ(primary_report.message().senders()->size(), 3);
+    EXPECT_EQ(
+        primary_report.message()
+            .senders()
+            ->Get(0)
+            ->error_counts()
+            ->Get(static_cast<size_t>(timing::SendError::MESSAGE_SENT_TOO_FAST))
+            ->count(),
+        1);
+  }
+}
+
 // Tests that senders count correctly in the timing report.
 TEST_P(AbstractEventLoopTest, WatcherTimingReport) {
   FLAGS_timing_report_ms = 1000;
diff --git a/aos/events/event_loop_tmpl.h b/aos/events/event_loop_tmpl.h
index 1382332..b3c36cc 100644
--- a/aos/events/event_loop_tmpl.h
+++ b/aos/events/event_loop_tmpl.h
@@ -143,26 +143,13 @@
     uint32_t remote_queue_index, const UUID &uuid) {
   const auto err = DoSend(size, monotonic_remote_time, realtime_remote_time,
                           remote_queue_index, uuid);
-  switch (err) {
-    case Error::kOk: {
-      if (timing_.sender) {
-        timing_.size.Add(size);
-        timing_.sender->mutate_count(timing_.sender->count() + 1);
-      }
-      ftrace_.FormatMessage(
-          "%.*s: sent internal: event=%" PRId64 " queue=%" PRIu32,
-          static_cast<int>(ftrace_prefix_.size()), ftrace_prefix_.data(),
-          static_cast<int64_t>(
-              monotonic_sent_time().time_since_epoch().count()),
-          sent_queue_index());
-      break;
-    }
-    case Error::kMessagesSentTooFast:
-      timing_.IncrementError(timing::SendError::MESSAGE_SENT_TOO_FAST);
-      break;
-    case Error::kInvalidRedzone:
-      timing_.IncrementError(timing::SendError::INVALID_REDZONE);
-      break;
+  RecordSendResult(err, size);
+  if (err == Error::kOk) {
+    ftrace_.FormatMessage(
+        "%.*s: sent internal: event=%" PRId64 " queue=%" PRIu32,
+        static_cast<int>(ftrace_prefix_.size()), ftrace_prefix_.data(),
+        static_cast<int64_t>(monotonic_sent_time().time_since_epoch().count()),
+        sent_queue_index());
   }
   return err;
 }
@@ -179,11 +166,8 @@
     uint32_t remote_queue_index, const UUID &uuid) {
   const auto err = DoSend(data, size, monotonic_remote_time,
                           realtime_remote_time, remote_queue_index, uuid);
+  RecordSendResult(err, size);
   if (err == RawSender::Error::kOk) {
-    if (timing_.sender) {
-      timing_.size.Add(size);
-      timing_.sender->mutate_count(timing_.sender->count() + 1);
-    }
     ftrace_.FormatMessage(
         "%.*s: sent external: event=%" PRId64 " queue=%" PRIu32,
         static_cast<int>(ftrace_prefix_.size()), ftrace_prefix_.data(),
@@ -206,11 +190,8 @@
   const size_t size = data->size();
   const auto err = DoSend(std::move(data), monotonic_remote_time,
                           realtime_remote_time, remote_queue_index, uuid);
+  RecordSendResult(err, size);
   if (err == Error::kOk) {
-    if (timing_.sender) {
-      timing_.size.Add(size);
-      timing_.sender->mutate_count(timing_.sender->count() + 1);
-    }
     ftrace_.FormatMessage(
         "%.*s: sent shared: event=%" PRId64 " queue=%" PRIu32,
         static_cast<int>(ftrace_prefix_.size()), ftrace_prefix_.data(),
diff --git a/aos/events/logging/log_writer.cc b/aos/events/logging/log_writer.cc
index b8d0c7b..2a53f05 100644
--- a/aos/events/logging/log_writer.cc
+++ b/aos/events/logging/log_writer.cc
@@ -95,7 +95,8 @@
         if (it != timestamp_logger_channels.end()) {
           CHECK(!is_split);
           CHECK_LT(channel_index, std::get<2>(it->second).size());
-          std::get<2>(it->second)[channel_index] = (connection->time_to_live() == 0);
+          std::get<2>(it->second)[channel_index] =
+              (connection->time_to_live() == 0);
         } else {
           if (is_split) {
             timestamp_logger_channels.insert(std::make_pair(
@@ -272,7 +273,7 @@
   return true;
 }
 
-std::string Logger::WriteConfiguration(LogNamer* log_namer) {
+std::string Logger::WriteConfiguration(LogNamer *log_namer) {
   std::string config_sha256;
 
   if (separate_config_) {
@@ -351,8 +352,8 @@
   const aos::monotonic_clock::time_point header_time =
       event_loop_->monotonic_now();
 
-  VLOG(1) << "Logging node as " << FlatbufferToJson(node_)
-          << " start_time " << last_synchronized_time_ << ", took "
+  VLOG(1) << "Logging node as " << FlatbufferToJson(node_) << " start_time "
+          << last_synchronized_time_ << ", took "
           << chrono::duration<double>(fetch_time - beginning_time).count()
           << " to fetch, "
           << chrono::duration<double>(header_time - fetch_time).count()
@@ -373,8 +374,8 @@
                         polling_period_);
 }
 
-std::unique_ptr<LogNamer> Logger::RestartLogging(std::unique_ptr<LogNamer> log_namer,
-                          std::optional<UUID> log_start_uuid) {
+std::unique_ptr<LogNamer> Logger::RestartLogging(
+    std::unique_ptr<LogNamer> log_namer, std::optional<UUID> log_start_uuid) {
   CHECK(log_namer_) << ": Unexpected restart while not logging";
 
   VLOG(1) << "Restarting logger for " << FlatbufferToJson(node_);
@@ -382,13 +383,38 @@
   // Force out every currently pending message, pointing all fetchers at the
   // last (currently available) records.  Note that LogUntil() updates
   // last_synchronized_time_ to the time value that it receives.
-  while(LogUntil(last_synchronized_time_ + polling_period_));
+  while (LogUntil(last_synchronized_time_ + polling_period_))
+    ;
 
   std::unique_ptr<LogNamer> old_log_namer = std::move(log_namer_);
   log_namer_ = std::move(log_namer);
 
+  // Now grab a representative time on both the RT and monotonic clock.  Average
+  // a monotonic clock before and after to reduce the error.
   const aos::monotonic_clock::time_point beginning_time =
       event_loop_->monotonic_now();
+  const aos::realtime_clock::time_point beginning_time_rt =
+      event_loop_->realtime_now();
+  const aos::monotonic_clock::time_point beginning_time2 =
+      event_loop_->monotonic_now();
+
+  if (beginning_time > last_synchronized_time_) {
+    LOG(WARNING) << "Took over " << polling_period_.count()
+                 << "ns to swap log_namer";
+  }
+
+  // Since we are going to log all in 1 big go, we need our log start time to be
+  // after the previous LogUntil call finished, but before 1 period after it.
+  // The best way to guarentee that is to pick a start time that is the earliest
+  // of the two.  That covers the case where the OS puts us to sleep between
+  // when we finish LogUntil and capture beginning_time.
+  const aos::monotonic_clock::time_point monotonic_start_time =
+      std::min(last_synchronized_time_, beginning_time);
+  const aos::realtime_clock::time_point realtime_start_time =
+      (beginning_time_rt + (monotonic_start_time.time_since_epoch() -
+                            ((beginning_time.time_since_epoch() +
+                              beginning_time2.time_since_epoch()) /
+                             2)));
 
   auto config_sha256 = WriteConfiguration(log_namer_.get());
 
@@ -400,15 +426,15 @@
   // Note that WriteHeader updates last_synchronized_time_ to be the
   // current time when it is called, which is then the "start time"
   // of the new (restarted) log. This timestamp will be after
-  // the timestamp of the last message fetched on each channel.
-  WriteHeader();
+  // the timestamp of the last message fetched on each channel, but is carefully
+  // picked per the comment above to not violate max_out_of_order_duration.
+  WriteHeader(monotonic_start_time, realtime_start_time);
 
   const aos::monotonic_clock::time_point header_time =
       event_loop_->monotonic_now();
 
   // Write the transition record(s) for each channel ...
   for (FetcherStruct &f : fetchers_) {
-
     // Create writers from the new namer
     NewDataWriter *next_writer = nullptr;
     NewDataWriter *next_timestamp_writer = nullptr;
@@ -426,37 +452,38 @@
     }
 
     if (f.fetcher->context().data != nullptr) {
-
-      // Write the last message fetched as the first of the new log of this type.
-      // The timestamps on these will all be before the new start time.
+      // Write the last message fetched as the first of the new log of this
+      // type. The timestamps on these will all be before the new start time.
       WriteData(next_writer, f);
       WriteTimestamps(next_timestamp_writer, f);
       WriteContent(next_contents_writer, f);
 
-      // It is possible that a few more snuck in. Write them all out also, including
-      // any that should also be in the old log.
+      // It is possible that a few more snuck in. Write them all out also,
+      // including any that should also be in the old log.
       while (true) {
-          // Get the next message ...
-          const auto start = event_loop_->monotonic_now();
-          const bool got_new = f.fetcher->FetchNext();
-          const auto end = event_loop_->monotonic_now();
-          RecordFetchResult(start, end, got_new, &f);
+        // Get the next message ...
+        const auto start = event_loop_->monotonic_now();
+        const bool got_new = f.fetcher->FetchNext();
+        const auto end = event_loop_->monotonic_now();
+        RecordFetchResult(start, end, got_new, &f);
 
-          if (got_new) {
-            if (f.fetcher->context().monotonic_event_time < last_synchronized_time_) {
-              WriteFetchedRecord(f);
-            }
-
+        if (got_new) {
+          if (f.fetcher->context().monotonic_event_time <=
+              last_synchronized_time_) {
+            WriteFetchedRecord(f);
             WriteData(next_writer, f);
             WriteTimestamps(next_timestamp_writer, f);
             WriteContent(next_contents_writer, f);
 
-            if (f.fetcher->context().monotonic_event_time > last_synchronized_time_) {
-              break;
-            }
           } else {
+            f.written = false;
             break;
           }
+
+        } else {
+          f.written = true;
+          break;
+        }
       }
     }
 
@@ -464,18 +491,18 @@
     f.writer = next_writer;
     f.timestamp_writer = next_timestamp_writer;
     f.contents_writer = next_contents_writer;
-    f.written = true;
   }
 
   const aos::monotonic_clock::time_point channel_time =
       event_loop_->monotonic_now();
 
-  VLOG(1) << "Logging node as " << FlatbufferToJson(node_)
-          << " restart_time " << last_synchronized_time_ << ", took "
+  VLOG(1) << "Logging node as " << FlatbufferToJson(node_) << " restart_time "
+          << last_synchronized_time_ << ", took "
           << chrono::duration<double>(header_time - beginning_time).count()
           << " to prepare and write header, "
           << chrono::duration<double>(channel_time - header_time).count()
-          << " to write initial channel messages, boot uuid " << event_loop_->boot_uuid();
+          << " to write initial channel messages, boot uuid "
+          << event_loop_->boot_uuid();
 
   return old_log_namer;
 }
@@ -505,15 +532,16 @@
   return std::move(log_namer_);
 }
 
-void Logger::WriteHeader() {
+void Logger::WriteHeader(aos::monotonic_clock::time_point monotonic_start_time,
+                         aos::realtime_clock::time_point realtime_start_time) {
   if (configuration::MultiNode(configuration_)) {
     server_statistics_fetcher_.Fetch();
   }
 
-  const aos::monotonic_clock::time_point monotonic_start_time =
-      event_loop_->monotonic_now();
-  const aos::realtime_clock::time_point realtime_start_time =
-      event_loop_->realtime_now();
+  if (monotonic_start_time == aos::monotonic_clock::min_time) {
+    monotonic_start_time = event_loop_->monotonic_now();
+    realtime_start_time = event_loop_->realtime_now();
+  }
 
   // We need to pick a point in time to declare the log file "started".  This
   // starts here.  It needs to be after everything is fetched so that the
@@ -732,26 +760,26 @@
                                        max_header_size_);
     fbb.ForceDefaults(true);
 
-    fbb.FinishSizePrefixed(PackMessage(&fbb, f.fetcher->context(),
-                                       f.channel_index, f.log_type));
+    fbb.FinishSizePrefixed(
+        PackMessage(&fbb, f.fetcher->context(), f.channel_index, f.log_type));
     const auto end = event_loop_->monotonic_now();
     RecordCreateMessageTime(start, end, f);
 
-    max_header_size_ = std::max(max_header_size_,
-                                fbb.GetSize() - f.fetcher->context().size);
+    max_header_size_ =
+        std::max(max_header_size_, fbb.GetSize() - f.fetcher->context().size);
     writer->QueueMessage(&fbb, source_node_boot_uuid, end);
 
-    VLOG(2) << "Wrote data as node "
-            << FlatbufferToJson(node_) << " for channel "
+    VLOG(2) << "Wrote data as node " << FlatbufferToJson(node_)
+            << " for channel "
             << configuration::CleanedChannelToString(f.fetcher->channel())
             << " to " << writer->filename() << " data "
-            << FlatbufferToJson(
-                   flatbuffers::GetSizePrefixedRoot<MessageHeader>(
-                       fbb.GetBufferPointer()));
+            << FlatbufferToJson(flatbuffers::GetSizePrefixedRoot<MessageHeader>(
+                   fbb.GetBufferPointer()));
   }
 }
 
-void Logger::WriteTimestamps(NewDataWriter *timestamp_writer, const FetcherStruct &f) {
+void Logger::WriteTimestamps(NewDataWriter *timestamp_writer,
+                             const FetcherStruct &f) {
   if (timestamp_writer != nullptr) {
     // And now handle timestamps.
     const auto start = event_loop_->monotonic_now();
@@ -771,17 +799,17 @@
         f.fetcher->context().monotonic_event_time, f.reliable_forwarding);
     timestamp_writer->QueueMessage(&fbb, event_loop_->boot_uuid(), end);
 
-    VLOG(2) << "Wrote timestamps as node "
-            << FlatbufferToJson(node_) << " for channel "
+    VLOG(2) << "Wrote timestamps as node " << FlatbufferToJson(node_)
+            << " for channel "
             << configuration::CleanedChannelToString(f.fetcher->channel())
             << " to " << timestamp_writer->filename() << " timestamp "
-            << FlatbufferToJson(
-                   flatbuffers::GetSizePrefixedRoot<MessageHeader>(
-                       fbb.GetBufferPointer()));
+            << FlatbufferToJson(flatbuffers::GetSizePrefixedRoot<MessageHeader>(
+                   fbb.GetBufferPointer()));
   }
 }
 
-void Logger::WriteContent(NewDataWriter *contents_writer, const FetcherStruct &f) {
+void Logger::WriteContent(NewDataWriter *contents_writer,
+                          const FetcherStruct &f) {
   if (contents_writer != nullptr) {
     const auto start = event_loop_->monotonic_now();
     // And now handle the special message contents channel.  Copy the
@@ -841,15 +869,16 @@
             ? f.channel_reliable_contents[msg->channel_index()]
             : f.reliable_contents;
 
-    contents_writer->UpdateRemote(node_index_, event_loop_->boot_uuid(),
+    contents_writer->UpdateRemote(
+        node_index_, event_loop_->boot_uuid(),
         monotonic_clock::time_point(
             chrono::nanoseconds(msg->monotonic_remote_time())),
         monotonic_clock::time_point(
             chrono::nanoseconds(msg->monotonic_sent_time())),
         reliable, monotonic_timestamp_time);
 
-    contents_writer->QueueMessage(
-        &fbb, UUID::FromVector(msg->boot_uuid()), end);
+    contents_writer->QueueMessage(&fbb, UUID::FromVector(msg->boot_uuid()),
+                                  end);
   }
 }
 
diff --git a/aos/events/logging/log_writer.h b/aos/events/logging/log_writer.h
index 3c75f29..7beef03 100644
--- a/aos/events/logging/log_writer.h
+++ b/aos/events/logging/log_writer.h
@@ -224,7 +224,10 @@
   // Start/Restart write configuration into LogNamer space.
   std::string WriteConfiguration(LogNamer* log_namer);
 
-  void WriteHeader();
+  void WriteHeader(aos::monotonic_clock::time_point monotonic_start_time =
+                       aos::monotonic_clock::min_time,
+                   aos::realtime_clock::time_point realtime_start_time =
+                       aos::realtime_clock::min_time);
 
   // Makes a template header for all the follower nodes.
   aos::SizePrefixedFlatbufferDetachedBuffer<LogFileHeader> MakeHeader(
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index b2bf1e4..b7ee3cc 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -164,21 +164,24 @@
       case SCTP_ASSOC_CHANGE: {
         const struct sctp_assoc_change *sac = &snp->sn_assoc_change;
         switch (sac->sac_state) {
+          case SCTP_RESTART:
+            NodeDisconnected();
+            [[fallthrough]];
           case SCTP_COMM_UP:
             NodeConnected(sac->sac_assoc_id);
 
             VLOG(1) << "Received up from " << message->PeerAddress() << " on "
-                    << sac->sac_assoc_id;
+                    << sac->sac_assoc_id << " state " << sac->sac_state;
             break;
           case SCTP_COMM_LOST:
           case SCTP_SHUTDOWN_COMP:
           case SCTP_CANT_STR_ASSOC: {
             NodeDisconnected();
             VLOG(1) << "Disconnect from " << message->PeerAddress() << " on "
-                    << sac->sac_assoc_id;
+                    << sac->sac_assoc_id << " state " << sac->sac_state;
           } break;
-          case SCTP_RESTART:
-            LOG(FATAL) << "Never seen this before.";
+          default:
+            LOG(FATAL) << "Never seen state " << sac->sac_state << " before.";
             break;
         }
       } break;
@@ -193,19 +196,26 @@
 }
 
 void SctpClientConnection::SendConnect() {
+  ScheduleConnectTimeout();
+
+  // If we're already connected, assume something went wrong and abort
+  // the connection.
+  if (client_status_->GetClientConnection(client_index_)->state() ==
+      aos::message_bridge::State::CONNECTED) {
+    client_.Abort();
+    return;
+  }
   // Try to send the connect message.  If that fails, retry.
   if (client_.Send(kConnectStream(),
                    std::string_view(reinterpret_cast<const char *>(
                                         connect_message_.span().data()),
                                     connect_message_.span().size()),
                    0)) {
-    VLOG(1) << "Connect to " << remote_node_->hostname()->string_view()
+    VLOG(1) << "Sending connect to " << remote_node_->hostname()->string_view()
             << " succeeded.";
-    ScheduleConnectTimeout();
   } else {
     VLOG(1) << "Connect to " << remote_node_->hostname()->string_view()
             << " failed.";
-    NodeDisconnected();
   }
 }
 
@@ -215,11 +225,14 @@
   // We want to tell the kernel to schedule the packets on this new stream with
   // the priority scheduler.  This only needs to be done once per stream.
   client_.SetPriorityScheduler(assoc_id);
+  client_.SetAssociationId(assoc_id);
 
   client_status_->Connect(client_index_);
 }
 
 void SctpClientConnection::NodeDisconnected() {
+  client_.SetAssociationId(0);
+
   connect_timer_->Setup(
       event_loop_->monotonic_now() + chrono::milliseconds(100),
       chrono::milliseconds(100));
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index de30af1..f6c0c4d 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -434,10 +434,13 @@
       case SCTP_ASSOC_CHANGE: {
         const struct sctp_assoc_change *sac = &snp->sn_assoc_change;
         switch (sac->sac_state) {
+          case SCTP_RESTART:
+            NodeDisconnected(sac->sac_assoc_id);
+            [[fallthrough]];
           case SCTP_COMM_UP:
             NodeConnected(sac->sac_assoc_id);
             VLOG(1) << "Received up from " << message->PeerAddress() << " on "
-                    << sac->sac_assoc_id;
+                    << sac->sac_assoc_id << " state " << sac->sac_state;
             break;
           case SCTP_COMM_LOST:
           case SCTP_SHUTDOWN_COMP:
@@ -446,8 +449,8 @@
             VLOG(1) << "Disconnect from " << message->PeerAddress() << " on "
                     << sac->sac_assoc_id << " state " << sac->sac_state;
             break;
-          case SCTP_RESTART:
-            LOG(FATAL) << "Never seen this before.";
+          default:
+            LOG(FATAL) << "Never seen state " << sac->sac_state << " before.";
             break;
         }
       } break;
diff --git a/aos/network/sctp_client.h b/aos/network/sctp_client.h
index 9356612..1ba32ff 100644
--- a/aos/network/sctp_client.h
+++ b/aos/network/sctp_client.h
@@ -26,9 +26,13 @@
   // Sends a block of data on a stream with a TTL.
   // TODO(austin): time_to_live should be a chrono::duration
   bool Send(int stream, std::string_view data, int time_to_live) {
-    return sctp_.SendMessage(stream, data, time_to_live, sockaddr_remote_, 0);
+    return sctp_.SendMessage(stream, data, time_to_live, sockaddr_remote_,
+                             sac_assoc_id_);
   }
 
+  // Aborts a connection.  Returns true on success.
+  bool Abort() { return sctp_.Abort(sac_assoc_id_); }
+
   int fd() { return sctp_.fd(); }
 
   // Enables the priority scheduler.  This is a SCTP feature which lets us
@@ -45,10 +49,17 @@
 
   void SetMaxSize(size_t max_size) { sctp_.SetMaxSize(max_size); }
 
+  void SetAssociationId(sctp_assoc_t sac_assoc_id) {
+    sac_assoc_id_ = sac_assoc_id;
+  }
+
  private:
   struct sockaddr_storage sockaddr_remote_;
   struct sockaddr_storage sockaddr_local_;
   SctpReadWrite sctp_;
+
+  // Valid if != 0.
+  sctp_assoc_t sac_assoc_id_ = 0;
 };
 
 }  // namespace message_bridge
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 21c0b78..53a28cb 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -481,6 +481,48 @@
   }
 }
 
+bool SctpReadWrite::Abort(sctp_assoc_t snd_assoc_id) {
+  if (fd_ == -1) {
+    return true;
+  }
+  VLOG(1) << "Sending abort to assoc " << snd_assoc_id;
+
+  // Use the assoc_id for the destination instead of the msg_name.
+  struct msghdr outmsg;
+  outmsg.msg_namelen = 0;
+
+  outmsg.msg_iovlen = 0;
+
+  // Build up the sndinfo message.
+  char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
+  outmsg.msg_control = outcmsg;
+  outmsg.msg_controllen = CMSG_SPACE(sizeof(struct sctp_sndrcvinfo));
+  outmsg.msg_flags = 0;
+
+  struct cmsghdr *cmsg = CMSG_FIRSTHDR(&outmsg);
+  cmsg->cmsg_level = IPPROTO_SCTP;
+  cmsg->cmsg_type = SCTP_SNDRCV;
+  cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
+
+  struct sctp_sndrcvinfo *sinfo = (struct sctp_sndrcvinfo *)CMSG_DATA(cmsg);
+  memset(sinfo, 0, sizeof(struct sctp_sndrcvinfo));
+  sinfo->sinfo_stream = 0;
+  sinfo->sinfo_flags = SCTP_ABORT;
+  sinfo->sinfo_assoc_id = snd_assoc_id;
+
+  // And send.
+  const ssize_t size = sendmsg(fd_, &outmsg, MSG_NOSIGNAL | MSG_DONTWAIT);
+  if (size == -1) {
+    if (errno == EPIPE || errno == EAGAIN || errno == ESHUTDOWN) {
+      return false;
+    }
+    return false;
+  } else {
+    CHECK_EQ(0, size);
+    return true;
+  }
+}
+
 void SctpReadWrite::CloseSocket() {
   if (fd_ == -1) {
     return;
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index e81e6b3..6cd11a3 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -100,6 +100,9 @@
   // Returns nullptr if the kernel blocks before returning a complete message.
   aos::unique_c_ptr<Message> ReadMessage();
 
+  // Send an abort message for the given association.
+  bool Abort(sctp_assoc_t snd_assoc_id);
+
   int fd() const { return fd_; }
 
   void SetMaxSize(size_t max_size) {
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index 93d1e88..15da7c0 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -69,43 +69,6 @@
   }
 }
 
-bool SctpServer::Abort(sctp_assoc_t snd_assoc_id) {
-  // Use the assoc_id for the destination instead of the msg_name.
-  struct msghdr outmsg;
-  outmsg.msg_namelen = 0;
-
-  outmsg.msg_iovlen = 0;
-
-  // Build up the sndinfo message.
-  char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
-  outmsg.msg_control = outcmsg;
-  outmsg.msg_controllen = CMSG_SPACE(sizeof(struct sctp_sndrcvinfo));
-  outmsg.msg_flags = 0;
-
-  struct cmsghdr *cmsg = CMSG_FIRSTHDR(&outmsg);
-  cmsg->cmsg_level = IPPROTO_SCTP;
-  cmsg->cmsg_type = SCTP_SNDRCV;
-  cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
-
-  struct sctp_sndrcvinfo *sinfo = (struct sctp_sndrcvinfo *)CMSG_DATA(cmsg);
-  memset(sinfo, 0, sizeof(struct sctp_sndrcvinfo));
-  sinfo->sinfo_stream = 0;
-  sinfo->sinfo_flags = SCTP_ABORT;
-  sinfo->sinfo_assoc_id = snd_assoc_id;
-
-  // And send.
-  const ssize_t size = sendmsg(fd(), &outmsg, MSG_NOSIGNAL | MSG_DONTWAIT);
-  if (size == -1) {
-    if (errno == EPIPE || errno == EAGAIN || errno == ESHUTDOWN) {
-      return false;
-    }
-    return false;
-  } else {
-    CHECK_EQ(0, size);
-    return true;
-  }
-}
-
 void SctpServer::SetPriorityScheduler(sctp_assoc_t assoc_id) {
   struct sctp_assoc_value scheduler;
   memset(&scheduler, 0, sizeof(scheduler));
diff --git a/aos/network/sctp_server.h b/aos/network/sctp_server.h
index 6ffe93e..dbfd1ac 100644
--- a/aos/network/sctp_server.h
+++ b/aos/network/sctp_server.h
@@ -39,7 +39,7 @@
   }
 
   // Aborts a connection.  Returns true on success.
-  bool Abort(sctp_assoc_t snd_assoc_id);
+  bool Abort(sctp_assoc_t snd_assoc_id) { return sctp_.Abort(snd_assoc_id); }
 
   int fd() { return sctp_.fd(); }