blob: 43b901ecb680ef0449259db4dd8b17c018845993 [file] [log] [blame]
James Kuszmaulcf324122023-01-14 14:07:17 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#include "wpinet/WebSocket.h"
6
7#include <random>
James Kuszmaulb13e13f2023-11-22 20:44:04 -08008#include <span>
9#include <string>
10#include <string_view>
James Kuszmaulcf324122023-01-14 14:07:17 -080011
12#include <fmt/format.h>
13#include <wpi/Base64.h>
14#include <wpi/SmallString.h>
15#include <wpi/SmallVector.h>
16#include <wpi/StringExtras.h>
17#include <wpi/raw_ostream.h>
18#include <wpi/sha1.h>
19
James Kuszmaulb13e13f2023-11-22 20:44:04 -080020#include "WebSocketDebug.h"
21#include "WebSocketSerializer.h"
James Kuszmaulcf324122023-01-14 14:07:17 -080022#include "wpinet/HttpParser.h"
23#include "wpinet/raw_uv_ostream.h"
24#include "wpinet/uv/Stream.h"
25
26using namespace wpi;
27
James Kuszmaulb13e13f2023-11-22 20:44:04 -080028#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
29static std::string DebugBinary(std::span<const uint8_t> val) {
30#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
31 std::string str;
32 wpi::raw_string_ostream stros{str};
33 for (auto ch : val) {
34 stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
35 }
36 return str;
37#else
38 return "";
39#endif
40}
41
42static inline std::string_view DebugText(std::string_view val) {
43#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
44 return val;
45#else
46 return "";
47#endif
48}
49#endif // WPINET_WEBSOCKET_VERBOSE_DEBUG
50
51class WebSocket::WriteReq : public uv::WriteReq,
52 public detail::WebSocketWriteReqBase {
James Kuszmaulcf324122023-01-14 14:07:17 -080053 public:
James Kuszmaulb13e13f2023-11-22 20:44:04 -080054 explicit WriteReq(
55 std::weak_ptr<WebSocket> ws,
James Kuszmaulcf324122023-01-14 14:07:17 -080056 std::function<void(std::span<uv::Buffer>, uv::Error)> callback)
James Kuszmaulb13e13f2023-11-22 20:44:04 -080057 : m_ws{std::move(ws)}, m_callback{std::move(callback)} {
58 finish.connect([this](uv::Error err) { Send(err); });
James Kuszmaulcf324122023-01-14 14:07:17 -080059 }
60
James Kuszmaulb13e13f2023-11-22 20:44:04 -080061 void Send(uv::Error err) {
62 auto ws = m_ws.lock();
63 if (!ws || err) {
64 WS_DEBUG("no WS or error, calling callback\n");
65 m_frames.ReleaseBufs();
66 m_callback(m_userBufs, err);
67 return;
68 }
69
70 // Continue() is designed so this is *only* called on frame boundaries
71 if (m_controlCont) {
72 // We have a control frame; switch to it. We will come back here via
73 // the control frame's m_cont when it's done.
74 WS_DEBUG("Continuing with a control write\n");
75 auto controlCont = std::move(m_controlCont);
76 m_controlCont.reset();
77 return controlCont->Send({});
78 }
79 int result = Continue(ws->m_stream, shared_from_this());
80 WS_DEBUG("Continue() -> {}\n", result);
81 if (result <= 0) {
82 m_frames.ReleaseBufs();
83 m_callback(m_userBufs, uv::Error{result});
84 if (result == 0 && m_cont) {
85 WS_DEBUG("Continuing with another write\n");
86 ws->m_curWriteReq = m_cont;
87 return m_cont->Send({});
88 } else {
89 ws->m_writeInProgress = false;
90 ws->m_curWriteReq.reset();
91 ws->m_lastWriteReq.reset();
92 }
93 }
94 }
95
96 std::weak_ptr<WebSocket> m_ws;
James Kuszmaulcf324122023-01-14 14:07:17 -080097 std::function<void(std::span<uv::Buffer>, uv::Error)> m_callback;
James Kuszmaulb13e13f2023-11-22 20:44:04 -080098 std::shared_ptr<WriteReq> m_cont;
99 std::shared_ptr<WriteReq> m_controlCont;
James Kuszmaulcf324122023-01-14 14:07:17 -0800100};
James Kuszmaulcf324122023-01-14 14:07:17 -0800101
102static constexpr uint8_t kFlagMasking = 0x80;
103static constexpr uint8_t kLenMask = 0x7f;
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800104static constexpr size_t kWriteAllocSize = 4096;
James Kuszmaulcf324122023-01-14 14:07:17 -0800105
106class WebSocket::ClientHandshakeData {
107 public:
108 ClientHandshakeData() {
109 // key is a random nonce
110 static std::random_device rd;
111 static std::default_random_engine gen{rd()};
112 std::uniform_int_distribution<unsigned int> dist(0, 255);
113 char nonce[16]; // the nonce sent to the server
114 for (char& v : nonce) {
115 v = static_cast<char>(dist(gen));
116 }
117 raw_svector_ostream os(key);
118 Base64Encode(os, {nonce, 16});
119 }
120 ~ClientHandshakeData() {
121 if (auto t = timer.lock()) {
122 t->Stop();
123 t->Close();
124 }
125 }
126
127 SmallString<64> key; // the key sent to the server
128 SmallVector<std::string, 2> protocols; // valid protocols
129 HttpParser parser{HttpParser::kResponse}; // server response parser
130 bool hasUpgrade = false;
131 bool hasConnection = false;
132 bool hasAccept = false;
133 bool hasProtocol = false;
134
135 std::weak_ptr<uv::Timer> timer;
136};
137
138static std::string_view AcceptHash(std::string_view key,
139 SmallVectorImpl<char>& buf) {
140 SHA1 hash;
141 hash.Update(key);
142 hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
143 SmallString<64> hashBuf;
144 return Base64Encode(hash.RawFinal(hashBuf), buf);
145}
146
147WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
148 : m_stream{stream}, m_server{server} {
149 // Connect closed and error signals to ourselves
150 m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
151 m_stream.error.connect([this](uv::Error err) {
152 Terminate(1006, fmt::format("stream error: {}", err.name()));
153 });
154
155 // Start reading
156 m_stream.StopRead(); // we may have been reading
157 m_stream.StartRead();
158 m_stream.data.connect(
159 [this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
160 m_stream.end.connect(
161 [this]() { Terminate(1006, "remote end closed connection"); });
162}
163
164WebSocket::~WebSocket() = default;
165
166std::shared_ptr<WebSocket> WebSocket::CreateClient(
167 uv::Stream& stream, std::string_view uri, std::string_view host,
168 std::span<const std::string_view> protocols, const ClientOptions& options) {
169 auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
170 stream.SetData(ws);
171 ws->StartClient(uri, host, protocols, options);
172 return ws;
173}
174
175std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
176 std::string_view key,
177 std::string_view version,
178 std::string_view protocol) {
179 auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
180 stream.SetData(ws);
181 ws->StartServer(key, version, protocol);
182 return ws;
183}
184
185void WebSocket::Close(uint16_t code, std::string_view reason) {
186 SendClose(code, reason);
187 if (m_state != FAILED && m_state != CLOSED) {
188 m_state = CLOSING;
189 }
190}
191
192void WebSocket::Fail(uint16_t code, std::string_view reason) {
193 if (m_state == FAILED || m_state == CLOSED) {
194 return;
195 }
196 SendClose(code, reason);
197 SetClosed(code, reason, true);
198 Shutdown();
199}
200
201void WebSocket::Terminate(uint16_t code, std::string_view reason) {
202 if (m_state == FAILED || m_state == CLOSED) {
203 return;
204 }
205 SetClosed(code, reason);
206 Shutdown();
207}
208
209void WebSocket::StartClient(std::string_view uri, std::string_view host,
210 std::span<const std::string_view> protocols,
211 const ClientOptions& options) {
212 // Create client handshake data
213 m_clientHandshake = std::make_unique<ClientHandshakeData>();
214
215 // Build client request
216 SmallVector<uv::Buffer, 4> bufs;
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800217 raw_uv_ostream os{bufs, kWriteAllocSize};
James Kuszmaulcf324122023-01-14 14:07:17 -0800218
219 os << "GET " << uri << " HTTP/1.1\r\n";
220 os << "Host: " << host << "\r\n";
221 os << "Upgrade: websocket\r\n";
222 os << "Connection: Upgrade\r\n";
223 os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
224 os << "Sec-WebSocket-Version: 13\r\n";
225
226 // protocols (if provided)
227 if (!protocols.empty()) {
228 os << "Sec-WebSocket-Protocol: ";
229 bool first = true;
230 for (auto protocol : protocols) {
231 if (!first) {
232 os << ", ";
233 } else {
234 first = false;
235 }
236 os << protocol;
237 // also save for later checking against server response
238 m_clientHandshake->protocols.emplace_back(protocol);
239 }
240 os << "\r\n";
241 }
242
243 // other headers
244 for (auto&& header : options.extraHeaders) {
245 os << header.first << ": " << header.second << "\r\n";
246 }
247
248 // finish headers
249 os << "\r\n";
250
251 // Send client request
252 m_stream.Write(bufs, [](auto bufs, uv::Error) {
253 for (auto& buf : bufs) {
254 buf.Deallocate();
255 }
256 });
257
258 // Set up client response handling
259 m_clientHandshake->parser.status.connect([this](std::string_view status) {
260 unsigned int code = m_clientHandshake->parser.GetStatusCode();
261 if (code != 101) {
262 Terminate(code, status);
263 }
264 });
265 m_clientHandshake->parser.header.connect(
266 [this](std::string_view name, std::string_view value) {
267 value = trim(value);
268 if (equals_lower(name, "upgrade")) {
269 if (!equals_lower(value, "websocket")) {
270 return Terminate(1002, "invalid upgrade response value");
271 }
272 m_clientHandshake->hasUpgrade = true;
273 } else if (equals_lower(name, "connection")) {
274 if (!equals_lower(value, "upgrade")) {
275 return Terminate(1002, "invalid connection response value");
276 }
277 m_clientHandshake->hasConnection = true;
278 } else if (equals_lower(name, "sec-websocket-accept")) {
279 // Check against expected response
280 SmallString<64> acceptBuf;
281 if (!equals(value, AcceptHash(m_clientHandshake->key, acceptBuf))) {
282 return Terminate(1002, "invalid accept key");
283 }
284 m_clientHandshake->hasAccept = true;
285 } else if (equals_lower(name, "sec-websocket-extensions")) {
286 // No extensions are supported
287 if (!value.empty()) {
288 return Terminate(1010, "unsupported extension");
289 }
290 } else if (equals_lower(name, "sec-websocket-protocol")) {
291 // Make sure it was one of the provided protocols
292 bool match = false;
293 for (auto&& protocol : m_clientHandshake->protocols) {
294 if (equals_lower(value, protocol)) {
295 match = true;
296 break;
297 }
298 }
299 if (!match) {
300 return Terminate(1003, "unsupported protocol");
301 }
302 m_clientHandshake->hasProtocol = true;
303 m_protocol = value;
304 }
305 });
306 m_clientHandshake->parser.headersComplete.connect([this](bool) {
307 if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
308 !m_clientHandshake->hasAccept ||
309 (!m_clientHandshake->hasProtocol &&
310 !m_clientHandshake->protocols.empty())) {
311 return Terminate(1002, "invalid response");
312 }
313 if (m_state == CONNECTING) {
314 m_state = OPEN;
315 open(m_protocol);
316 }
317 });
318
319 // Start handshake timer if a timeout was specified
320 if (options.handshakeTimeout != (uv::Timer::Time::max)()) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800321 if (auto timer = uv::Timer::Create(m_stream.GetLoopRef())) {
322 timer->timeout.connect(
323 [this]() { Terminate(1006, "connection timed out"); });
324 timer->Start(options.handshakeTimeout);
325 m_clientHandshake->timer = timer;
326 }
James Kuszmaulcf324122023-01-14 14:07:17 -0800327 }
328}
329
330void WebSocket::StartServer(std::string_view key, std::string_view version,
331 std::string_view protocol) {
332 m_protocol = protocol;
333
334 // Build server response
335 SmallVector<uv::Buffer, 4> bufs;
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800336 raw_uv_ostream os{bufs, kWriteAllocSize};
James Kuszmaulcf324122023-01-14 14:07:17 -0800337
338 // Handle unsupported version
339 if (version != "13") {
340 os << "HTTP/1.1 426 Upgrade Required\r\n";
341 os << "Upgrade: WebSocket\r\n";
342 os << "Sec-WebSocket-Version: 13\r\n\r\n";
343 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
344 for (auto& buf : bufs) {
345 buf.Deallocate();
346 }
347 // XXX: Should we support sending a new handshake on the same connection?
348 // XXX: "this->" is required by GCC 5.5 (bug)
349 this->Terminate(1003, "unsupported protocol version");
350 });
351 return;
352 }
353
354 os << "HTTP/1.1 101 Switching Protocols\r\n";
355 os << "Upgrade: websocket\r\n";
356 os << "Connection: Upgrade\r\n";
357
358 // accept hash
359 SmallString<64> acceptBuf;
360 os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
361
362 if (!protocol.empty()) {
363 os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
364 }
365
366 // end headers
367 os << "\r\n";
368
369 // Send server response
370 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
371 for (auto& buf : bufs) {
372 buf.Deallocate();
373 }
374 if (m_state == CONNECTING) {
375 m_state = OPEN;
376 open(m_protocol);
377 }
378 });
379}
380
381void WebSocket::SendClose(uint16_t code, std::string_view reason) {
382 SmallVector<uv::Buffer, 4> bufs;
383 if (code != 1005) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800384 raw_uv_ostream os{bufs, kWriteAllocSize};
James Kuszmaulcf324122023-01-14 14:07:17 -0800385 const uint8_t codeMsb[] = {static_cast<uint8_t>((code >> 8) & 0xff),
386 static_cast<uint8_t>(code & 0xff)};
387 os << std::span{codeMsb};
388 os << reason;
389 }
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800390 SendControl(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800391 for (auto&& buf : bufs) {
392 buf.Deallocate();
393 }
394 });
395}
396
397void WebSocket::SetClosed(uint16_t code, std::string_view reason, bool failed) {
398 if (m_state == FAILED || m_state == CLOSED) {
399 return;
400 }
401 m_state = failed ? FAILED : CLOSED;
402 closed(code, reason);
403}
404
405void WebSocket::Shutdown() {
406 m_stream.Shutdown([this] { m_stream.Close(); });
407}
408
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800409static inline void Unmask(std::span<uint8_t> data,
410 std::span<const uint8_t, 4> key) {
411 int n = 0;
412 for (uint8_t& ch : data) {
413 ch ^= key[n++];
414 if (n >= 4) {
415 n = 0;
416 }
417 }
418}
419
James Kuszmaulcf324122023-01-14 14:07:17 -0800420void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
421 // ignore incoming data if we're failed or closed
422 if (m_state == FAILED || m_state == CLOSED) {
423 return;
424 }
425
426 std::string_view data{buf.base, size};
427
428 // Handle connecting state (mainly on client)
429 if (m_state == CONNECTING) {
430 if (m_clientHandshake) {
431 data = m_clientHandshake->parser.Execute(data);
432 // check for parser failure
433 if (m_clientHandshake->parser.HasError()) {
434 return Terminate(1003, "invalid response");
435 }
436 if (m_state != OPEN) {
437 return; // not done with handshake yet
438 }
439
440 // we're done with the handshake, so release its memory
441 m_clientHandshake.reset();
442
443 // fall through to process additional data after handshake
444 } else {
445 return Terminate(1003, "got data on server before response");
446 }
447 }
448
449 // Message processing
450 while (!data.empty()) {
451 if (m_frameSize == UINT64_MAX) {
452 // Need at least two bytes to determine header length
453 if (m_header.size() < 2u) {
454 size_t toCopy = (std::min)(2u - m_header.size(), data.size());
455 m_header.append(data.data(), data.data() + toCopy);
456 data.remove_prefix(toCopy);
457 if (m_header.size() < 2u) {
458 return; // need more data
459 }
460
461 // Validate RSV bits are zero
462 if ((m_header[0] & 0x70) != 0) {
463 return Fail(1002, "nonzero RSV");
464 }
465 }
466
467 // Once we have first two bytes, we can calculate the header size
468 if (m_headerSize == 0) {
469 m_headerSize = 2;
470 uint8_t len = m_header[1] & kLenMask;
471 if (len == 126) {
472 m_headerSize += 2;
473 } else if (len == 127) {
474 m_headerSize += 8;
475 }
476 bool masking = (m_header[1] & kFlagMasking) != 0;
477 if (masking) {
478 m_headerSize += 4; // masking key
479 }
480 // On server side, incoming messages MUST be masked
481 // On client side, incoming messages MUST NOT be masked
482 if (m_server && !masking) {
483 return Fail(1002, "client data not masked");
484 }
485 if (!m_server && masking) {
486 return Fail(1002, "server data masked");
487 }
488 }
489
490 // Need to complete header to calculate message size
491 if (m_header.size() < m_headerSize) {
492 size_t toCopy = (std::min)(m_headerSize - m_header.size(), data.size());
493 m_header.append(data.data(), data.data() + toCopy);
494 data.remove_prefix(toCopy);
495 if (m_header.size() < m_headerSize) {
496 return; // need more data
497 }
498 }
499
500 if (m_header.size() >= m_headerSize) {
501 // get payload length
502 uint8_t len = m_header[1] & kLenMask;
503 if (len == 126) {
504 m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
505 static_cast<uint16_t>(m_header[3]);
506 } else if (len == 127) {
507 m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
508 (static_cast<uint64_t>(m_header[3]) << 48) |
509 (static_cast<uint64_t>(m_header[4]) << 40) |
510 (static_cast<uint64_t>(m_header[5]) << 32) |
511 (static_cast<uint64_t>(m_header[6]) << 24) |
512 (static_cast<uint64_t>(m_header[7]) << 16) |
513 (static_cast<uint64_t>(m_header[8]) << 8) |
514 static_cast<uint64_t>(m_header[9]);
515 } else {
516 m_frameSize = len;
517 }
518
519 // limit maximum size
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800520 bool control = (m_header[0] & kFlagControl) != 0;
521 if (((control ? m_controlPayload.size() : m_payload.size()) +
522 m_frameSize) > m_maxMessageSize) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800523 return Fail(1009, "message too large");
524 }
525 }
526 }
527
528 if (m_frameSize != UINT64_MAX) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800529 bool control = (m_header[0] & kFlagControl) != 0;
530 size_t need;
531 if (control) {
532 need = m_frameSize - m_controlPayload.size();
533 } else {
534 need = m_frameStart + m_frameSize - m_payload.size();
535 }
James Kuszmaulcf324122023-01-14 14:07:17 -0800536 size_t toCopy = (std::min)(need, data.size());
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800537 if (control) {
538 m_controlPayload.append(data.data(), data.data() + toCopy);
539 } else {
540 m_payload.append(data.data(), data.data() + toCopy);
541 }
James Kuszmaulcf324122023-01-14 14:07:17 -0800542 data.remove_prefix(toCopy);
543 need -= toCopy;
544 if (need == 0) {
545 // We have a complete frame
546 // If the message had masking, unmask it
547 if ((m_header[1] & kFlagMasking) != 0) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800548 Unmask(control ? std::span{m_controlPayload}
549 : std::span{m_payload}.subspan(m_frameStart),
550 std::span<const uint8_t, 4>{&m_header[m_headerSize - 4], 4});
James Kuszmaulcf324122023-01-14 14:07:17 -0800551 }
552
553 // Handle message
554 bool fin = (m_header[0] & kFlagFin) != 0;
555 uint8_t opcode = m_header[0] & kOpMask;
556 switch (opcode) {
557 case kOpCont:
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800558 WS_DEBUG("WS Fragment {} [{}]\n", m_payload.size(),
559 DebugBinary(m_payload));
James Kuszmaulcf324122023-01-14 14:07:17 -0800560 switch (m_fragmentOpcode) {
561 case kOpText:
562 if (!m_combineFragments || fin) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800563 std::string_view content{
564 reinterpret_cast<char*>(m_payload.data()),
565 m_payload.size()};
566 WS_DEBUG("WS RecvText(Defrag) {} ({})\n", m_payload.size(),
567 DebugText(content));
568 text(content, fin);
James Kuszmaulcf324122023-01-14 14:07:17 -0800569 }
570 break;
571 case kOpBinary:
572 if (!m_combineFragments || fin) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800573 WS_DEBUG("WS RecvBinary(Defrag) {} ({})\n", m_payload.size(),
574 DebugBinary(m_payload));
James Kuszmaulcf324122023-01-14 14:07:17 -0800575 binary(m_payload, fin);
576 }
577 break;
578 default:
579 // no preceding message?
580 return Fail(1002, "invalid continuation message");
581 }
582 if (fin) {
583 m_fragmentOpcode = 0;
584 }
585 break;
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800586 case kOpText: {
587 std::string_view content{reinterpret_cast<char*>(m_payload.data()),
588 m_payload.size()};
James Kuszmaulcf324122023-01-14 14:07:17 -0800589 if (m_fragmentOpcode != 0) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800590 WS_DEBUG("WS RecvText {} ({}) -> INCOMPLETE FRAGMENT\n",
591 m_payload.size(), DebugText(content));
James Kuszmaulcf324122023-01-14 14:07:17 -0800592 return Fail(1002, "incomplete fragment");
593 }
594 if (!m_combineFragments || fin) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800595 WS_DEBUG("WS RecvText {} ({})\n", m_payload.size(),
596 DebugText(content));
597 text(content, fin);
James Kuszmaulcf324122023-01-14 14:07:17 -0800598 }
599 if (!fin) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800600 WS_DEBUG("WS RecvText {} StartFrag\n", m_payload.size());
James Kuszmaulcf324122023-01-14 14:07:17 -0800601 m_fragmentOpcode = opcode;
602 }
603 break;
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800604 }
James Kuszmaulcf324122023-01-14 14:07:17 -0800605 case kOpBinary:
606 if (m_fragmentOpcode != 0) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800607 WS_DEBUG("WS RecvBinary {} ({}) -> INCOMPLETE FRAGMENT\n",
608 m_payload.size(), DebugBinary(m_payload));
James Kuszmaulcf324122023-01-14 14:07:17 -0800609 return Fail(1002, "incomplete fragment");
610 }
611 if (!m_combineFragments || fin) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800612 WS_DEBUG("WS RecvBinary {} ({})\n", m_payload.size(),
613 DebugBinary(m_payload));
James Kuszmaulcf324122023-01-14 14:07:17 -0800614 binary(m_payload, fin);
615 }
616 if (!fin) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800617 WS_DEBUG("WS RecvBinary {} StartFrag\n", m_payload.size());
James Kuszmaulcf324122023-01-14 14:07:17 -0800618 m_fragmentOpcode = opcode;
619 }
620 break;
621 case kOpClose: {
622 uint16_t code;
623 std::string_view reason;
624 if (!fin) {
625 code = 1002;
626 reason = "cannot fragment control frames";
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800627 } else if (m_controlPayload.size() < 2) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800628 code = 1005;
629 } else {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800630 code = (static_cast<uint16_t>(m_controlPayload[0]) << 8) |
631 static_cast<uint16_t>(m_controlPayload[1]);
632 reason =
633 drop_front({reinterpret_cast<char*>(m_controlPayload.data()),
634 m_controlPayload.size()},
635 2);
James Kuszmaulcf324122023-01-14 14:07:17 -0800636 }
637 // Echo the close if we didn't previously send it
638 if (m_state != CLOSING) {
639 SendClose(code, reason);
640 }
641 SetClosed(code, reason);
642 // If we're the server, shutdown the connection.
643 if (m_server) {
644 Shutdown();
645 }
646 break;
647 }
648 case kOpPing:
649 if (!fin) {
650 return Fail(1002, "cannot fragment control frames");
651 }
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800652 // If the connection is open, send a Pong in response
653 if (m_state == OPEN) {
654 SmallVector<uv::Buffer, 4> bufs;
655 {
656 raw_uv_ostream os{bufs, kWriteAllocSize};
657 os << m_controlPayload;
658 }
659 SendPong(bufs, [](auto bufs, uv::Error) {
660 for (auto&& buf : bufs) {
661 buf.Deallocate();
662 }
663 });
664 }
665 WS_DEBUG("WS RecvPing() {} ({})\n", m_controlPayload.size(),
666 DebugBinary(m_controlPayload));
667 ping(m_controlPayload);
James Kuszmaulcf324122023-01-14 14:07:17 -0800668 break;
669 case kOpPong:
670 if (!fin) {
671 return Fail(1002, "cannot fragment control frames");
672 }
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800673 WS_DEBUG("WS RecvPong() {} ({})\n", m_controlPayload.size(),
674 DebugBinary(m_controlPayload));
675 pong(m_controlPayload);
James Kuszmaulcf324122023-01-14 14:07:17 -0800676 break;
677 default:
678 return Fail(1002, "invalid message opcode");
679 }
680
681 // Prepare for next message
682 m_header.clear();
683 m_headerSize = 0;
684 if (!m_combineFragments || fin) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800685 if (control) {
686 m_controlPayload.clear();
687 } else {
688 m_payload.clear();
689 }
James Kuszmaulcf324122023-01-14 14:07:17 -0800690 }
691 m_frameStart = m_payload.size();
692 m_frameSize = UINT64_MAX;
693 }
694 }
695 }
696}
697
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800698static void VerboseDebug(const WebSocket::Frame& frame) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800699#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800700 if ((frame.opcode & 0x7f) == 0x01) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800701 SmallString<128> str;
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800702#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
703 for (auto&& d : frame.data) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800704 str.append(std::string_view(d.base, d.len));
705 }
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800706#endif
James Kuszmaulcf324122023-01-14 14:07:17 -0800707 fmt::print("WS SendText({})\n", str.str());
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800708 } else if ((frame.opcode & 0x7f) == 0x02) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800709 SmallString<128> str;
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800710#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
James Kuszmaulcf324122023-01-14 14:07:17 -0800711 raw_svector_ostream stros{str};
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800712 for (auto&& d : frame.data) {
James Kuszmaulcf324122023-01-14 14:07:17 -0800713 for (auto ch : d.data()) {
714 stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
715 }
716 }
James Kuszmaulcf324122023-01-14 14:07:17 -0800717#endif
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800718 fmt::print("WS SendBinary({})\n", str.str());
James Kuszmaulcf324122023-01-14 14:07:17 -0800719 } else {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800720 SmallString<128> str;
721#ifdef WPINET_WEBSOCKET_VERBOSE_DEBUG_CONTENT
722 raw_svector_ostream stros{str};
723 for (auto&& d : frame.data) {
724 for (auto ch : d.data()) {
725 stros << fmt::format("{:02x},", static_cast<unsigned int>(ch) & 0xff);
James Kuszmaulcf324122023-01-14 14:07:17 -0800726 }
727 }
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800728#endif
729 fmt::print("WS SendOp({}, {})\n", frame.opcode, str.str());
James Kuszmaulcf324122023-01-14 14:07:17 -0800730 }
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800731#endif
James Kuszmaulcf324122023-01-14 14:07:17 -0800732}
733
734void WebSocket::SendFrames(
735 std::span<const Frame> frames,
736 std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
737 // If we're not open, emit an error and don't send the data
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800738 WS_DEBUG("SendFrames({})\n", frames.size());
James Kuszmaulcf324122023-01-14 14:07:17 -0800739 if (m_state != OPEN) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800740 SendError(frames, callback);
James Kuszmaulcf324122023-01-14 14:07:17 -0800741 return;
742 }
743
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800744 // Build request
745 auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(callback));
746 int numBytes = 0;
747 for (auto&& frame : frames) {
748 VerboseDebug(frame);
749 numBytes += req->m_frames.AddFrame(frame, m_server);
750 req->m_continueFrameOffs.emplace_back(numBytes);
751 req->m_userBufs.append(frame.data.begin(), frame.data.end());
752 }
753
754 if (m_writeInProgress) {
755 if (auto lastReq = m_lastWriteReq.lock()) {
756 // if write currently in progress, process as a continuation of that
757 m_lastWriteReq = req;
758 // make sure we're really at the end
759 while (lastReq->m_cont) {
760 lastReq = lastReq->m_cont;
761 }
762 lastReq->m_cont = std::move(req);
763 return;
764 }
765 }
766
767 m_writeInProgress = true;
768 m_curWriteReq = req;
769 m_lastWriteReq = req;
770 req->Send({});
771}
772
773std::span<const WebSocket::Frame> WebSocket::TrySendFrames(
774 std::span<const Frame> frames,
775 std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
776 // If we're not open, emit an error and don't send the data
777 if (m_state != WebSocket::OPEN) {
778 SendError(frames, callback);
779 return {};
780 }
781
782 // If something else is still in flight, don't send anything
783 if (m_writeInProgress) {
784 return frames;
785 }
786
787 return detail::TrySendFrames(
788 m_server, m_stream, frames,
789 [this](std::function<void(std::span<uv::Buffer>, uv::Error)>&& cb) {
790 auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(cb));
791 m_writeInProgress = true;
792 m_curWriteReq = req;
793 m_lastWriteReq = req;
794 return req;
795 },
796 std::move(callback));
797}
798
799void WebSocket::SendControl(
800 uint8_t opcode, std::span<const uv::Buffer> data,
801 std::function<void(std::span<uv::Buffer>, uv::Error)> callback) {
802 Frame frame{opcode, data};
803 // If we're not open, emit an error and don't send the data
804 if (m_state != WebSocket::OPEN) {
805 SendError({{frame}}, callback);
806 return;
807 }
808
809 // If nothing else is in flight, just use SendFrames()
810 std::shared_ptr<WriteReq> curReq = m_curWriteReq.lock();
811 if (!m_writeInProgress || !curReq) {
812 return SendFrames({{frame}}, std::move(callback));
813 }
814
815 // There's a write request in flight, but since this is a control frame, we
816 // want to send it as soon as we can, without waiting for all frames in that
817 // request (or any continuations) to be sent.
818 auto req = std::make_shared<WriteReq>(weak_from_this(), std::move(callback));
819 VerboseDebug(frame);
820 size_t numBytes = req->m_frames.AddFrame(frame, m_server);
821 req->m_userBufs.append(frame.data.begin(), frame.data.end());
822 req->m_continueFrameOffs.emplace_back(numBytes);
823 req->m_cont = curReq;
824 // There may be multiple control packets in flight; maintain in-order
825 // transmission. Linear search here is O(n^2), but should be pretty rare.
826 if (!curReq->m_controlCont) {
827 curReq->m_controlCont = std::move(req);
828 } else {
829 curReq = curReq->m_controlCont;
830 while (curReq->m_cont != req->m_cont) {
831 curReq = curReq->m_cont;
832 }
833 curReq->m_cont = std::move(req);
834 }
835}
836
837void WebSocket::SendError(
838 std::span<const Frame> frames,
839 const std::function<void(std::span<uv::Buffer>, uv::Error)>& callback) {
840 int err;
841 if (m_state == WebSocket::CONNECTING) {
842 err = UV_EAGAIN;
843 } else {
844 err = UV_ESHUTDOWN;
845 }
James Kuszmaulcf324122023-01-14 14:07:17 -0800846 SmallVector<uv::Buffer, 4> bufs;
847 for (auto&& frame : frames) {
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800848 bufs.append(frame.data.begin(), frame.data.end());
James Kuszmaulcf324122023-01-14 14:07:17 -0800849 }
James Kuszmaulb13e13f2023-11-22 20:44:04 -0800850 callback(bufs, uv::Error{err});
James Kuszmaulcf324122023-01-14 14:07:17 -0800851}