aos/sctp: handle multiple recvmsg calls
Change-Id: I9929c49559f28b5595c8343f1c8e244ce37f7c15
diff --git a/aos/network/sctp_lib.cc b/aos/network/sctp_lib.cc
index 318c5a7..b522851 100644
--- a/aos/network/sctp_lib.cc
+++ b/aos/network/sctp_lib.cc
@@ -180,57 +180,85 @@
<< " sstat_primary.spinfo_rto:" << status.sstat_primary.spinfo_rto;
}
-aos::unique_c_ptr<Message> ReadSctpMessage(int fd, int max_size) {
+aos::unique_c_ptr<Message> ReadSctpMessage(int fd, size_t max_size) {
char incmsg[CMSG_SPACE(sizeof(_sctp_cmsg_data_t))];
struct iovec iov;
struct msghdr inmessage;
- memset(&inmessage, 0, sizeof(struct msghdr));
-
aos::unique_c_ptr<Message> result(
reinterpret_cast<Message *>(malloc(sizeof(Message) + max_size + 1)));
+ result->size = 0;
- iov.iov_len = max_size + 1;
- iov.iov_base = result->mutable_data();
+ int count = 0;
+ int last_flags = 0;
+ for (count = 0; !(last_flags & MSG_EOR); count++) {
+ memset(&inmessage, 0, sizeof(struct msghdr));
- inmessage.msg_iov = &iov;
- inmessage.msg_iovlen = 1;
+ iov.iov_len = max_size + 1 - result->size;
+ iov.iov_base = result->mutable_data() + result->size;
- inmessage.msg_control = incmsg;
- inmessage.msg_controllen = sizeof(incmsg);
+ inmessage.msg_iov = &iov;
+ inmessage.msg_iovlen = 1;
- inmessage.msg_namelen = sizeof(struct sockaddr_storage);
- inmessage.msg_name = &result->sin;
+ inmessage.msg_control = incmsg;
+ inmessage.msg_controllen = sizeof(incmsg);
- ssize_t size;
- PCHECK((size = recvmsg(fd, &inmessage, 0)) > 0);
+ inmessage.msg_namelen = sizeof(struct sockaddr_storage);
+ inmessage.msg_name = &result->sin;
- result->size = size;
+ ssize_t size;
+ PCHECK((size = recvmsg(fd, &inmessage, 0)) > 0);
+
+ if (count > 0) {
+ VLOG(1) << "Count: " << count;
+ VLOG(1) << "Last msg_flags: " << last_flags;
+ VLOG(1) << "msg_flags: " << inmessage.msg_flags;
+ VLOG(1) << "Current size: " << result->size;
+ VLOG(1) << "Received size: " << size;
+ CHECK_EQ(MSG_NOTIFICATION & inmessage.msg_flags, MSG_NOTIFICATION & last_flags);
+ }
+
+ result->size += size;
+ last_flags = inmessage.msg_flags;
+
+ for (struct cmsghdr *scmsg = CMSG_FIRSTHDR(&inmessage); scmsg != NULL;
+ scmsg = CMSG_NXTHDR(&inmessage, scmsg)) {
+ switch (scmsg->cmsg_type) {
+ case SCTP_RCVINFO: {
+ struct sctp_rcvinfo *data = reinterpret_cast<struct sctp_rcvinfo *>(CMSG_DATA(scmsg));
+ if (count > 0) {
+ VLOG(1) << "Got sctp_rcvinfo on continued packet";
+ CHECK_EQ(result->header.rcvinfo.rcv_sid, data->rcv_sid);
+ CHECK_EQ(result->header.rcvinfo.rcv_ssn, data->rcv_ssn);
+ CHECK_EQ(result->header.rcvinfo.rcv_ppid, data->rcv_ppid);
+ CHECK_EQ(result->header.rcvinfo.rcv_assoc_id, data->rcv_assoc_id);
+ }
+ result->header.rcvinfo = *data;
+ } break;
+ default:
+ LOG(INFO) << "\tUnknown type: " << scmsg->cmsg_type;
+ break;
+ }
+ }
+
+ CHECK_NE(inmessage.msg_flags & MSG_CTRUNC, MSG_CTRUNC)
+ << ": Control message truncated.";
+
+ CHECK_LE(result->size, max_size) << ": Message overflowed buffer on stream "
+ << result->header.rcvinfo.rcv_sid << ".";
+ }
+
+ result->partial_deliveries = count - 1;
+ if (count > 1) {
+ VLOG(1) << "Final count: " << count;
+ VLOG(1) << "Final size: " << result->size;
+ }
+
if ((MSG_NOTIFICATION & inmessage.msg_flags)) {
result->message_type = Message::kNotification;
} else {
result->message_type = Message::kMessage;
}
-
- for (struct cmsghdr *scmsg = CMSG_FIRSTHDR(&inmessage); scmsg != NULL;
- scmsg = CMSG_NXTHDR(&inmessage, scmsg)) {
- switch (scmsg->cmsg_type) {
- case SCTP_RCVINFO: {
- struct sctp_rcvinfo *data = (struct sctp_rcvinfo *)CMSG_DATA(scmsg);
- result->header.rcvinfo = *data;
- } break;
- default:
- LOG(INFO) << "\tUnknown type: " << scmsg->cmsg_type;
- break;
- }
- }
-
- CHECK_NE(inmessage.msg_flags & MSG_CTRUNC, MSG_CTRUNC)
- << ": Control message truncated.";
-
- CHECK_LE(size, max_size) << ": Message overflowed buffer on stream "
- << result->header.rcvinfo.rcv_sid << ".";
-
return result;
}