Handle SCTP messages fragmented by the kernel

It does this under memory pressure, with no obvious way to turn it off.
Just handle it rather than worry about changing the kernel.

Change-Id: If4f29351e7b2baac2e551d9e4fc5238350b744d5
Signed-off-by: Austin Schuh <austin.schuh@bluerivertech.com>
diff --git a/aos/network/message_bridge_client_lib.cc b/aos/network/message_bridge_client_lib.cc
index 1044b2f..80a41d9 100644
--- a/aos/network/message_bridge_client_lib.cc
+++ b/aos/network/message_bridge_client_lib.cc
@@ -90,7 +90,6 @@
   return fbb.Release();
 }
 
-
 }  // namespace
 
 SctpClientConnection::SctpClientConnection(
@@ -150,6 +149,9 @@
 void SctpClientConnection::MessageReceived() {
   // Dispatch the message to the correct receiver.
   aos::unique_c_ptr<Message> message = client_.Read();
+  if (!message) {
+    return;
+  }
 
   if (message->message_type == Message::kNotification) {
     const union sctp_notification *snp =
@@ -228,9 +230,14 @@
   VLOG(1) << "Got a message of size " << message->size;
   CHECK_EQ(message->size, flatbuffers::GetPrefixedSize(message->data()) +
                               sizeof(flatbuffers::uoffset_t));
+  {
+    flatbuffers::Verifier verifier(message->data(), message->size);
+    CHECK(remote_data->Verify(verifier));
+  }
 
   const int stream = message->header.rcvinfo.rcv_sid - kControlStreams();
-  SctpClientChannelState *channel_state = &((*channels_)[stream_to_channel_[stream]]);
+  SctpClientChannelState *channel_state =
+      &((*channels_)[stream_to_channel_[stream]]);
 
   if (remote_data->queue_index() == channel_state->last_queue_index &&
       monotonic_clock::time_point(
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index c929e31..98f5101 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -329,7 +329,6 @@
 
     if (configuration::ChannelIsSendableOnNode(channel, event_loop_->node()) &&
         channel->has_destination_nodes()) {
-
       bool any_reliable = false;
       for (const Connection *connection : *channel->destination_nodes()) {
         if (connection->time_to_live() == 0) {
@@ -432,6 +431,9 @@
 
 void MessageBridgeServer::MessageReceived() {
   aos::unique_c_ptr<Message> message = server_.Read();
+  if (!message) {
+    return;
+  }
 
   if (message->message_type == Message::kNotification) {
     const union sctp_notification *snp =
@@ -472,6 +474,10 @@
   if (message->header.rcvinfo.rcv_sid == kConnectStream()) {
     // Control channel!
     const Connect *connect = flatbuffers::GetRoot<Connect>(message->data());
+    {
+      flatbuffers::Verifier verifier(message->data(), message->size);
+      CHECK(connect->Verify(verifier));
+    }
     VLOG(1) << FlatbufferToJson(connect);
 
     // Account for the control channel and delivery times channel.
@@ -516,6 +522,10 @@
     // Message delivery
     const logger::MessageHeader *message_header =
         flatbuffers::GetRoot<logger::MessageHeader>(message->data());
+    {
+      flatbuffers::Verifier verifier(message->data(), message->size);
+      CHECK(message_header->Verify(verifier));
+    }
 
     CHECK_LT(message_header->channel_index(), channels_.size());
     CHECK_NOTNULL(channels_[message_header->channel_index()])
diff --git a/aos/network/sctp_client.cc b/aos/network/sctp_client.cc
index 113b525..33115be 100644
--- a/aos/network/sctp_client.cc
+++ b/aos/network/sctp_client.cc
@@ -38,17 +38,6 @@
     PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_NODELAY, &on, sizeof(int)) == 0);
   }
 
-  {
-    // TODO(austin): This is the old style registration...  But, the sctp
-    // stack out in the wild for linux is old and primitive.
-    struct sctp_event_subscribe subscribe;
-    memset(&subscribe, 0, sizeof(subscribe));
-    subscribe.sctp_association_event = 1;
-    subscribe.sctp_stream_change_event = 1;
-    PCHECK(setsockopt(fd(), SOL_SCTP, SCTP_EVENTS, (char *)&subscribe,
-                      sizeof(subscribe)) == 0);
-  }
-
   PCHECK(bind(fd(), (struct sockaddr *)&sockaddr_local_,
               sockaddr_local_.ss_family == AF_INET6
                   ? sizeof(struct sockaddr_in6)
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 07bc86c..4651d82 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -8,6 +8,7 @@
 #include <sys/types.h>
 #include <unistd.h>
 
+#include <algorithm>
 #include <string_view>
 
 #include "aos/util/file.h"
@@ -193,10 +194,14 @@
   {
     // Per https://tools.ietf.org/html/rfc6458
     // Setting this to !0 allows event notifications to be interleaved
-    // with data if enabled, and would have to be handled in the code.
-    // Enabling interleaving would only matter during congestion, which
-    // typically only happens during application startup.
-    int interleaving = 0;
+    // with data if enabled. This typically only matters during congestion.
+    // However, Linux seems to interleave under memory pressure regardless of
+    // this being enabled, so we have to handle it in the code anyways, so might
+    // as well turn it on all the time.
+    // TODO(Brian): Change this to 2 once we have kernels that support it, and
+    // also address the TODO in ProcessNotification to match on all the
+    // necessary fields.
+    int interleaving = 1;
     PCHECK(setsockopt(fd_, IPPROTO_SCTP, SCTP_FRAGMENT_INTERLEAVE,
                       &interleaving, sizeof(interleaving)) == 0);
   }
@@ -207,6 +212,18 @@
            0);
   }
 
+  {
+    // TODO(austin): This is the old style registration...  But, the sctp
+    // stack out in the wild for linux is old and primitive.
+    struct sctp_event_subscribe subscribe;
+    memset(&subscribe, 0, sizeof(subscribe));
+    subscribe.sctp_association_event = 1;
+    subscribe.sctp_stream_change_event = 1;
+    subscribe.sctp_partial_delivery_event = 1;
+    PCHECK(setsockopt(fd(), SOL_SCTP, SCTP_EVENTS, (char *)&subscribe,
+                      sizeof(subscribe)) == 0);
+  }
+
   DoSetMaxSize();
 }
 
@@ -268,21 +285,21 @@
   return true;
 }
 
+// We read each fragment into a fresh Message, because most of them won't be
+// fragmented. If we do end up with a fragment, then we copy the data out of it.
 aos::unique_c_ptr<Message> SctpReadWrite::ReadMessage() {
   CHECK(fd_ != -1);
-  aos::unique_c_ptr<Message> result(
-      reinterpret_cast<Message *>(malloc(sizeof(Message) + max_size_ + 1)));
-  result->size = 0;
 
-  int count = 0;
-  int last_flags = 0;
-  for (count = 0; !(last_flags & MSG_EOR); count++) {
+  while (true) {
+    aos::unique_c_ptr<Message> result(
+        reinterpret_cast<Message *>(malloc(sizeof(Message) + max_size_ + 1)));
+
     struct msghdr inmessage;
     memset(&inmessage, 0, sizeof(struct msghdr));
 
     struct iovec iov;
-    iov.iov_len = max_size_ + 1 - result->size;
-    iov.iov_base = result->mutable_data() + result->size;
+    iov.iov_len = max_size_ + 1;
+    iov.iov_base = result->mutable_data();
 
     inmessage.msg_iov = &iov;
     inmessage.msg_iovlen = 1;
@@ -294,63 +311,127 @@
     inmessage.msg_namelen = sizeof(struct sockaddr_storage);
     inmessage.msg_name = &result->sin;
 
-    ssize_t size;
-    PCHECK((size = recvmsg(fd_, &inmessage, 0)) > 0);
-
-    if (count > 0) {
-      VLOG(1) << "Count: " << count;
-      VLOG(1) << "Last msg_flags: " << last_flags;
-      VLOG(1) << "msg_flags: " << inmessage.msg_flags;
-      VLOG(1) << "Current size: " << result->size;
-      VLOG(1) << "Received size: " << size;
-      CHECK_EQ(MSG_NOTIFICATION & inmessage.msg_flags,
-               MSG_NOTIFICATION & last_flags);
-    }
-
-    result->size += size;
-    last_flags = inmessage.msg_flags;
-
-    for (struct cmsghdr *scmsg = CMSG_FIRSTHDR(&inmessage); scmsg != NULL;
-         scmsg = CMSG_NXTHDR(&inmessage, scmsg)) {
-      switch (scmsg->cmsg_type) {
-        case SCTP_RCVINFO: {
-          struct sctp_rcvinfo *data =
-              reinterpret_cast<struct sctp_rcvinfo *>(CMSG_DATA(scmsg));
-          if (count > 0) {
-            VLOG(1) << "Got sctp_rcvinfo on continued packet";
-            CHECK_EQ(result->header.rcvinfo.rcv_sid, data->rcv_sid);
-            CHECK_EQ(result->header.rcvinfo.rcv_ssn, data->rcv_ssn);
-            CHECK_EQ(result->header.rcvinfo.rcv_ppid, data->rcv_ppid);
-            CHECK_EQ(result->header.rcvinfo.rcv_assoc_id, data->rcv_assoc_id);
-          }
-          result->header.rcvinfo = *data;
-        } break;
-        default:
-          LOG(INFO) << "\tUnknown type: " << scmsg->cmsg_type;
-          break;
+    const ssize_t size = recvmsg(fd_, &inmessage, MSG_DONTWAIT);
+    if (size == -1) {
+      if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) {
+        // These are all non-fatal failures indicating we should retry later.
+        return nullptr;
       }
+      PLOG(FATAL) << "recvmsg on sctp socket " << fd_ << " failed";
     }
 
-    CHECK_NE(last_flags & MSG_CTRUNC, MSG_CTRUNC)
+    CHECK(!(inmessage.msg_flags & MSG_CTRUNC))
         << ": Control message truncated.";
 
-    CHECK_LE(result->size, max_size_)
+    CHECK_LE(size, static_cast<ssize_t>(max_size_))
         << ": Message overflowed buffer on stream "
         << result->header.rcvinfo.rcv_sid << ".";
-  }
 
-  result->partial_deliveries = count - 1;
-  if (count > 1) {
-    VLOG(1) << "Final count: " << count;
-    VLOG(1) << "Final size: " << result->size;
-  }
+    result->size = size;
+    if (MSG_NOTIFICATION & inmessage.msg_flags) {
+      result->message_type = Message::kNotification;
+    } else {
+      result->message_type = Message::kMessage;
+    }
+    result->partial_deliveries = 0;
 
-  if ((MSG_NOTIFICATION & last_flags)) {
-    result->message_type = Message::kNotification;
-  } else {
-    result->message_type = Message::kMessage;
+    {
+      bool found_rcvinfo = false;
+      for (struct cmsghdr *scmsg = CMSG_FIRSTHDR(&inmessage); scmsg != NULL;
+           scmsg = CMSG_NXTHDR(&inmessage, scmsg)) {
+        switch (scmsg->cmsg_type) {
+          case SCTP_RCVINFO: {
+            CHECK(!found_rcvinfo);
+            found_rcvinfo = true;
+            result->header.rcvinfo =
+                *reinterpret_cast<struct sctp_rcvinfo *>(CMSG_DATA(scmsg));
+          } break;
+          default:
+            LOG(INFO) << "\tUnknown type: " << scmsg->cmsg_type;
+            break;
+        }
+      }
+      CHECK_EQ(found_rcvinfo, result->message_type == Message::kMessage)
+          << ": Failed to find a SCTP_RCVINFO cmsghdr. flags: "
+          << inmessage.msg_flags;
+    }
+    if (result->message_type == Message::kNotification) {
+      // Notifications are never fragmented, just return it now.
+      CHECK(inmessage.msg_flags & MSG_EOR)
+          << ": Notifications should never be big enough to fragment";
+      if (ProcessNotification(result.get())) {
+        // We handled this notification internally, so don't pass it on.
+        return nullptr;
+      }
+      return result;
+    }
+
+    auto partial_message_iterator =
+        std::find_if(partial_messages_.begin(), partial_messages_.end(),
+                     [&result](const aos::unique_c_ptr<Message> &candidate) {
+                       return result->header.rcvinfo.rcv_sid ==
+                                  candidate->header.rcvinfo.rcv_sid &&
+                              result->header.rcvinfo.rcv_ssn ==
+                                  candidate->header.rcvinfo.rcv_ssn &&
+                              result->header.rcvinfo.rcv_assoc_id ==
+                                  candidate->header.rcvinfo.rcv_assoc_id;
+                     });
+    if (partial_message_iterator != partial_messages_.end()) {
+      const aos::unique_c_ptr<Message> &partial_message =
+          *partial_message_iterator;
+      // Verify it's really part of the same message.
+      CHECK_EQ(partial_message->message_type, result->message_type)
+          << ": for " << result->header.rcvinfo.rcv_sid << ","
+          << result->header.rcvinfo.rcv_ssn << ","
+          << result->header.rcvinfo.rcv_assoc_id;
+      CHECK_EQ(partial_message->header.rcvinfo.rcv_ppid,
+               result->header.rcvinfo.rcv_ppid)
+          << ": for " << result->header.rcvinfo.rcv_sid << ","
+          << result->header.rcvinfo.rcv_ssn << ","
+          << result->header.rcvinfo.rcv_assoc_id;
+
+      // Now copy the data over and update the size.
+      CHECK_LE(partial_message->size + result->size, max_size_)
+          << ": Assembled fragments overflowed buffer on stream "
+          << result->header.rcvinfo.rcv_sid << ".";
+      memcpy(partial_message->mutable_data() + partial_message->size,
+             result->data(), result->size);
+      ++partial_message->partial_deliveries;
+      VLOG(1) << "Merged fragment of " << result->size << " after "
+              << partial_message->size << ", had "
+              << partial_message->partial_deliveries
+              << ", for: " << result->header.rcvinfo.rcv_sid << ","
+              << result->header.rcvinfo.rcv_ssn << ","
+              << result->header.rcvinfo.rcv_assoc_id;
+      partial_message->size += result->size;
+      result.reset();
+    }
+
+    if (inmessage.msg_flags & MSG_EOR) {
+      // This is the last fragment, so we have something to return.
+      if (partial_message_iterator != partial_messages_.end()) {
+        // It was already merged into the message in the list, so now we pull
+        // that out of the list and return it.
+        CHECK(!result);
+        result = std::move(*partial_message_iterator);
+        partial_messages_.erase(partial_message_iterator);
+        VLOG(1) << "Final count: " << (result->partial_deliveries + 1)
+                << ", size: " << result->size
+                << ", for: " << result->header.rcvinfo.rcv_sid << ","
+                << result->header.rcvinfo.rcv_ssn << ","
+                << result->header.rcvinfo.rcv_assoc_id;
+      }
+      CHECK(result);
+      return result;
+    }
+    if (partial_message_iterator == partial_messages_.end()) {
+      VLOG(1) << "Starting fragment for: " << result->header.rcvinfo.rcv_sid
+              << "," << result->header.rcvinfo.rcv_ssn << ","
+              << result->header.rcvinfo.rcv_assoc_id;
+      // Need to record this as the first fragment.
+      partial_messages_.emplace_back(std::move(result));
+    }
   }
-  return result;
 }
 
 void SctpReadWrite::CloseSocket() {
@@ -381,6 +462,46 @@
          0);
 }
 
+bool SctpReadWrite::ProcessNotification(const Message *message) {
+  const union sctp_notification *const snp =
+      reinterpret_cast<const union sctp_notification *>(message->data());
+  switch (snp->sn_header.sn_type) {
+    case SCTP_PARTIAL_DELIVERY_EVENT: {
+      const struct sctp_pdapi_event *const partial_delivery =
+          &snp->sn_pdapi_event;
+      CHECK_EQ(partial_delivery->pdapi_length, sizeof(*partial_delivery))
+          << ": Kernel's SCTP code is not a version we support";
+      switch (partial_delivery->pdapi_indication) {
+        case SCTP_PARTIAL_DELIVERY_ABORTED: {
+          const auto iterator = std::find_if(
+              partial_messages_.begin(), partial_messages_.end(),
+              [partial_delivery](const aos::unique_c_ptr<Message> &candidate) {
+                // TODO(Brian): Once we have new enough userpace headers, for
+                // kernels that support level-2 interleaving, we'll need to add
+                // this:
+                //   candidate->header.rcvinfo.rcv_sid ==
+                //     partial_delivery->pdapi_stream &&
+                //   candidate->header.rcvinfo.rcv_ssn ==
+                //     partial_delivery->pdapi_seq &&
+                return candidate->header.rcvinfo.rcv_assoc_id ==
+                       partial_delivery->pdapi_assoc_id;
+              });
+          CHECK(iterator != partial_messages_.end())
+              << ": Got out of sync with the kernel for "
+              << partial_delivery->pdapi_assoc_id;
+          VLOG(1) << "Pruning partial delivery for "
+                  << iterator->get()->header.rcvinfo.rcv_sid << ","
+                  << iterator->get()->header.rcvinfo.rcv_ssn << ","
+                  << iterator->get()->header.rcvinfo.rcv_assoc_id;
+          partial_messages_.erase(iterator);
+        }
+          return true;
+      }
+    } break;
+  }
+  return false;
+}
+
 void Message::LogRcvInfo() const {
   LOG(INFO) << "\tSNDRCV (stream=" << header.rcvinfo.rcv_sid
             << " ssn=" << header.rcvinfo.rcv_ssn
diff --git a/aos/network/sctp_lib.h b/aos/network/sctp_lib.h
index 427f828..6b40600 100644
--- a/aos/network/sctp_lib.h
+++ b/aos/network/sctp_lib.h
@@ -8,6 +8,7 @@
 #include <optional>
 #include <string>
 #include <string_view>
+#include <vector>
 
 #include "aos/unique_malloc_ptr.h"
 #include "gflags/gflags.h"
@@ -94,6 +95,9 @@
   int fd() const { return fd_; }
 
   void SetMaxSize(size_t max_size) {
+    CHECK(partial_messages_.empty())
+        << ": May not update size with queued fragments because we do not "
+           "track individual message sizes";
     max_size_ = max_size;
     if (fd_ != -1) {
       DoSetMaxSize();
@@ -104,12 +108,18 @@
   void CloseSocket();
   void DoSetMaxSize();
 
+  // Examines a notification message for ones we handle here.
+  // Returns true if the notification was handled by this class.
+  bool ProcessNotification(const Message *message);
+
   int fd_ = -1;
 
   // We use this as a unique identifier that just increments for each message.
   uint32_t send_ppid_ = 0;
 
   size_t max_size_ = 1000;
+
+  std::vector<aos::unique_c_ptr<Message>> partial_messages_;
 };
 
 // Returns the max network buffer available for reading for a socket.
diff --git a/aos/network/sctp_server.cc b/aos/network/sctp_server.cc
index 327f5a0..894a1f1 100644
--- a/aos/network/sctp_server.cc
+++ b/aos/network/sctp_server.cc
@@ -26,16 +26,6 @@
     sctp_.OpenSocket(sockaddr_local_);
 
     {
-      struct sctp_event_subscribe subscribe;
-      memset(&subscribe, 0, sizeof(subscribe));
-      subscribe.sctp_association_event = 1;
-      subscribe.sctp_send_failure_event = 1;
-      subscribe.sctp_partial_delivery_event = 1;
-
-      PCHECK(setsockopt(fd(), SOL_SCTP, SCTP_EVENTS, (char *)&subscribe,
-                        sizeof(subscribe)) == 0);
-    }
-    {
       // Turn off the NAGLE algorithm.
       int on = 1;
       PCHECK(setsockopt(fd(), IPPROTO_SCTP, SCTP_NODELAY, &on, sizeof(int)) ==