Move ReadSctpMessage to a class

In preparation for reassembling partial messages in userspace.

Change-Id: Ifa530698058ea775362eee4ec1bf9e6e0d3dd5de
Signed-off-by: Austin Schuh <austin.schuh@bluerivertech.com>
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index c5e2080..07bc86c 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -185,9 +185,93 @@
             << " sstat_primary.spinfo_rto:" << status.sstat_primary.spinfo_rto;
 }
 
-aos::unique_c_ptr<Message> ReadSctpMessage(int fd, size_t max_size) {
+void SctpReadWrite::OpenSocket(const struct sockaddr_storage &sockaddr_local) {
+  fd_ = socket(sockaddr_local.ss_family, SOCK_SEQPACKET, IPPROTO_SCTP);
+  PCHECK(fd_ != -1);
+  LOG(INFO) << "socket(" << Family(sockaddr_local)
+            << ", SOCK_SEQPACKET, IPPROTOSCTP) = " << fd_;
+  {
+    // Per https://tools.ietf.org/html/rfc6458
+    // Setting this to !0 allows event notifications to be interleaved
+    // with data if enabled, and would have to be handled in the code.
+    // Enabling interleaving would only matter during congestion, which
+    // typically only happens during application startup.
+    int interleaving = 0;
+    PCHECK(setsockopt(fd_, IPPROTO_SCTP, SCTP_FRAGMENT_INTERLEAVE,
+                      &interleaving, sizeof(interleaving)) == 0);
+  }
+  {
+    // Enable recvinfo when a packet arrives.
+    int on = 1;
+    PCHECK(setsockopt(fd_, IPPROTO_SCTP, SCTP_RECVRCVINFO, &on, sizeof(int)) ==
+           0);
+  }
+
+  DoSetMaxSize();
+}
+
+bool SctpReadWrite::SendMessage(
+    int stream, std::string_view data, int time_to_live,
+    std::optional<struct sockaddr_storage> sockaddr_remote,
+    sctp_assoc_t snd_assoc_id) {
+  CHECK(fd_ != -1);
+  struct iovec iov;
+  iov.iov_base = const_cast<char *>(data.data());
+  iov.iov_len = data.size();
+
+  // Use the assoc_id for the destination instead of the msg_name.
+  struct msghdr outmsg;
+  if (sockaddr_remote) {
+    outmsg.msg_name = &*sockaddr_remote;
+    outmsg.msg_namelen = sizeof(*sockaddr_remote);
+    VLOG(1) << "Sending to " << Address(*sockaddr_remote);
+  } else {
+    outmsg.msg_namelen = 0;
+  }
+
+  // Data to send.
+  outmsg.msg_iov = &iov;
+  outmsg.msg_iovlen = 1;
+
+  // Build up the sndinfo message.
+  char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
+  outmsg.msg_control = outcmsg;
+  outmsg.msg_controllen = sizeof(outcmsg);
+  outmsg.msg_flags = 0;
+
+  struct cmsghdr *cmsg = CMSG_FIRSTHDR(&outmsg);
+  cmsg->cmsg_level = IPPROTO_SCTP;
+  cmsg->cmsg_type = SCTP_SNDRCV;
+  cmsg->cmsg_len = CMSG_LEN(sizeof(struct sctp_sndrcvinfo));
+
+  struct sctp_sndrcvinfo *sinfo =
+      reinterpret_cast<struct sctp_sndrcvinfo *>(CMSG_DATA(cmsg));
+  memset(sinfo, 0, sizeof(struct sctp_sndrcvinfo));
+  sinfo->sinfo_ppid = ++send_ppid_;
+  sinfo->sinfo_stream = stream;
+  sinfo->sinfo_flags = 0;
+  sinfo->sinfo_assoc_id = snd_assoc_id;
+  sinfo->sinfo_timetolive = time_to_live;
+
+  // And send.
+  const ssize_t size = sendmsg(fd_, &outmsg, MSG_NOSIGNAL | MSG_DONTWAIT);
+  if (size == -1) {
+    if (errno == EPIPE || errno == EAGAIN || errno == ESHUTDOWN ||
+        errno == EINTR) {
+      return false;
+    }
+    PLOG(FATAL) << "sendmsg on sctp socket failed";
+    return false;
+  }
+  CHECK_EQ(static_cast<ssize_t>(data.size()), size);
+  VLOG(1) << "Sent " << data.size();
+  return true;
+}
+
+aos::unique_c_ptr<Message> SctpReadWrite::ReadMessage() {
+  CHECK(fd_ != -1);
   aos::unique_c_ptr<Message> result(
-      reinterpret_cast<Message *>(malloc(sizeof(Message) + max_size + 1)));
+      reinterpret_cast<Message *>(malloc(sizeof(Message) + max_size_ + 1)));
   result->size = 0;
 
   int count = 0;
@@ -197,7 +281,7 @@
     memset(&inmessage, 0, sizeof(struct msghdr));
 
     struct iovec iov;
-    iov.iov_len = max_size + 1 - result->size;
+    iov.iov_len = max_size_ + 1 - result->size;
     iov.iov_base = result->mutable_data() + result->size;
 
     inmessage.msg_iov = &iov;
@@ -211,7 +295,7 @@
     inmessage.msg_name = &result->sin;
 
     ssize_t size;
-    PCHECK((size = recvmsg(fd, &inmessage, 0)) > 0);
+    PCHECK((size = recvmsg(fd_, &inmessage, 0)) > 0);
 
     if (count > 0) {
       VLOG(1) << "Count: " << count;
@@ -250,8 +334,9 @@
     CHECK_NE(last_flags & MSG_CTRUNC, MSG_CTRUNC)
         << ": Control message truncated.";
 
-    CHECK_LE(result->size, max_size) << ": Message overflowed buffer on stream "
-                                     << result->header.rcvinfo.rcv_sid << ".";
+    CHECK_LE(result->size, max_size_)
+        << ": Message overflowed buffer on stream "
+        << result->header.rcvinfo.rcv_sid << ".";
   }
 
   result->partial_deliveries = count - 1;
@@ -268,6 +353,34 @@
   return result;
 }
 
+void SctpReadWrite::CloseSocket() {
+  if (fd_ == -1) {
+    return;
+  }
+  LOG(INFO) << "close(" << fd_ << ")";
+  PCHECK(close(fd_) == 0);
+  fd_ = -1;
+}
+
+void SctpReadWrite::DoSetMaxSize() {
+  // Have the kernel give us a factor of 10 more.  This lets us have more than
+  // one full sized packet in flight.
+  size_t max_size = max_size_ * 10;
+
+  CHECK_GE(ReadRMemMax(), max_size)
+      << "rmem_max is too low. To increase rmem_max temporarily, do sysctl "
+         "-w net.core.rmem_max="
+      << max_size;
+  CHECK_GE(ReadWMemMax(), max_size)
+      << "wmem_max is too low. To increase wmem_max temporarily, do sysctl "
+         "-w net.core.wmem_max="
+      << max_size;
+  PCHECK(setsockopt(fd(), SOL_SOCKET, SO_RCVBUF, &max_size, sizeof(max_size)) ==
+         0);
+  PCHECK(setsockopt(fd(), SOL_SOCKET, SO_SNDBUF, &max_size, sizeof(max_size)) ==
+         0);
+}
+
 void Message::LogRcvInfo() const {
   LOG(INFO) << "\tSNDRCV (stream=" << header.rcvinfo.rcv_sid
             << " ssn=" << header.rcvinfo.rcv_ssn