Add backpressure to web_proxy
Implement things such that the client will send back what it has
currently processed on a given channel. The server will then avoid
sending more than 10000 messages ahead of that point.
Also, fix up some memory management to ensure that data channels
actually get closed/destroyed at the end of a browser session.
Change-Id: Id1795d7496f410332407624a559d6a16a1698702
Signed-off-by: James Kuszmaul <jabukuszmaul@gmail.com>
diff --git a/aos/network/rawrtc.cc b/aos/network/rawrtc.cc
index 89141db..98f2448 100644
--- a/aos/network/rawrtc.cc
+++ b/aos/network/rawrtc.cc
@@ -27,6 +27,12 @@
ScopedDataChannel::ScopedDataChannel() {}
+std::shared_ptr<ScopedDataChannel> ScopedDataChannel::MakeDataChannel() {
+ std::shared_ptr<ScopedDataChannel> channel(new ScopedDataChannel());
+ channel->self_ = channel;
+ return channel;
+}
+
void ScopedDataChannel::Open(struct rawrtc_peer_connection *connection,
const std::string &label) {
label_ = label;
@@ -89,7 +95,7 @@
ScopedDataChannel::~ScopedDataChannel() {
CHECK(opened_);
- CHECK(closed_);
+ CHECK(closed_) << ": Never closed " << label();
CHECK(data_channel_ == nullptr)
<< ": Destroying open data channel " << this << ".";
}
@@ -129,6 +135,7 @@
on_close();
}
mem_deref(data_channel);
+ client->self_.reset();
}
void ScopedDataChannel::StaticDataChannelMessageHandler(
@@ -285,7 +292,7 @@
RawRTCConnection *const client = reinterpret_cast<RawRTCConnection *>(arg);
if (client->on_data_channel_) {
std::shared_ptr<ScopedDataChannel> new_channel =
- std::make_shared<ScopedDataChannel>();
+ ScopedDataChannel::MakeDataChannel();
new_channel->Open(channel);
client->on_data_channel_(std::move(new_channel));
}
diff --git a/aos/network/rawrtc.h b/aos/network/rawrtc.h
index 3f37435..4c61313 100644
--- a/aos/network/rawrtc.h
+++ b/aos/network/rawrtc.h
@@ -44,7 +44,7 @@
// on_close callback and shuts down the channel.
class ScopedDataChannel {
public:
- ScopedDataChannel();
+ static std::shared_ptr<ScopedDataChannel> MakeDataChannel();
ScopedDataChannel(const ScopedDataChannel &) = delete;
ScopedDataChannel &operator=(const ScopedDataChannel &) = delete;
@@ -90,6 +90,7 @@
uint64_t buffered_amount();
private:
+ ScopedDataChannel();
// Trampolines from C -> C++.
static void StaticDataChannelOpenHandler(void *const arg);
static void StaticBufferedAmountLowHandler(void *const arg);
diff --git a/aos/network/web_proxy.cc b/aos/network/web_proxy.cc
index 94473fb..fdd8f1e 100644
--- a/aos/network/web_proxy.cc
+++ b/aos/network/web_proxy.cc
@@ -18,6 +18,12 @@
}
DEFINE_int32(proxy_port, 1180, "Port to use for the web proxy server.");
+DEFINE_int32(pre_send_messages, 10000,
+ "Number of messages / queue to send to a client before waiting on "
+ "confirmation that the initial message was received. If set to "
+ "-1, will not throttle messages at all. This prevents a situation "
+ "where, when run on localhost, the large number of WebRTC packets "
+ "can overwhelm the browser and crash the webpage.");
namespace aos {
namespace web_proxy {
@@ -234,7 +240,8 @@
}
}
for (auto &conn : channels_) {
- std::shared_ptr<ScopedDataChannel> rtc_channel = conn.first;
+ std::shared_ptr<ScopedDataChannel> rtc_channel = conn.first.lock();
+ CHECK(rtc_channel) << "data_channel was destroyed too early.";
ChannelInformation *channel_data = &conn.second;
if (channel_data->transfer_method == TransferMethod::SUBSAMPLE) {
SkipToLastMessage(channel_data);
@@ -261,7 +268,6 @@
void Subscriber::AddListener(std::shared_ptr<ScopedDataChannel> data_channel,
TransferMethod transfer_method) {
- VLOG(1) << "Adding listener for " << data_channel.get();
ChannelInformation info;
info.transfer_method = transfer_method;
@@ -271,13 +277,35 @@
fetcher_->Fetch();
}
- channels_.emplace(data_channel, info);
+ channels_.emplace_back(std::make_pair(data_channel, info));
+
+ data_channel->set_on_message(
+ [this, index = channels_.size() - 1](
+ struct mbuf *const buffer,
+ const enum rawrtc_data_channel_message_flag /*flags*/) {
+ FlatbufferSpan<ChannelState> message(
+ {mbuf_buf(buffer), mbuf_get_left(buffer)});
+ if (!message.Verify()) {
+ LOG(ERROR) << "Invalid flatbuffer received from browser client.";
+ return;
+ }
+
+ channels_[index].second.reported_queue_index =
+ message.message().queue_index();
+ channels_[index].second.reported_packet_index =
+ message.message().packet_index();
+ });
}
-void Subscriber::RemoveListener(
- std::shared_ptr<ScopedDataChannel> data_channel) {
- VLOG(1) << "Removing listener for " << data_channel.get();
- channels_.erase(data_channel);
+void Subscriber::RemoveListener(std::shared_ptr<ScopedDataChannel> data_channel) {
+ channels_.erase(
+ std::remove_if(
+ channels_.begin(), channels_.end(),
+ [data_channel](const std::pair<std::weak_ptr<ScopedDataChannel>,
+ ChannelInformation> &channel) {
+ return channel.first.lock().get() == data_channel.get();
+ }),
+ channels_.end());
}
std::shared_ptr<struct mbuf> Subscriber::NextBuffer(
@@ -294,10 +322,24 @@
channel->next_packet_number = 0;
return message_buffer_.front().data.at(0);
}
+ // TODO(james): Handle queue index wrapping when >2^32 messages are sent on a
+ // channel.
if (channel->current_queue_index > latest_index) {
// We are still waiting on the next message to appear; return.
return nullptr;
}
+ if (FLAGS_pre_send_messages > 0) {
+ // Don't buffer up an excessive number of messages to the client.
+ // This currently ignores the packet index (and really, any concept of
+ // message size), but the main goal is just to avoid locking up the client
+ // browser, not to be ultra precise about anything. It's also not clear that
+ // message *size* is necessarily even the determining factor in causing
+ // issues.
+ if (channel->reported_queue_index + FLAGS_pre_send_messages <
+ channel->current_queue_index) {
+ return nullptr;
+ }
+ }
CHECK_EQ(latest_index - earliest_index + 1, message_buffer_.size())
<< "Inconsistent queue indices.";
const size_t packets_in_message =
@@ -575,7 +617,7 @@
auto it = channels_.find(channel_index);
if (it == channels_.end()) {
std::shared_ptr<ScopedDataChannel> data_channel =
- std::make_shared<ScopedDataChannel>();
+ ScopedDataChannel::MakeDataChannel();
std::weak_ptr<ScopedDataChannel> data_channel_weak_ptr = data_channel;
@@ -584,9 +626,10 @@
std::shared_ptr<ScopedDataChannel> data_channel =
data_channel_weak_ptr.lock();
CHECK(data_channel) << ": Subscriber got destroyed before we started.";
- // Weak ptr inside the subscriber so we don't have a circular
+ // Raw pointer inside the subscriber so we don't have a circular
// reference. AddListener will close it.
- subscribers_[channel_index]->AddListener(data_channel, transfer_method);
+ subscribers_[channel_index]->AddListener(data_channel,
+ transfer_method);
});
Subscriber *subscriber = subscribers_[channel_index].get();
diff --git a/aos/network/web_proxy.fbs b/aos/network/web_proxy.fbs
index 817add1..6c85acb 100644
--- a/aos/network/web_proxy.fbs
+++ b/aos/network/web_proxy.fbs
@@ -72,6 +72,16 @@
method:TransferMethod (id: 1);
}
+// This is used to communicate the most recently received message by the client.
+// This allows the server to avoid overloading the client (which we've had
+// issues with in the past).
+table ChannelState {
+ // queue_index and packet_index correspond to the similarly named fields in
+ // MessageHeader.
+ queue_index:uint (id: 0);
+ packet_index:uint (id: 1);
+}
+
table SubscriberRequest {
// The channels that we want transfered to this client.
channels_to_transfer:[ChannelRequest] (id: 0);
diff --git a/aos/network/web_proxy.h b/aos/network/web_proxy.h
index 4ad0630..0815ebf 100644
--- a/aos/network/web_proxy.h
+++ b/aos/network/web_proxy.h
@@ -111,8 +111,17 @@
private:
struct ChannelInformation {
TransferMethod transfer_method;
+ // Queue index (same as the queue index within the AOS channel) of the
+ // message that we are currently sending or, if we are between messages,
+ // the next message we will send.
uint32_t current_queue_index = 0;
+ // Index of the next packet to send within current_queue_index (large
+ // messages are broken into multiple packets, as we have encountered
+ // issues with how some WebRTC implementations handle large packets).
size_t next_packet_number = 0;
+ // The last queue/packet index reported by the client.
+ uint32_t reported_queue_index = 0;
+ size_t reported_packet_index = 0;
};
struct Message {
uint32_t index = 0xffffffff;
@@ -126,7 +135,21 @@
int channel_index_;
int buffer_size_;
std::deque<Message> message_buffer_;
- std::map<std::shared_ptr<ScopedDataChannel>, ChannelInformation> channels_;
+ // The ScopedDataChannel that we use for actually sending data over WebRTC
+ // is stored using a weak_ptr because:
+ // (a) There are some dangers of accidentally creating circular dependencies
+ // that prevent a ScopedDataChannel from ever being destroyed.
+ // (b) The inter-dependencies involved are complicated enough that we want
+ // to be able to check whether someone has destroyed the ScopedDataChannel
+ // before using it (if it has been destroyed and the Subscriber still
+ // wants to use it, that is a bug, but checking for bugs is useful).
+ // This particular location *may* be able to get away with a shared_ptr, but
+ // because the ScopedDataChannel effectively destroys itself (see
+ // ScopedDataChannel::StaticDataChannelCloseHandler) while also potentially
+ // holding references to other objects (e.g., through the various handlers
+ // that can be registered), creating unnecessary shared_ptr's is dubious.
+ std::vector<std::pair<std::weak_ptr<ScopedDataChannel>, ChannelInformation>>
+ channels_;
};
// Class to manage a WebRTC connection to a browser.
diff --git a/aos/network/www/proxy.ts b/aos/network/www/proxy.ts
index 3da817c..c52f495 100644
--- a/aos/network/www/proxy.ts
+++ b/aos/network/www/proxy.ts
@@ -15,6 +15,7 @@
import SubscriberRequest = web_proxy.aos.web_proxy.SubscriberRequest;
import ChannelRequestFb = web_proxy.aos.web_proxy.ChannelRequest;
import TransferMethod = web_proxy.aos.web_proxy.TransferMethod;
+import ChannelState = web_proxy.aos.web_proxy.ChannelState;
// There is one handler for each DataChannel, it maintains the state of
// multi-part messages and delegates to a callback when the message is fully
@@ -34,6 +35,16 @@
const messageHeader = MessageHeader.getRootAsMessageHeader(
fbBuffer as unknown as flatbuffers.ByteBuffer);
const time = messageHeader.monotonicSentTime().toFloat64() * 1e-9;
+
+ const stateBuilder = new Builder(512) as unknown as flatbuffers.Builder;
+ ChannelState.startChannelState(stateBuilder);
+ ChannelState.addQueueIndex(stateBuilder, messageHeader.queueIndex());
+ ChannelState.addPacketIndex(stateBuilder, messageHeader.packetIndex());
+ const state = ChannelState.endChannelState(stateBuilder);
+ stateBuilder.finish(state);
+ const stateArray = stateBuilder.asUint8Array();
+ this.channel.send(stateArray);
+
// Short circuit if only one packet
if (messageHeader.packetCount() === 1) {
this.handlerFunc(messageHeader.dataArray(), time);