Squashed 'third_party/allwpilib_2019/' content from commit bd05dfa1c
Change-Id: I2b1c2250cdb9b055133780c33593292098c375b7
git-subtree-dir: third_party/allwpilib_2019
git-subtree-split: bd05dfa1c7cca74c4fac451e7b9d6a37e7b53447
diff --git a/wpiutil/src/main/native/cpp/WebSocket.cpp b/wpiutil/src/main/native/cpp/WebSocket.cpp
new file mode 100644
index 0000000..a38be20
--- /dev/null
+++ b/wpiutil/src/main/native/cpp/WebSocket.cpp
@@ -0,0 +1,565 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) 2018 FIRST. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "wpi/WebSocket.h"
+
+#include <random>
+
+#include "wpi/Base64.h"
+#include "wpi/HttpParser.h"
+#include "wpi/SmallString.h"
+#include "wpi/SmallVector.h"
+#include "wpi/raw_uv_ostream.h"
+#include "wpi/sha1.h"
+#include "wpi/uv/Stream.h"
+
+using namespace wpi;
+
+namespace {
+class WebSocketWriteReq : public uv::WriteReq {
+ public:
+ explicit WebSocketWriteReq(
+ std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
+ finish.connect([=](uv::Error err) {
+ MutableArrayRef<uv::Buffer> bufs{m_bufs};
+ for (auto&& buf : bufs.slice(0, m_startUser)) buf.Deallocate();
+ callback(bufs.slice(m_startUser), err);
+ });
+ }
+
+ SmallVector<uv::Buffer, 4> m_bufs;
+ size_t m_startUser;
+};
+} // namespace
+
+class WebSocket::ClientHandshakeData {
+ public:
+ ClientHandshakeData() {
+ // key is a random nonce
+ static std::random_device rd;
+ static std::default_random_engine gen{rd()};
+ std::uniform_int_distribution<unsigned int> dist(0, 255);
+ char nonce[16]; // the nonce sent to the server
+ for (char& v : nonce) v = static_cast<char>(dist(gen));
+ raw_svector_ostream os(key);
+ Base64Encode(os, StringRef{nonce, 16});
+ }
+ ~ClientHandshakeData() {
+ if (auto t = timer.lock()) {
+ t->Stop();
+ t->Close();
+ }
+ }
+
+ SmallString<64> key; // the key sent to the server
+ SmallVector<std::string, 2> protocols; // valid protocols
+ HttpParser parser{HttpParser::kResponse}; // server response parser
+ bool hasUpgrade = false;
+ bool hasConnection = false;
+ bool hasAccept = false;
+ bool hasProtocol = false;
+
+ std::weak_ptr<uv::Timer> timer;
+};
+
+static StringRef AcceptHash(StringRef key, SmallVectorImpl<char>& buf) {
+ SHA1 hash;
+ hash.Update(key);
+ hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
+ SmallString<64> hashBuf;
+ return Base64Encode(hash.RawFinal(hashBuf), buf);
+}
+
+WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
+ : m_stream{stream}, m_server{server} {
+ // Connect closed and error signals to ourselves
+ m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
+ m_stream.error.connect([this](uv::Error err) {
+ Terminate(1006, "stream error: " + Twine(err.name()));
+ });
+
+ // Start reading
+ m_stream.StopRead(); // we may have been reading
+ m_stream.StartRead();
+ m_stream.data.connect(
+ [this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
+ m_stream.end.connect(
+ [this]() { Terminate(1006, "remote end closed connection"); });
+}
+
+WebSocket::~WebSocket() {}
+
+std::shared_ptr<WebSocket> WebSocket::CreateClient(
+ uv::Stream& stream, const Twine& uri, const Twine& host,
+ ArrayRef<StringRef> protocols, const ClientOptions& options) {
+ auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
+ stream.SetData(ws);
+ ws->StartClient(uri, host, protocols, options);
+ return ws;
+}
+
+std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
+ StringRef key,
+ StringRef version,
+ StringRef protocol) {
+ auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
+ stream.SetData(ws);
+ ws->StartServer(key, version, protocol);
+ return ws;
+}
+
+void WebSocket::Close(uint16_t code, const Twine& reason) {
+ SendClose(code, reason);
+ if (m_state != FAILED && m_state != CLOSED) m_state = CLOSING;
+}
+
+void WebSocket::Fail(uint16_t code, const Twine& reason) {
+ if (m_state == FAILED || m_state == CLOSED) return;
+ SendClose(code, reason);
+ SetClosed(code, reason, true);
+ Shutdown();
+}
+
+void WebSocket::Terminate(uint16_t code, const Twine& reason) {
+ if (m_state == FAILED || m_state == CLOSED) return;
+ SetClosed(code, reason);
+ Shutdown();
+}
+
+void WebSocket::StartClient(const Twine& uri, const Twine& host,
+ ArrayRef<StringRef> protocols,
+ const ClientOptions& options) {
+ // Create client handshake data
+ m_clientHandshake = std::make_unique<ClientHandshakeData>();
+
+ // Build client request
+ SmallVector<uv::Buffer, 4> bufs;
+ raw_uv_ostream os{bufs, 4096};
+
+ os << "GET " << uri << " HTTP/1.1\r\n";
+ os << "Host: " << host << "\r\n";
+ os << "Upgrade: websocket\r\n";
+ os << "Connection: Upgrade\r\n";
+ os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
+ os << "Sec-WebSocket-Version: 13\r\n";
+
+ // protocols (if provided)
+ if (!protocols.empty()) {
+ os << "Sec-WebSocket-Protocol: ";
+ bool first = true;
+ for (auto protocol : protocols) {
+ if (!first)
+ os << ", ";
+ else
+ first = false;
+ os << protocol;
+ // also save for later checking against server response
+ m_clientHandshake->protocols.emplace_back(protocol);
+ }
+ os << "\r\n";
+ }
+
+ // other headers
+ for (auto&& header : options.extraHeaders)
+ os << header.first << ": " << header.second << "\r\n";
+
+ // finish headers
+ os << "\r\n";
+
+ // Send client request
+ m_stream.Write(bufs, [](auto bufs, uv::Error) {
+ for (auto& buf : bufs) buf.Deallocate();
+ });
+
+ // Set up client response handling
+ m_clientHandshake->parser.status.connect([this](StringRef status) {
+ unsigned int code = m_clientHandshake->parser.GetStatusCode();
+ if (code != 101) Terminate(code, status);
+ });
+ m_clientHandshake->parser.header.connect(
+ [this](StringRef name, StringRef value) {
+ value = value.trim();
+ if (name.equals_lower("upgrade")) {
+ if (!value.equals_lower("websocket"))
+ return Terminate(1002, "invalid upgrade response value");
+ m_clientHandshake->hasUpgrade = true;
+ } else if (name.equals_lower("connection")) {
+ if (!value.equals_lower("upgrade"))
+ return Terminate(1002, "invalid connection response value");
+ m_clientHandshake->hasConnection = true;
+ } else if (name.equals_lower("sec-websocket-accept")) {
+ // Check against expected response
+ SmallString<64> acceptBuf;
+ if (!value.equals(AcceptHash(m_clientHandshake->key, acceptBuf)))
+ return Terminate(1002, "invalid accept key");
+ m_clientHandshake->hasAccept = true;
+ } else if (name.equals_lower("sec-websocket-extensions")) {
+ // No extensions are supported
+ if (!value.empty()) return Terminate(1010, "unsupported extension");
+ } else if (name.equals_lower("sec-websocket-protocol")) {
+ // Make sure it was one of the provided protocols
+ bool match = false;
+ for (auto&& protocol : m_clientHandshake->protocols) {
+ if (value.equals_lower(protocol)) {
+ match = true;
+ break;
+ }
+ }
+ if (!match) return Terminate(1003, "unsupported protocol");
+ m_clientHandshake->hasProtocol = true;
+ m_protocol = value;
+ }
+ });
+ m_clientHandshake->parser.headersComplete.connect([this](bool) {
+ if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
+ !m_clientHandshake->hasAccept ||
+ (!m_clientHandshake->hasProtocol &&
+ !m_clientHandshake->protocols.empty())) {
+ return Terminate(1002, "invalid response");
+ }
+ if (m_state == CONNECTING) {
+ m_state = OPEN;
+ open(m_protocol);
+ }
+ });
+
+ // Start handshake timer if a timeout was specified
+ if (options.handshakeTimeout != uv::Timer::Time::max()) {
+ auto timer = uv::Timer::Create(m_stream.GetLoopRef());
+ timer->timeout.connect(
+ [this]() { Terminate(1006, "connection timed out"); });
+ timer->Start(options.handshakeTimeout);
+ m_clientHandshake->timer = timer;
+ }
+}
+
+void WebSocket::StartServer(StringRef key, StringRef version,
+ StringRef protocol) {
+ m_protocol = protocol;
+
+ // Build server response
+ SmallVector<uv::Buffer, 4> bufs;
+ raw_uv_ostream os{bufs, 4096};
+
+ // Handle unsupported version
+ if (version != "13") {
+ os << "HTTP/1.1 426 Upgrade Required\r\n";
+ os << "Upgrade: WebSocket\r\n";
+ os << "Sec-WebSocket-Version: 13\r\n\r\n";
+ m_stream.Write(bufs, [this](auto bufs, uv::Error) {
+ for (auto& buf : bufs) buf.Deallocate();
+ // XXX: Should we support sending a new handshake on the same connection?
+ // XXX: "this->" is required by GCC 5.5 (bug)
+ this->Terminate(1003, "unsupported protocol version");
+ });
+ return;
+ }
+
+ os << "HTTP/1.1 101 Switching Protocols\r\n";
+ os << "Upgrade: websocket\r\n";
+ os << "Connection: Upgrade\r\n";
+
+ // accept hash
+ SmallString<64> acceptBuf;
+ os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
+
+ if (!protocol.empty()) os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
+
+ // end headers
+ os << "\r\n";
+
+ // Send server response
+ m_stream.Write(bufs, [this](auto bufs, uv::Error) {
+ for (auto& buf : bufs) buf.Deallocate();
+ if (m_state == CONNECTING) {
+ m_state = OPEN;
+ open(m_protocol);
+ }
+ });
+}
+
+void WebSocket::SendClose(uint16_t code, const Twine& reason) {
+ SmallVector<uv::Buffer, 4> bufs;
+ if (code != 1005) {
+ raw_uv_ostream os{bufs, 4096};
+ os << ArrayRef<uint8_t>{static_cast<uint8_t>((code >> 8) & 0xff),
+ static_cast<uint8_t>(code & 0xff)};
+ reason.print(os);
+ }
+ Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
+ for (auto&& buf : bufs) buf.Deallocate();
+ });
+}
+
+void WebSocket::SetClosed(uint16_t code, const Twine& reason, bool failed) {
+ if (m_state == FAILED || m_state == CLOSED) return;
+ m_state = failed ? FAILED : CLOSED;
+ SmallString<64> reasonBuf;
+ closed(code, reason.toStringRef(reasonBuf));
+}
+
+void WebSocket::Shutdown() {
+ m_stream.Shutdown([this] { m_stream.Close(); });
+}
+
+void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
+ // ignore incoming data if we're failed or closed
+ if (m_state == FAILED || m_state == CLOSED) return;
+
+ StringRef data{buf.base, size};
+
+ // Handle connecting state (mainly on client)
+ if (m_state == CONNECTING) {
+ if (m_clientHandshake) {
+ data = m_clientHandshake->parser.Execute(data);
+ // check for parser failure
+ if (m_clientHandshake->parser.HasError())
+ return Terminate(1003, "invalid response");
+ if (m_state != OPEN) return; // not done with handshake yet
+
+ // we're done with the handshake, so release its memory
+ m_clientHandshake.reset();
+
+ // fall through to process additional data after handshake
+ } else {
+ return Terminate(1003, "got data on server before response");
+ }
+ }
+
+ // Message processing
+ while (!data.empty()) {
+ if (m_frameSize == UINT64_MAX) {
+ // Need at least two bytes to determine header length
+ if (m_header.size() < 2u) {
+ size_t toCopy = std::min(2u - m_header.size(), data.size());
+ m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
+ data = data.drop_front(toCopy);
+ if (m_header.size() < 2u) return; // need more data
+
+ // Validate RSV bits are zero
+ if ((m_header[0] & 0x70) != 0) return Fail(1002, "nonzero RSV");
+ }
+
+ // Once we have first two bytes, we can calculate the header size
+ if (m_headerSize == 0) {
+ m_headerSize = 2;
+ uint8_t len = m_header[1] & kLenMask;
+ if (len == 126)
+ m_headerSize += 2;
+ else if (len == 127)
+ m_headerSize += 8;
+ bool masking = (m_header[1] & kFlagMasking) != 0;
+ if (masking) m_headerSize += 4; // masking key
+ // On server side, incoming messages MUST be masked
+ // On client side, incoming messages MUST NOT be masked
+ if (m_server && !masking) return Fail(1002, "client data not masked");
+ if (!m_server && masking) return Fail(1002, "server data masked");
+ }
+
+ // Need to complete header to calculate message size
+ if (m_header.size() < m_headerSize) {
+ size_t toCopy = std::min(m_headerSize - m_header.size(), data.size());
+ m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
+ data = data.drop_front(toCopy);
+ if (m_header.size() < m_headerSize) return; // need more data
+ }
+
+ if (m_header.size() >= m_headerSize) {
+ // get payload length
+ uint8_t len = m_header[1] & kLenMask;
+ if (len == 126)
+ m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
+ static_cast<uint16_t>(m_header[3]);
+ else if (len == 127)
+ m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
+ (static_cast<uint64_t>(m_header[3]) << 48) |
+ (static_cast<uint64_t>(m_header[4]) << 40) |
+ (static_cast<uint64_t>(m_header[5]) << 32) |
+ (static_cast<uint64_t>(m_header[6]) << 24) |
+ (static_cast<uint64_t>(m_header[7]) << 16) |
+ (static_cast<uint64_t>(m_header[8]) << 8) |
+ static_cast<uint64_t>(m_header[9]);
+ else
+ m_frameSize = len;
+
+ // limit maximum size
+ if ((m_payload.size() + m_frameSize) > m_maxMessageSize)
+ return Fail(1009, "message too large");
+ }
+ }
+
+ if (m_frameSize != UINT64_MAX) {
+ size_t need = m_frameStart + m_frameSize - m_payload.size();
+ size_t toCopy = std::min(need, data.size());
+ m_payload.append(data.bytes_begin(), data.bytes_begin() + toCopy);
+ data = data.drop_front(toCopy);
+ need -= toCopy;
+ if (need == 0) {
+ // We have a complete frame
+ // If the message had masking, unmask it
+ if ((m_header[1] & kFlagMasking) != 0) {
+ uint8_t key[4] = {
+ m_header[m_headerSize - 4], m_header[m_headerSize - 3],
+ m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
+ int n = 0;
+ for (uint8_t& ch :
+ MutableArrayRef<uint8_t>{m_payload}.slice(m_frameStart)) {
+ ch ^= key[n++];
+ if (n >= 4) n = 0;
+ }
+ }
+
+ // Handle message
+ bool fin = (m_header[0] & kFlagFin) != 0;
+ uint8_t opcode = m_header[0] & kOpMask;
+ switch (opcode) {
+ case kOpCont:
+ switch (m_fragmentOpcode) {
+ case kOpText:
+ if (!m_combineFragments || fin)
+ text(StringRef{reinterpret_cast<char*>(m_payload.data()),
+ m_payload.size()},
+ fin);
+ break;
+ case kOpBinary:
+ if (!m_combineFragments || fin) binary(m_payload, fin);
+ break;
+ default:
+ // no preceding message?
+ return Fail(1002, "invalid continuation message");
+ }
+ if (fin) m_fragmentOpcode = 0;
+ break;
+ case kOpText:
+ if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
+ if (!m_combineFragments || fin)
+ text(StringRef{reinterpret_cast<char*>(m_payload.data()),
+ m_payload.size()},
+ fin);
+ if (!fin) m_fragmentOpcode = opcode;
+ break;
+ case kOpBinary:
+ if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
+ if (!m_combineFragments || fin) binary(m_payload, fin);
+ if (!fin) m_fragmentOpcode = opcode;
+ break;
+ case kOpClose: {
+ uint16_t code;
+ StringRef reason;
+ if (!fin) {
+ code = 1002;
+ reason = "cannot fragment control frames";
+ } else if (m_payload.size() < 2) {
+ code = 1005;
+ } else {
+ code = (static_cast<uint16_t>(m_payload[0]) << 8) |
+ static_cast<uint16_t>(m_payload[1]);
+ reason = StringRef{reinterpret_cast<char*>(m_payload.data()),
+ m_payload.size()}
+ .drop_front(2);
+ }
+ // Echo the close if we didn't previously send it
+ if (m_state != CLOSING) SendClose(code, reason);
+ SetClosed(code, reason);
+ // If we're the server, shutdown the connection.
+ if (m_server) Shutdown();
+ break;
+ }
+ case kOpPing:
+ if (!fin) return Fail(1002, "cannot fragment control frames");
+ ping(m_payload);
+ break;
+ case kOpPong:
+ if (!fin) return Fail(1002, "cannot fragment control frames");
+ pong(m_payload);
+ break;
+ default:
+ return Fail(1002, "invalid message opcode");
+ }
+
+ // Prepare for next message
+ m_header.clear();
+ m_headerSize = 0;
+ if (!m_combineFragments || fin) m_payload.clear();
+ m_frameStart = m_payload.size();
+ m_frameSize = UINT64_MAX;
+ }
+ }
+ }
+}
+
+void WebSocket::Send(
+ uint8_t opcode, ArrayRef<uv::Buffer> data,
+ std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
+ // If we're not open, emit an error and don't send the data
+ if (m_state != OPEN) {
+ int err;
+ if (m_state == CONNECTING)
+ err = UV_EAGAIN;
+ else
+ err = UV_ESHUTDOWN;
+ SmallVector<uv::Buffer, 4> bufs{data.begin(), data.end()};
+ callback(bufs, uv::Error{err});
+ return;
+ }
+
+ auto req = std::make_shared<WebSocketWriteReq>(callback);
+ raw_uv_ostream os{req->m_bufs, 4096};
+
+ // opcode (includes FIN bit)
+ os << static_cast<unsigned char>(opcode);
+
+ // payload length
+ uint64_t size = 0;
+ for (auto&& buf : data) size += buf.len;
+ if (size < 126) {
+ os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | size);
+ } else if (size <= 0xffff) {
+ os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 126);
+ os << ArrayRef<uint8_t>{static_cast<uint8_t>((size >> 8) & 0xff),
+ static_cast<uint8_t>(size & 0xff)};
+ } else {
+ os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 127);
+ os << ArrayRef<uint8_t>{static_cast<uint8_t>((size >> 56) & 0xff),
+ static_cast<uint8_t>((size >> 48) & 0xff),
+ static_cast<uint8_t>((size >> 40) & 0xff),
+ static_cast<uint8_t>((size >> 32) & 0xff),
+ static_cast<uint8_t>((size >> 24) & 0xff),
+ static_cast<uint8_t>((size >> 16) & 0xff),
+ static_cast<uint8_t>((size >> 8) & 0xff),
+ static_cast<uint8_t>(size & 0xff)};
+ }
+
+ // clients need to mask the input data
+ if (!m_server) {
+ // generate masking key
+ static std::random_device rd;
+ static std::default_random_engine gen{rd()};
+ std::uniform_int_distribution<unsigned int> dist(0, 255);
+ uint8_t key[4];
+ for (uint8_t& v : key) v = dist(gen);
+ os << ArrayRef<uint8_t>{key, 4};
+ // copy and mask data
+ int n = 0;
+ for (auto&& buf : data) {
+ for (auto&& ch : buf.data()) {
+ os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
+ if (n >= 4) n = 0;
+ }
+ }
+ req->m_startUser = req->m_bufs.size();
+ req->m_bufs.append(data.begin(), data.end());
+ // don't send the user bufs as we copied their data
+ m_stream.Write(ArrayRef<uv::Buffer>{req->m_bufs}.slice(0, req->m_startUser),
+ req);
+ } else {
+ // servers can just send the buffers directly without masking
+ req->m_startUser = req->m_bufs.size();
+ req->m_bufs.append(data.begin(), data.end());
+ m_stream.Write(req->m_bufs, req);
+ }
+}