Add backpressure to web_proxy

Implement things such that the client will send back what it has
currently processed on a given channel. The server will then avoid
sending more than 10000 messages ahead of that point.

Also, fix up some memory management to ensure that data channels
actually get closed/destroyed at the end of a browser session.

Change-Id: Id1795d7496f410332407624a559d6a16a1698702
Signed-off-by: James Kuszmaul <jabukuszmaul@gmail.com>
diff --git a/aos/network/rawrtc.cc b/aos/network/rawrtc.cc
index 89141db..98f2448 100644
--- a/aos/network/rawrtc.cc
+++ b/aos/network/rawrtc.cc
@@ -27,6 +27,12 @@
 
 ScopedDataChannel::ScopedDataChannel() {}
 
+std::shared_ptr<ScopedDataChannel> ScopedDataChannel::MakeDataChannel() {
+  std::shared_ptr<ScopedDataChannel> channel(new ScopedDataChannel());
+  channel->self_ = channel;
+  return channel;
+}
+
 void ScopedDataChannel::Open(struct rawrtc_peer_connection *connection,
                              const std::string &label) {
   label_ = label;
@@ -89,7 +95,7 @@
 
 ScopedDataChannel::~ScopedDataChannel() {
   CHECK(opened_);
-  CHECK(closed_);
+  CHECK(closed_) << ": Never closed " << label();
   CHECK(data_channel_ == nullptr)
       << ": Destroying open data channel " << this << ".";
 }
@@ -129,6 +135,7 @@
     on_close();
   }
   mem_deref(data_channel);
+  client->self_.reset();
 }
 
 void ScopedDataChannel::StaticDataChannelMessageHandler(
@@ -285,7 +292,7 @@
   RawRTCConnection *const client = reinterpret_cast<RawRTCConnection *>(arg);
   if (client->on_data_channel_) {
     std::shared_ptr<ScopedDataChannel> new_channel =
-        std::make_shared<ScopedDataChannel>();
+        ScopedDataChannel::MakeDataChannel();
     new_channel->Open(channel);
     client->on_data_channel_(std::move(new_channel));
   }
diff --git a/aos/network/rawrtc.h b/aos/network/rawrtc.h
index 3f37435..4c61313 100644
--- a/aos/network/rawrtc.h
+++ b/aos/network/rawrtc.h
@@ -44,7 +44,7 @@
 //    on_close callback and shuts down the channel.
 class ScopedDataChannel {
  public:
-  ScopedDataChannel();
+  static std::shared_ptr<ScopedDataChannel> MakeDataChannel();
   ScopedDataChannel(const ScopedDataChannel &) = delete;
   ScopedDataChannel &operator=(const ScopedDataChannel &) = delete;
 
@@ -90,6 +90,7 @@
   uint64_t buffered_amount();
 
  private:
+  ScopedDataChannel();
   // Trampolines from C -> C++.
   static void StaticDataChannelOpenHandler(void *const arg);
   static void StaticBufferedAmountLowHandler(void *const arg);
diff --git a/aos/network/web_proxy.cc b/aos/network/web_proxy.cc
index 94473fb..fdd8f1e 100644
--- a/aos/network/web_proxy.cc
+++ b/aos/network/web_proxy.cc
@@ -18,6 +18,12 @@
 }
 
 DEFINE_int32(proxy_port, 1180, "Port to use for the web proxy server.");
+DEFINE_int32(pre_send_messages, 10000,
+             "Number of messages / queue to send to a client before waiting on "
+             "confirmation that the initial message was received. If set to "
+             "-1, will not throttle messages at all. This prevents a situation "
+             "where, when run on localhost, the large number of WebRTC packets "
+             "can overwhelm the browser and crash the webpage.");
 
 namespace aos {
 namespace web_proxy {
@@ -234,7 +240,8 @@
     }
   }
   for (auto &conn : channels_) {
-    std::shared_ptr<ScopedDataChannel> rtc_channel = conn.first;
+    std::shared_ptr<ScopedDataChannel> rtc_channel = conn.first.lock();
+    CHECK(rtc_channel) << "data_channel was destroyed too early.";
     ChannelInformation *channel_data = &conn.second;
     if (channel_data->transfer_method == TransferMethod::SUBSAMPLE) {
       SkipToLastMessage(channel_data);
@@ -261,7 +268,6 @@
 
 void Subscriber::AddListener(std::shared_ptr<ScopedDataChannel> data_channel,
                              TransferMethod transfer_method) {
-  VLOG(1) << "Adding listener for " << data_channel.get();
   ChannelInformation info;
   info.transfer_method = transfer_method;
 
@@ -271,13 +277,35 @@
     fetcher_->Fetch();
   }
 
-  channels_.emplace(data_channel, info);
+  channels_.emplace_back(std::make_pair(data_channel, info));
+
+  data_channel->set_on_message(
+      [this, index = channels_.size() - 1](
+          struct mbuf *const buffer,
+          const enum rawrtc_data_channel_message_flag /*flags*/) {
+        FlatbufferSpan<ChannelState> message(
+            {mbuf_buf(buffer), mbuf_get_left(buffer)});
+        if (!message.Verify()) {
+          LOG(ERROR) << "Invalid flatbuffer received from browser client.";
+          return;
+        }
+
+        channels_[index].second.reported_queue_index =
+            message.message().queue_index();
+        channels_[index].second.reported_packet_index =
+            message.message().packet_index();
+      });
 }
 
-void Subscriber::RemoveListener(
-    std::shared_ptr<ScopedDataChannel> data_channel) {
-  VLOG(1) << "Removing listener for " << data_channel.get();
-  channels_.erase(data_channel);
+void Subscriber::RemoveListener(std::shared_ptr<ScopedDataChannel> data_channel) {
+  channels_.erase(
+      std::remove_if(
+          channels_.begin(), channels_.end(),
+          [data_channel](const std::pair<std::weak_ptr<ScopedDataChannel>,
+                                         ChannelInformation> &channel) {
+            return channel.first.lock().get() == data_channel.get();
+          }),
+      channels_.end());
 }
 
 std::shared_ptr<struct mbuf> Subscriber::NextBuffer(
@@ -294,10 +322,24 @@
     channel->next_packet_number = 0;
     return message_buffer_.front().data.at(0);
   }
+  // TODO(james): Handle queue index wrapping when >2^32 messages are sent on a
+  // channel.
   if (channel->current_queue_index > latest_index) {
     // We are still waiting on the next message to appear; return.
     return nullptr;
   }
+  if (FLAGS_pre_send_messages > 0) {
+    // Don't buffer up an excessive number of messages to the client.
+    // This currently ignores the packet index (and really, any concept of
+    // message size), but the main goal is just to avoid locking up the client
+    // browser, not to be ultra precise about anything. It's also not clear that
+    // message *size* is necessarily even the determining factor in causing
+    // issues.
+    if (channel->reported_queue_index + FLAGS_pre_send_messages <
+        channel->current_queue_index) {
+      return nullptr;
+    }
+  }
   CHECK_EQ(latest_index - earliest_index + 1, message_buffer_.size())
       << "Inconsistent queue indices.";
   const size_t packets_in_message =
@@ -575,7 +617,7 @@
     auto it = channels_.find(channel_index);
     if (it == channels_.end()) {
       std::shared_ptr<ScopedDataChannel> data_channel =
-          std::make_shared<ScopedDataChannel>();
+          ScopedDataChannel::MakeDataChannel();
 
       std::weak_ptr<ScopedDataChannel> data_channel_weak_ptr = data_channel;
 
@@ -584,9 +626,10 @@
         std::shared_ptr<ScopedDataChannel> data_channel =
             data_channel_weak_ptr.lock();
         CHECK(data_channel) << ": Subscriber got destroyed before we started.";
-        // Weak ptr inside the subscriber so we don't have a circular
+        // Raw pointer inside the subscriber so we don't have a circular
         // reference.  AddListener will close it.
-        subscribers_[channel_index]->AddListener(data_channel, transfer_method);
+        subscribers_[channel_index]->AddListener(data_channel,
+                                                 transfer_method);
       });
 
       Subscriber *subscriber = subscribers_[channel_index].get();
diff --git a/aos/network/web_proxy.fbs b/aos/network/web_proxy.fbs
index 817add1..6c85acb 100644
--- a/aos/network/web_proxy.fbs
+++ b/aos/network/web_proxy.fbs
@@ -72,6 +72,16 @@
   method:TransferMethod (id: 1);
 }
 
+// This is used to communicate the most recently received message by the client.
+// This allows the server to avoid overloading the client (which we've had
+// issues with in the past).
+table ChannelState {
+  // queue_index and packet_index correspond to the similarly named fields in
+  // MessageHeader.
+  queue_index:uint (id: 0);
+  packet_index:uint (id: 1);
+}
+
 table SubscriberRequest {
   // The channels that we want transfered to this client.
   channels_to_transfer:[ChannelRequest] (id: 0);
diff --git a/aos/network/web_proxy.h b/aos/network/web_proxy.h
index 4ad0630..0815ebf 100644
--- a/aos/network/web_proxy.h
+++ b/aos/network/web_proxy.h
@@ -111,8 +111,17 @@
  private:
   struct ChannelInformation {
     TransferMethod transfer_method;
+    // Queue index (same as the queue index within the AOS channel) of the
+    // message that we are currently sending or, if we are between messages,
+    // the next message we will send.
     uint32_t current_queue_index = 0;
+    // Index of the next packet to send within current_queue_index (large
+    // messages are broken into multiple packets, as we have encountered
+    // issues with how some WebRTC implementations handle large packets).
     size_t next_packet_number = 0;
+    // The last queue/packet index reported by the client.
+    uint32_t reported_queue_index = 0;
+    size_t reported_packet_index = 0;
   };
   struct Message {
     uint32_t index = 0xffffffff;
@@ -126,7 +135,21 @@
   int channel_index_;
   int buffer_size_;
   std::deque<Message> message_buffer_;
-  std::map<std::shared_ptr<ScopedDataChannel>, ChannelInformation> channels_;
+  // The ScopedDataChannel that we use for actually sending data over WebRTC
+  // is stored using a weak_ptr because:
+  // (a) There are some dangers of accidentally creating circular dependencies
+  //     that prevent a ScopedDataChannel from ever being destroyed.
+  // (b) The inter-dependencies involved are complicated enough that we want
+  //     to be able to check whether someone has destroyed the ScopedDataChannel
+  //     before using it (if it has been destroyed and the Subscriber still
+  //     wants to use it, that is a bug, but checking for bugs is useful).
+  // This particular location *may* be able to get away with a shared_ptr, but
+  // because the ScopedDataChannel effectively destroys itself (see
+  // ScopedDataChannel::StaticDataChannelCloseHandler) while also potentially
+  // holding references to other objects (e.g., through the various handlers
+  // that can be registered), creating unnecessary shared_ptr's is dubious.
+  std::vector<std::pair<std::weak_ptr<ScopedDataChannel>, ChannelInformation>>
+      channels_;
 };
 
 // Class to manage a WebRTC connection to a browser.
diff --git a/aos/network/www/proxy.ts b/aos/network/www/proxy.ts
index 3da817c..c52f495 100644
--- a/aos/network/www/proxy.ts
+++ b/aos/network/www/proxy.ts
@@ -15,6 +15,7 @@
 import SubscriberRequest = web_proxy.aos.web_proxy.SubscriberRequest;
 import ChannelRequestFb = web_proxy.aos.web_proxy.ChannelRequest;
 import TransferMethod = web_proxy.aos.web_proxy.TransferMethod;
+import ChannelState = web_proxy.aos.web_proxy.ChannelState;
 
 // There is one handler for each DataChannel, it maintains the state of
 // multi-part messages and delegates to a callback when the message is fully
@@ -34,6 +35,16 @@
     const messageHeader = MessageHeader.getRootAsMessageHeader(
         fbBuffer as unknown as flatbuffers.ByteBuffer);
     const time = messageHeader.monotonicSentTime().toFloat64() * 1e-9;
+
+    const stateBuilder = new Builder(512) as unknown as flatbuffers.Builder;
+    ChannelState.startChannelState(stateBuilder);
+    ChannelState.addQueueIndex(stateBuilder, messageHeader.queueIndex());
+    ChannelState.addPacketIndex(stateBuilder, messageHeader.packetIndex());
+    const state = ChannelState.endChannelState(stateBuilder);
+    stateBuilder.finish(state);
+    const stateArray = stateBuilder.asUint8Array();
+    this.channel.send(stateArray);
+
     // Short circuit if only one packet
     if (messageHeader.packetCount() === 1) {
       this.handlerFunc(messageHeader.dataArray(), time);