blob: 4bb49d33914f4a9769c747a9a6c360cd6361c63f [file] [log] [blame]
Austin Schuh812d0d12021-11-04 20:16:48 -07001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
Brian Silverman8fce7482020-01-05 13:18:21 -08004
5#include "wpi/WebSocket.h"
6
7#include <random>
8
Austin Schuh812d0d12021-11-04 20:16:48 -07009#include "fmt/format.h"
Brian Silverman8fce7482020-01-05 13:18:21 -080010#include "wpi/Base64.h"
11#include "wpi/HttpParser.h"
12#include "wpi/SmallString.h"
13#include "wpi/SmallVector.h"
Austin Schuh812d0d12021-11-04 20:16:48 -070014#include "wpi/StringExtras.h"
Brian Silverman8fce7482020-01-05 13:18:21 -080015#include "wpi/raw_uv_ostream.h"
16#include "wpi/sha1.h"
17#include "wpi/uv/Stream.h"
18
19using namespace wpi;
20
21namespace {
22class WebSocketWriteReq : public uv::WriteReq {
23 public:
24 explicit WebSocketWriteReq(
Austin Schuh812d0d12021-11-04 20:16:48 -070025 std::function<void(span<uv::Buffer>, uv::Error)> callback)
26 : m_callback{std::move(callback)} {
27 finish.connect([this](uv::Error err) {
28 span<uv::Buffer> bufs{m_bufs};
29 for (auto&& buf : bufs.subspan(0, m_startUser)) {
30 buf.Deallocate();
31 }
32 m_callback(bufs.subspan(m_startUser), err);
Brian Silverman8fce7482020-01-05 13:18:21 -080033 });
34 }
35
Austin Schuh812d0d12021-11-04 20:16:48 -070036 std::function<void(span<uv::Buffer>, uv::Error)> m_callback;
Brian Silverman8fce7482020-01-05 13:18:21 -080037 SmallVector<uv::Buffer, 4> m_bufs;
38 size_t m_startUser;
39};
40} // namespace
41
42class WebSocket::ClientHandshakeData {
43 public:
44 ClientHandshakeData() {
45 // key is a random nonce
46 static std::random_device rd;
47 static std::default_random_engine gen{rd()};
48 std::uniform_int_distribution<unsigned int> dist(0, 255);
49 char nonce[16]; // the nonce sent to the server
Austin Schuh812d0d12021-11-04 20:16:48 -070050 for (char& v : nonce) {
51 v = static_cast<char>(dist(gen));
52 }
Brian Silverman8fce7482020-01-05 13:18:21 -080053 raw_svector_ostream os(key);
Austin Schuh812d0d12021-11-04 20:16:48 -070054 Base64Encode(os, {nonce, 16});
Brian Silverman8fce7482020-01-05 13:18:21 -080055 }
56 ~ClientHandshakeData() {
57 if (auto t = timer.lock()) {
58 t->Stop();
59 t->Close();
60 }
61 }
62
63 SmallString<64> key; // the key sent to the server
64 SmallVector<std::string, 2> protocols; // valid protocols
65 HttpParser parser{HttpParser::kResponse}; // server response parser
66 bool hasUpgrade = false;
67 bool hasConnection = false;
68 bool hasAccept = false;
69 bool hasProtocol = false;
70
71 std::weak_ptr<uv::Timer> timer;
72};
73
Austin Schuh812d0d12021-11-04 20:16:48 -070074static std::string_view AcceptHash(std::string_view key,
75 SmallVectorImpl<char>& buf) {
Brian Silverman8fce7482020-01-05 13:18:21 -080076 SHA1 hash;
77 hash.Update(key);
78 hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
79 SmallString<64> hashBuf;
80 return Base64Encode(hash.RawFinal(hashBuf), buf);
81}
82
83WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
84 : m_stream{stream}, m_server{server} {
85 // Connect closed and error signals to ourselves
86 m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
87 m_stream.error.connect([this](uv::Error err) {
Austin Schuh812d0d12021-11-04 20:16:48 -070088 Terminate(1006, fmt::format("stream error: {}", err.name()));
Brian Silverman8fce7482020-01-05 13:18:21 -080089 });
90
91 // Start reading
92 m_stream.StopRead(); // we may have been reading
93 m_stream.StartRead();
94 m_stream.data.connect(
95 [this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
96 m_stream.end.connect(
97 [this]() { Terminate(1006, "remote end closed connection"); });
98}
99
Austin Schuh812d0d12021-11-04 20:16:48 -0700100WebSocket::~WebSocket() = default;
Brian Silverman8fce7482020-01-05 13:18:21 -0800101
102std::shared_ptr<WebSocket> WebSocket::CreateClient(
Austin Schuh812d0d12021-11-04 20:16:48 -0700103 uv::Stream& stream, std::string_view uri, std::string_view host,
104 span<const std::string_view> protocols, const ClientOptions& options) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800105 auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
106 stream.SetData(ws);
107 ws->StartClient(uri, host, protocols, options);
108 return ws;
109}
110
111std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
Austin Schuh812d0d12021-11-04 20:16:48 -0700112 std::string_view key,
113 std::string_view version,
114 std::string_view protocol) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800115 auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
116 stream.SetData(ws);
117 ws->StartServer(key, version, protocol);
118 return ws;
119}
120
Austin Schuh812d0d12021-11-04 20:16:48 -0700121void WebSocket::Close(uint16_t code, std::string_view reason) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800122 SendClose(code, reason);
Austin Schuh812d0d12021-11-04 20:16:48 -0700123 if (m_state != FAILED && m_state != CLOSED) {
124 m_state = CLOSING;
125 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800126}
127
Austin Schuh812d0d12021-11-04 20:16:48 -0700128void WebSocket::Fail(uint16_t code, std::string_view reason) {
129 if (m_state == FAILED || m_state == CLOSED) {
130 return;
131 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800132 SendClose(code, reason);
133 SetClosed(code, reason, true);
134 Shutdown();
135}
136
Austin Schuh812d0d12021-11-04 20:16:48 -0700137void WebSocket::Terminate(uint16_t code, std::string_view reason) {
138 if (m_state == FAILED || m_state == CLOSED) {
139 return;
140 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800141 SetClosed(code, reason);
142 Shutdown();
143}
144
Austin Schuh812d0d12021-11-04 20:16:48 -0700145void WebSocket::StartClient(std::string_view uri, std::string_view host,
146 span<const std::string_view> protocols,
Brian Silverman8fce7482020-01-05 13:18:21 -0800147 const ClientOptions& options) {
148 // Create client handshake data
149 m_clientHandshake = std::make_unique<ClientHandshakeData>();
150
151 // Build client request
152 SmallVector<uv::Buffer, 4> bufs;
153 raw_uv_ostream os{bufs, 4096};
154
155 os << "GET " << uri << " HTTP/1.1\r\n";
156 os << "Host: " << host << "\r\n";
157 os << "Upgrade: websocket\r\n";
158 os << "Connection: Upgrade\r\n";
159 os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
160 os << "Sec-WebSocket-Version: 13\r\n";
161
162 // protocols (if provided)
163 if (!protocols.empty()) {
164 os << "Sec-WebSocket-Protocol: ";
165 bool first = true;
166 for (auto protocol : protocols) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700167 if (!first) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800168 os << ", ";
Austin Schuh812d0d12021-11-04 20:16:48 -0700169 } else {
Brian Silverman8fce7482020-01-05 13:18:21 -0800170 first = false;
Austin Schuh812d0d12021-11-04 20:16:48 -0700171 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800172 os << protocol;
173 // also save for later checking against server response
174 m_clientHandshake->protocols.emplace_back(protocol);
175 }
176 os << "\r\n";
177 }
178
179 // other headers
Austin Schuh812d0d12021-11-04 20:16:48 -0700180 for (auto&& header : options.extraHeaders) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800181 os << header.first << ": " << header.second << "\r\n";
Austin Schuh812d0d12021-11-04 20:16:48 -0700182 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800183
184 // finish headers
185 os << "\r\n";
186
187 // Send client request
188 m_stream.Write(bufs, [](auto bufs, uv::Error) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700189 for (auto& buf : bufs) {
190 buf.Deallocate();
191 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800192 });
193
194 // Set up client response handling
Austin Schuh812d0d12021-11-04 20:16:48 -0700195 m_clientHandshake->parser.status.connect([this](std::string_view status) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800196 unsigned int code = m_clientHandshake->parser.GetStatusCode();
Austin Schuh812d0d12021-11-04 20:16:48 -0700197 if (code != 101) {
198 Terminate(code, status);
199 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800200 });
201 m_clientHandshake->parser.header.connect(
Austin Schuh812d0d12021-11-04 20:16:48 -0700202 [this](std::string_view name, std::string_view value) {
203 value = trim(value);
204 if (equals_lower(name, "upgrade")) {
205 if (!equals_lower(value, "websocket")) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800206 return Terminate(1002, "invalid upgrade response value");
Austin Schuh812d0d12021-11-04 20:16:48 -0700207 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800208 m_clientHandshake->hasUpgrade = true;
Austin Schuh812d0d12021-11-04 20:16:48 -0700209 } else if (equals_lower(name, "connection")) {
210 if (!equals_lower(value, "upgrade")) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800211 return Terminate(1002, "invalid connection response value");
Austin Schuh812d0d12021-11-04 20:16:48 -0700212 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800213 m_clientHandshake->hasConnection = true;
Austin Schuh812d0d12021-11-04 20:16:48 -0700214 } else if (equals_lower(name, "sec-websocket-accept")) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800215 // Check against expected response
216 SmallString<64> acceptBuf;
Austin Schuh812d0d12021-11-04 20:16:48 -0700217 if (!equals(value, AcceptHash(m_clientHandshake->key, acceptBuf))) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800218 return Terminate(1002, "invalid accept key");
Austin Schuh812d0d12021-11-04 20:16:48 -0700219 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800220 m_clientHandshake->hasAccept = true;
Austin Schuh812d0d12021-11-04 20:16:48 -0700221 } else if (equals_lower(name, "sec-websocket-extensions")) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800222 // No extensions are supported
Austin Schuh812d0d12021-11-04 20:16:48 -0700223 if (!value.empty()) {
224 return Terminate(1010, "unsupported extension");
225 }
226 } else if (equals_lower(name, "sec-websocket-protocol")) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800227 // Make sure it was one of the provided protocols
228 bool match = false;
229 for (auto&& protocol : m_clientHandshake->protocols) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700230 if (equals_lower(value, protocol)) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800231 match = true;
232 break;
233 }
234 }
Austin Schuh812d0d12021-11-04 20:16:48 -0700235 if (!match) {
236 return Terminate(1003, "unsupported protocol");
237 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800238 m_clientHandshake->hasProtocol = true;
239 m_protocol = value;
240 }
241 });
242 m_clientHandshake->parser.headersComplete.connect([this](bool) {
243 if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
244 !m_clientHandshake->hasAccept ||
245 (!m_clientHandshake->hasProtocol &&
246 !m_clientHandshake->protocols.empty())) {
247 return Terminate(1002, "invalid response");
248 }
249 if (m_state == CONNECTING) {
250 m_state = OPEN;
251 open(m_protocol);
252 }
253 });
254
255 // Start handshake timer if a timeout was specified
256 if (options.handshakeTimeout != (uv::Timer::Time::max)()) {
257 auto timer = uv::Timer::Create(m_stream.GetLoopRef());
258 timer->timeout.connect(
259 [this]() { Terminate(1006, "connection timed out"); });
260 timer->Start(options.handshakeTimeout);
261 m_clientHandshake->timer = timer;
262 }
263}
264
Austin Schuh812d0d12021-11-04 20:16:48 -0700265void WebSocket::StartServer(std::string_view key, std::string_view version,
266 std::string_view protocol) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800267 m_protocol = protocol;
268
269 // Build server response
270 SmallVector<uv::Buffer, 4> bufs;
271 raw_uv_ostream os{bufs, 4096};
272
273 // Handle unsupported version
274 if (version != "13") {
275 os << "HTTP/1.1 426 Upgrade Required\r\n";
276 os << "Upgrade: WebSocket\r\n";
277 os << "Sec-WebSocket-Version: 13\r\n\r\n";
278 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700279 for (auto& buf : bufs) {
280 buf.Deallocate();
281 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800282 // XXX: Should we support sending a new handshake on the same connection?
283 // XXX: "this->" is required by GCC 5.5 (bug)
284 this->Terminate(1003, "unsupported protocol version");
285 });
286 return;
287 }
288
289 os << "HTTP/1.1 101 Switching Protocols\r\n";
290 os << "Upgrade: websocket\r\n";
291 os << "Connection: Upgrade\r\n";
292
293 // accept hash
294 SmallString<64> acceptBuf;
295 os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
296
Austin Schuh812d0d12021-11-04 20:16:48 -0700297 if (!protocol.empty()) {
298 os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
299 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800300
301 // end headers
302 os << "\r\n";
303
304 // Send server response
305 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700306 for (auto& buf : bufs) {
307 buf.Deallocate();
308 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800309 if (m_state == CONNECTING) {
310 m_state = OPEN;
311 open(m_protocol);
312 }
313 });
314}
315
Austin Schuh812d0d12021-11-04 20:16:48 -0700316void WebSocket::SendClose(uint16_t code, std::string_view reason) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800317 SmallVector<uv::Buffer, 4> bufs;
318 if (code != 1005) {
319 raw_uv_ostream os{bufs, 4096};
320 const uint8_t codeMsb[] = {static_cast<uint8_t>((code >> 8) & 0xff),
321 static_cast<uint8_t>(code & 0xff)};
Austin Schuh812d0d12021-11-04 20:16:48 -0700322 os << span{codeMsb};
323 os << reason;
Brian Silverman8fce7482020-01-05 13:18:21 -0800324 }
325 Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700326 for (auto&& buf : bufs) {
327 buf.Deallocate();
328 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800329 });
330}
331
Austin Schuh812d0d12021-11-04 20:16:48 -0700332void WebSocket::SetClosed(uint16_t code, std::string_view reason, bool failed) {
333 if (m_state == FAILED || m_state == CLOSED) {
334 return;
335 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800336 m_state = failed ? FAILED : CLOSED;
Austin Schuh812d0d12021-11-04 20:16:48 -0700337 closed(code, reason);
Brian Silverman8fce7482020-01-05 13:18:21 -0800338}
339
340void WebSocket::Shutdown() {
341 m_stream.Shutdown([this] { m_stream.Close(); });
342}
343
344void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
345 // ignore incoming data if we're failed or closed
Austin Schuh812d0d12021-11-04 20:16:48 -0700346 if (m_state == FAILED || m_state == CLOSED) {
347 return;
348 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800349
Austin Schuh812d0d12021-11-04 20:16:48 -0700350 std::string_view data{buf.base, size};
Brian Silverman8fce7482020-01-05 13:18:21 -0800351
352 // Handle connecting state (mainly on client)
353 if (m_state == CONNECTING) {
354 if (m_clientHandshake) {
355 data = m_clientHandshake->parser.Execute(data);
356 // check for parser failure
Austin Schuh812d0d12021-11-04 20:16:48 -0700357 if (m_clientHandshake->parser.HasError()) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800358 return Terminate(1003, "invalid response");
Austin Schuh812d0d12021-11-04 20:16:48 -0700359 }
360 if (m_state != OPEN) {
361 return; // not done with handshake yet
362 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800363
364 // we're done with the handshake, so release its memory
365 m_clientHandshake.reset();
366
367 // fall through to process additional data after handshake
368 } else {
369 return Terminate(1003, "got data on server before response");
370 }
371 }
372
373 // Message processing
374 while (!data.empty()) {
375 if (m_frameSize == UINT64_MAX) {
376 // Need at least two bytes to determine header length
377 if (m_header.size() < 2u) {
378 size_t toCopy = (std::min)(2u - m_header.size(), data.size());
Austin Schuh812d0d12021-11-04 20:16:48 -0700379 m_header.append(data.data(), data.data() + toCopy);
380 data.remove_prefix(toCopy);
381 if (m_header.size() < 2u) {
382 return; // need more data
383 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800384
385 // Validate RSV bits are zero
Austin Schuh812d0d12021-11-04 20:16:48 -0700386 if ((m_header[0] & 0x70) != 0) {
387 return Fail(1002, "nonzero RSV");
388 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800389 }
390
391 // Once we have first two bytes, we can calculate the header size
392 if (m_headerSize == 0) {
393 m_headerSize = 2;
394 uint8_t len = m_header[1] & kLenMask;
Austin Schuh812d0d12021-11-04 20:16:48 -0700395 if (len == 126) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800396 m_headerSize += 2;
Austin Schuh812d0d12021-11-04 20:16:48 -0700397 } else if (len == 127) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800398 m_headerSize += 8;
Austin Schuh812d0d12021-11-04 20:16:48 -0700399 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800400 bool masking = (m_header[1] & kFlagMasking) != 0;
Austin Schuh812d0d12021-11-04 20:16:48 -0700401 if (masking) {
402 m_headerSize += 4; // masking key
403 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800404 // On server side, incoming messages MUST be masked
405 // On client side, incoming messages MUST NOT be masked
Austin Schuh812d0d12021-11-04 20:16:48 -0700406 if (m_server && !masking) {
407 return Fail(1002, "client data not masked");
408 }
409 if (!m_server && masking) {
410 return Fail(1002, "server data masked");
411 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800412 }
413
414 // Need to complete header to calculate message size
415 if (m_header.size() < m_headerSize) {
416 size_t toCopy = (std::min)(m_headerSize - m_header.size(), data.size());
Austin Schuh812d0d12021-11-04 20:16:48 -0700417 m_header.append(data.data(), data.data() + toCopy);
418 data.remove_prefix(toCopy);
419 if (m_header.size() < m_headerSize) {
420 return; // need more data
421 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800422 }
423
424 if (m_header.size() >= m_headerSize) {
425 // get payload length
426 uint8_t len = m_header[1] & kLenMask;
Austin Schuh812d0d12021-11-04 20:16:48 -0700427 if (len == 126) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800428 m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
429 static_cast<uint16_t>(m_header[3]);
Austin Schuh812d0d12021-11-04 20:16:48 -0700430 } else if (len == 127) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800431 m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
432 (static_cast<uint64_t>(m_header[3]) << 48) |
433 (static_cast<uint64_t>(m_header[4]) << 40) |
434 (static_cast<uint64_t>(m_header[5]) << 32) |
435 (static_cast<uint64_t>(m_header[6]) << 24) |
436 (static_cast<uint64_t>(m_header[7]) << 16) |
437 (static_cast<uint64_t>(m_header[8]) << 8) |
438 static_cast<uint64_t>(m_header[9]);
Austin Schuh812d0d12021-11-04 20:16:48 -0700439 } else {
Brian Silverman8fce7482020-01-05 13:18:21 -0800440 m_frameSize = len;
Austin Schuh812d0d12021-11-04 20:16:48 -0700441 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800442
443 // limit maximum size
Austin Schuh812d0d12021-11-04 20:16:48 -0700444 if ((m_payload.size() + m_frameSize) > m_maxMessageSize) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800445 return Fail(1009, "message too large");
Austin Schuh812d0d12021-11-04 20:16:48 -0700446 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800447 }
448 }
449
450 if (m_frameSize != UINT64_MAX) {
451 size_t need = m_frameStart + m_frameSize - m_payload.size();
452 size_t toCopy = (std::min)(need, data.size());
Austin Schuh812d0d12021-11-04 20:16:48 -0700453 m_payload.append(data.data(), data.data() + toCopy);
454 data.remove_prefix(toCopy);
Brian Silverman8fce7482020-01-05 13:18:21 -0800455 need -= toCopy;
456 if (need == 0) {
457 // We have a complete frame
458 // If the message had masking, unmask it
459 if ((m_header[1] & kFlagMasking) != 0) {
460 uint8_t key[4] = {
461 m_header[m_headerSize - 4], m_header[m_headerSize - 3],
462 m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
463 int n = 0;
Austin Schuh812d0d12021-11-04 20:16:48 -0700464 for (uint8_t& ch : span{m_payload}.subspan(m_frameStart)) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800465 ch ^= key[n++];
Austin Schuh812d0d12021-11-04 20:16:48 -0700466 if (n >= 4) {
467 n = 0;
468 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800469 }
470 }
471
472 // Handle message
473 bool fin = (m_header[0] & kFlagFin) != 0;
474 uint8_t opcode = m_header[0] & kOpMask;
475 switch (opcode) {
476 case kOpCont:
477 switch (m_fragmentOpcode) {
478 case kOpText:
Austin Schuh812d0d12021-11-04 20:16:48 -0700479 if (!m_combineFragments || fin) {
480 text(std::string_view{reinterpret_cast<char*>(
481 m_payload.data()),
482 m_payload.size()},
Brian Silverman8fce7482020-01-05 13:18:21 -0800483 fin);
Austin Schuh812d0d12021-11-04 20:16:48 -0700484 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800485 break;
486 case kOpBinary:
Austin Schuh812d0d12021-11-04 20:16:48 -0700487 if (!m_combineFragments || fin) {
488 binary(m_payload, fin);
489 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800490 break;
491 default:
492 // no preceding message?
493 return Fail(1002, "invalid continuation message");
494 }
Austin Schuh812d0d12021-11-04 20:16:48 -0700495 if (fin) {
496 m_fragmentOpcode = 0;
497 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800498 break;
499 case kOpText:
Austin Schuh812d0d12021-11-04 20:16:48 -0700500 if (m_fragmentOpcode != 0) {
501 return Fail(1002, "incomplete fragment");
502 }
503 if (!m_combineFragments || fin) {
504 text(std::string_view{reinterpret_cast<char*>(m_payload.data()),
505 m_payload.size()},
Brian Silverman8fce7482020-01-05 13:18:21 -0800506 fin);
Austin Schuh812d0d12021-11-04 20:16:48 -0700507 }
508 if (!fin) {
509 m_fragmentOpcode = opcode;
510 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800511 break;
512 case kOpBinary:
Austin Schuh812d0d12021-11-04 20:16:48 -0700513 if (m_fragmentOpcode != 0) {
514 return Fail(1002, "incomplete fragment");
515 }
516 if (!m_combineFragments || fin) {
517 binary(m_payload, fin);
518 }
519 if (!fin) {
520 m_fragmentOpcode = opcode;
521 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800522 break;
523 case kOpClose: {
524 uint16_t code;
Austin Schuh812d0d12021-11-04 20:16:48 -0700525 std::string_view reason;
Brian Silverman8fce7482020-01-05 13:18:21 -0800526 if (!fin) {
527 code = 1002;
528 reason = "cannot fragment control frames";
529 } else if (m_payload.size() < 2) {
530 code = 1005;
531 } else {
532 code = (static_cast<uint16_t>(m_payload[0]) << 8) |
533 static_cast<uint16_t>(m_payload[1]);
Austin Schuh812d0d12021-11-04 20:16:48 -0700534 reason = drop_front(
535 {reinterpret_cast<char*>(m_payload.data()), m_payload.size()},
536 2);
Brian Silverman8fce7482020-01-05 13:18:21 -0800537 }
538 // Echo the close if we didn't previously send it
Austin Schuh812d0d12021-11-04 20:16:48 -0700539 if (m_state != CLOSING) {
540 SendClose(code, reason);
541 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800542 SetClosed(code, reason);
543 // If we're the server, shutdown the connection.
Austin Schuh812d0d12021-11-04 20:16:48 -0700544 if (m_server) {
545 Shutdown();
546 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800547 break;
548 }
549 case kOpPing:
Austin Schuh812d0d12021-11-04 20:16:48 -0700550 if (!fin) {
551 return Fail(1002, "cannot fragment control frames");
552 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800553 ping(m_payload);
554 break;
555 case kOpPong:
Austin Schuh812d0d12021-11-04 20:16:48 -0700556 if (!fin) {
557 return Fail(1002, "cannot fragment control frames");
558 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800559 pong(m_payload);
560 break;
561 default:
562 return Fail(1002, "invalid message opcode");
563 }
564
565 // Prepare for next message
566 m_header.clear();
567 m_headerSize = 0;
Austin Schuh812d0d12021-11-04 20:16:48 -0700568 if (!m_combineFragments || fin) {
569 m_payload.clear();
570 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800571 m_frameStart = m_payload.size();
572 m_frameSize = UINT64_MAX;
573 }
574 }
575 }
576}
577
578void WebSocket::Send(
Austin Schuh812d0d12021-11-04 20:16:48 -0700579 uint8_t opcode, span<const uv::Buffer> data,
580 std::function<void(span<uv::Buffer>, uv::Error)> callback) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800581 // If we're not open, emit an error and don't send the data
582 if (m_state != OPEN) {
583 int err;
Austin Schuh812d0d12021-11-04 20:16:48 -0700584 if (m_state == CONNECTING) {
Brian Silverman8fce7482020-01-05 13:18:21 -0800585 err = UV_EAGAIN;
Austin Schuh812d0d12021-11-04 20:16:48 -0700586 } else {
Brian Silverman8fce7482020-01-05 13:18:21 -0800587 err = UV_ESHUTDOWN;
Austin Schuh812d0d12021-11-04 20:16:48 -0700588 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800589 SmallVector<uv::Buffer, 4> bufs{data.begin(), data.end()};
590 callback(bufs, uv::Error{err});
591 return;
592 }
593
Austin Schuh812d0d12021-11-04 20:16:48 -0700594 auto req = std::make_shared<WebSocketWriteReq>(std::move(callback));
Brian Silverman8fce7482020-01-05 13:18:21 -0800595 raw_uv_ostream os{req->m_bufs, 4096};
596
597 // opcode (includes FIN bit)
598 os << static_cast<unsigned char>(opcode);
599
600 // payload length
601 uint64_t size = 0;
Austin Schuh812d0d12021-11-04 20:16:48 -0700602 for (auto&& buf : data) {
603 size += buf.len;
604 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800605 if (size < 126) {
606 os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | size);
607 } else if (size <= 0xffff) {
608 os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 126);
609 const uint8_t sizeMsb[] = {static_cast<uint8_t>((size >> 8) & 0xff),
610 static_cast<uint8_t>(size & 0xff)};
Austin Schuh812d0d12021-11-04 20:16:48 -0700611 os << span{sizeMsb};
Brian Silverman8fce7482020-01-05 13:18:21 -0800612 } else {
613 os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 127);
614 const uint8_t sizeMsb[] = {static_cast<uint8_t>((size >> 56) & 0xff),
615 static_cast<uint8_t>((size >> 48) & 0xff),
616 static_cast<uint8_t>((size >> 40) & 0xff),
617 static_cast<uint8_t>((size >> 32) & 0xff),
618 static_cast<uint8_t>((size >> 24) & 0xff),
619 static_cast<uint8_t>((size >> 16) & 0xff),
620 static_cast<uint8_t>((size >> 8) & 0xff),
621 static_cast<uint8_t>(size & 0xff)};
Austin Schuh812d0d12021-11-04 20:16:48 -0700622 os << span{sizeMsb};
Brian Silverman8fce7482020-01-05 13:18:21 -0800623 }
624
625 // clients need to mask the input data
626 if (!m_server) {
627 // generate masking key
628 static std::random_device rd;
629 static std::default_random_engine gen{rd()};
630 std::uniform_int_distribution<unsigned int> dist(0, 255);
631 uint8_t key[4];
Austin Schuh812d0d12021-11-04 20:16:48 -0700632 for (uint8_t& v : key) {
633 v = dist(gen);
634 }
635 os << span<const uint8_t>{key, 4};
Brian Silverman8fce7482020-01-05 13:18:21 -0800636 // copy and mask data
637 int n = 0;
638 for (auto&& buf : data) {
639 for (auto&& ch : buf.data()) {
640 os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
Austin Schuh812d0d12021-11-04 20:16:48 -0700641 if (n >= 4) {
642 n = 0;
643 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800644 }
645 }
646 req->m_startUser = req->m_bufs.size();
647 req->m_bufs.append(data.begin(), data.end());
648 // don't send the user bufs as we copied their data
Austin Schuh812d0d12021-11-04 20:16:48 -0700649 m_stream.Write(span{req->m_bufs}.subspan(0, req->m_startUser), req);
Brian Silverman8fce7482020-01-05 13:18:21 -0800650 } else {
651 // servers can just send the buffers directly without masking
652 req->m_startUser = req->m_bufs.size();
653 req->m_bufs.append(data.begin(), data.end());
654 m_stream.Write(req->m_bufs, req);
655 }
656}