blob: 82c83bbf44ec0062f16828632f2e1bb5d2369f74 [file] [log] [blame]
Brian Silverman41cdd3e2019-01-19 19:48:58 -08001/*----------------------------------------------------------------------------*/
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -08002/* Copyright (c) 2018-2019 FIRST. All Rights Reserved. */
Brian Silverman41cdd3e2019-01-19 19:48:58 -08003/* Open Source Software - may be modified and shared by FRC teams. The code */
4/* must be accompanied by the FIRST BSD license file in the root directory of */
5/* the project. */
6/*----------------------------------------------------------------------------*/
7
8#include "wpi/WebSocket.h"
9
10#include <random>
11
12#include "wpi/Base64.h"
13#include "wpi/HttpParser.h"
14#include "wpi/SmallString.h"
15#include "wpi/SmallVector.h"
16#include "wpi/raw_uv_ostream.h"
17#include "wpi/sha1.h"
18#include "wpi/uv/Stream.h"
19
20using namespace wpi;
21
22namespace {
23class WebSocketWriteReq : public uv::WriteReq {
24 public:
25 explicit WebSocketWriteReq(
26 std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
27 finish.connect([=](uv::Error err) {
28 MutableArrayRef<uv::Buffer> bufs{m_bufs};
29 for (auto&& buf : bufs.slice(0, m_startUser)) buf.Deallocate();
30 callback(bufs.slice(m_startUser), err);
31 });
32 }
33
34 SmallVector<uv::Buffer, 4> m_bufs;
35 size_t m_startUser;
36};
37} // namespace
38
39class WebSocket::ClientHandshakeData {
40 public:
41 ClientHandshakeData() {
42 // key is a random nonce
43 static std::random_device rd;
44 static std::default_random_engine gen{rd()};
45 std::uniform_int_distribution<unsigned int> dist(0, 255);
46 char nonce[16]; // the nonce sent to the server
47 for (char& v : nonce) v = static_cast<char>(dist(gen));
48 raw_svector_ostream os(key);
49 Base64Encode(os, StringRef{nonce, 16});
50 }
51 ~ClientHandshakeData() {
52 if (auto t = timer.lock()) {
53 t->Stop();
54 t->Close();
55 }
56 }
57
58 SmallString<64> key; // the key sent to the server
59 SmallVector<std::string, 2> protocols; // valid protocols
60 HttpParser parser{HttpParser::kResponse}; // server response parser
61 bool hasUpgrade = false;
62 bool hasConnection = false;
63 bool hasAccept = false;
64 bool hasProtocol = false;
65
66 std::weak_ptr<uv::Timer> timer;
67};
68
69static StringRef AcceptHash(StringRef key, SmallVectorImpl<char>& buf) {
70 SHA1 hash;
71 hash.Update(key);
72 hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
73 SmallString<64> hashBuf;
74 return Base64Encode(hash.RawFinal(hashBuf), buf);
75}
76
77WebSocket::WebSocket(uv::Stream& stream, bool server, const private_init&)
78 : m_stream{stream}, m_server{server} {
79 // Connect closed and error signals to ourselves
80 m_stream.closed.connect([this]() { SetClosed(1006, "handle closed"); });
81 m_stream.error.connect([this](uv::Error err) {
82 Terminate(1006, "stream error: " + Twine(err.name()));
83 });
84
85 // Start reading
86 m_stream.StopRead(); // we may have been reading
87 m_stream.StartRead();
88 m_stream.data.connect(
89 [this](uv::Buffer& buf, size_t size) { HandleIncoming(buf, size); });
90 m_stream.end.connect(
91 [this]() { Terminate(1006, "remote end closed connection"); });
92}
93
94WebSocket::~WebSocket() {}
95
96std::shared_ptr<WebSocket> WebSocket::CreateClient(
97 uv::Stream& stream, const Twine& uri, const Twine& host,
98 ArrayRef<StringRef> protocols, const ClientOptions& options) {
99 auto ws = std::make_shared<WebSocket>(stream, false, private_init{});
100 stream.SetData(ws);
101 ws->StartClient(uri, host, protocols, options);
102 return ws;
103}
104
105std::shared_ptr<WebSocket> WebSocket::CreateServer(uv::Stream& stream,
106 StringRef key,
107 StringRef version,
108 StringRef protocol) {
109 auto ws = std::make_shared<WebSocket>(stream, true, private_init{});
110 stream.SetData(ws);
111 ws->StartServer(key, version, protocol);
112 return ws;
113}
114
115void WebSocket::Close(uint16_t code, const Twine& reason) {
116 SendClose(code, reason);
117 if (m_state != FAILED && m_state != CLOSED) m_state = CLOSING;
118}
119
120void WebSocket::Fail(uint16_t code, const Twine& reason) {
121 if (m_state == FAILED || m_state == CLOSED) return;
122 SendClose(code, reason);
123 SetClosed(code, reason, true);
124 Shutdown();
125}
126
127void WebSocket::Terminate(uint16_t code, const Twine& reason) {
128 if (m_state == FAILED || m_state == CLOSED) return;
129 SetClosed(code, reason);
130 Shutdown();
131}
132
133void WebSocket::StartClient(const Twine& uri, const Twine& host,
134 ArrayRef<StringRef> protocols,
135 const ClientOptions& options) {
136 // Create client handshake data
137 m_clientHandshake = std::make_unique<ClientHandshakeData>();
138
139 // Build client request
140 SmallVector<uv::Buffer, 4> bufs;
141 raw_uv_ostream os{bufs, 4096};
142
143 os << "GET " << uri << " HTTP/1.1\r\n";
144 os << "Host: " << host << "\r\n";
145 os << "Upgrade: websocket\r\n";
146 os << "Connection: Upgrade\r\n";
147 os << "Sec-WebSocket-Key: " << m_clientHandshake->key << "\r\n";
148 os << "Sec-WebSocket-Version: 13\r\n";
149
150 // protocols (if provided)
151 if (!protocols.empty()) {
152 os << "Sec-WebSocket-Protocol: ";
153 bool first = true;
154 for (auto protocol : protocols) {
155 if (!first)
156 os << ", ";
157 else
158 first = false;
159 os << protocol;
160 // also save for later checking against server response
161 m_clientHandshake->protocols.emplace_back(protocol);
162 }
163 os << "\r\n";
164 }
165
166 // other headers
167 for (auto&& header : options.extraHeaders)
168 os << header.first << ": " << header.second << "\r\n";
169
170 // finish headers
171 os << "\r\n";
172
173 // Send client request
174 m_stream.Write(bufs, [](auto bufs, uv::Error) {
175 for (auto& buf : bufs) buf.Deallocate();
176 });
177
178 // Set up client response handling
179 m_clientHandshake->parser.status.connect([this](StringRef status) {
180 unsigned int code = m_clientHandshake->parser.GetStatusCode();
181 if (code != 101) Terminate(code, status);
182 });
183 m_clientHandshake->parser.header.connect(
184 [this](StringRef name, StringRef value) {
185 value = value.trim();
186 if (name.equals_lower("upgrade")) {
187 if (!value.equals_lower("websocket"))
188 return Terminate(1002, "invalid upgrade response value");
189 m_clientHandshake->hasUpgrade = true;
190 } else if (name.equals_lower("connection")) {
191 if (!value.equals_lower("upgrade"))
192 return Terminate(1002, "invalid connection response value");
193 m_clientHandshake->hasConnection = true;
194 } else if (name.equals_lower("sec-websocket-accept")) {
195 // Check against expected response
196 SmallString<64> acceptBuf;
197 if (!value.equals(AcceptHash(m_clientHandshake->key, acceptBuf)))
198 return Terminate(1002, "invalid accept key");
199 m_clientHandshake->hasAccept = true;
200 } else if (name.equals_lower("sec-websocket-extensions")) {
201 // No extensions are supported
202 if (!value.empty()) return Terminate(1010, "unsupported extension");
203 } else if (name.equals_lower("sec-websocket-protocol")) {
204 // Make sure it was one of the provided protocols
205 bool match = false;
206 for (auto&& protocol : m_clientHandshake->protocols) {
207 if (value.equals_lower(protocol)) {
208 match = true;
209 break;
210 }
211 }
212 if (!match) return Terminate(1003, "unsupported protocol");
213 m_clientHandshake->hasProtocol = true;
214 m_protocol = value;
215 }
216 });
217 m_clientHandshake->parser.headersComplete.connect([this](bool) {
218 if (!m_clientHandshake->hasUpgrade || !m_clientHandshake->hasConnection ||
219 !m_clientHandshake->hasAccept ||
220 (!m_clientHandshake->hasProtocol &&
221 !m_clientHandshake->protocols.empty())) {
222 return Terminate(1002, "invalid response");
223 }
224 if (m_state == CONNECTING) {
225 m_state = OPEN;
226 open(m_protocol);
227 }
228 });
229
230 // Start handshake timer if a timeout was specified
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800231 if (options.handshakeTimeout != (uv::Timer::Time::max)()) {
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800232 auto timer = uv::Timer::Create(m_stream.GetLoopRef());
233 timer->timeout.connect(
234 [this]() { Terminate(1006, "connection timed out"); });
235 timer->Start(options.handshakeTimeout);
236 m_clientHandshake->timer = timer;
237 }
238}
239
240void WebSocket::StartServer(StringRef key, StringRef version,
241 StringRef protocol) {
242 m_protocol = protocol;
243
244 // Build server response
245 SmallVector<uv::Buffer, 4> bufs;
246 raw_uv_ostream os{bufs, 4096};
247
248 // Handle unsupported version
249 if (version != "13") {
250 os << "HTTP/1.1 426 Upgrade Required\r\n";
251 os << "Upgrade: WebSocket\r\n";
252 os << "Sec-WebSocket-Version: 13\r\n\r\n";
253 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
254 for (auto& buf : bufs) buf.Deallocate();
255 // XXX: Should we support sending a new handshake on the same connection?
256 // XXX: "this->" is required by GCC 5.5 (bug)
257 this->Terminate(1003, "unsupported protocol version");
258 });
259 return;
260 }
261
262 os << "HTTP/1.1 101 Switching Protocols\r\n";
263 os << "Upgrade: websocket\r\n";
264 os << "Connection: Upgrade\r\n";
265
266 // accept hash
267 SmallString<64> acceptBuf;
268 os << "Sec-WebSocket-Accept: " << AcceptHash(key, acceptBuf) << "\r\n";
269
270 if (!protocol.empty()) os << "Sec-WebSocket-Protocol: " << protocol << "\r\n";
271
272 // end headers
273 os << "\r\n";
274
275 // Send server response
276 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
277 for (auto& buf : bufs) buf.Deallocate();
278 if (m_state == CONNECTING) {
279 m_state = OPEN;
280 open(m_protocol);
281 }
282 });
283}
284
285void WebSocket::SendClose(uint16_t code, const Twine& reason) {
286 SmallVector<uv::Buffer, 4> bufs;
287 if (code != 1005) {
288 raw_uv_ostream os{bufs, 4096};
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800289 const uint8_t codeMsb[] = {static_cast<uint8_t>((code >> 8) & 0xff),
290 static_cast<uint8_t>(code & 0xff)};
291 os << ArrayRef<uint8_t>(codeMsb);
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800292 reason.print(os);
293 }
294 Send(kFlagFin | kOpClose, bufs, [](auto bufs, uv::Error) {
295 for (auto&& buf : bufs) buf.Deallocate();
296 });
297}
298
299void WebSocket::SetClosed(uint16_t code, const Twine& reason, bool failed) {
300 if (m_state == FAILED || m_state == CLOSED) return;
301 m_state = failed ? FAILED : CLOSED;
302 SmallString<64> reasonBuf;
303 closed(code, reason.toStringRef(reasonBuf));
304}
305
306void WebSocket::Shutdown() {
307 m_stream.Shutdown([this] { m_stream.Close(); });
308}
309
310void WebSocket::HandleIncoming(uv::Buffer& buf, size_t size) {
311 // ignore incoming data if we're failed or closed
312 if (m_state == FAILED || m_state == CLOSED) return;
313
314 StringRef data{buf.base, size};
315
316 // Handle connecting state (mainly on client)
317 if (m_state == CONNECTING) {
318 if (m_clientHandshake) {
319 data = m_clientHandshake->parser.Execute(data);
320 // check for parser failure
321 if (m_clientHandshake->parser.HasError())
322 return Terminate(1003, "invalid response");
323 if (m_state != OPEN) return; // not done with handshake yet
324
325 // we're done with the handshake, so release its memory
326 m_clientHandshake.reset();
327
328 // fall through to process additional data after handshake
329 } else {
330 return Terminate(1003, "got data on server before response");
331 }
332 }
333
334 // Message processing
335 while (!data.empty()) {
336 if (m_frameSize == UINT64_MAX) {
337 // Need at least two bytes to determine header length
338 if (m_header.size() < 2u) {
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800339 size_t toCopy = (std::min)(2u - m_header.size(), data.size());
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800340 m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
341 data = data.drop_front(toCopy);
342 if (m_header.size() < 2u) return; // need more data
343
344 // Validate RSV bits are zero
345 if ((m_header[0] & 0x70) != 0) return Fail(1002, "nonzero RSV");
346 }
347
348 // Once we have first two bytes, we can calculate the header size
349 if (m_headerSize == 0) {
350 m_headerSize = 2;
351 uint8_t len = m_header[1] & kLenMask;
352 if (len == 126)
353 m_headerSize += 2;
354 else if (len == 127)
355 m_headerSize += 8;
356 bool masking = (m_header[1] & kFlagMasking) != 0;
357 if (masking) m_headerSize += 4; // masking key
358 // On server side, incoming messages MUST be masked
359 // On client side, incoming messages MUST NOT be masked
360 if (m_server && !masking) return Fail(1002, "client data not masked");
361 if (!m_server && masking) return Fail(1002, "server data masked");
362 }
363
364 // Need to complete header to calculate message size
365 if (m_header.size() < m_headerSize) {
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800366 size_t toCopy = (std::min)(m_headerSize - m_header.size(), data.size());
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800367 m_header.append(data.bytes_begin(), data.bytes_begin() + toCopy);
368 data = data.drop_front(toCopy);
369 if (m_header.size() < m_headerSize) return; // need more data
370 }
371
372 if (m_header.size() >= m_headerSize) {
373 // get payload length
374 uint8_t len = m_header[1] & kLenMask;
375 if (len == 126)
376 m_frameSize = (static_cast<uint16_t>(m_header[2]) << 8) |
377 static_cast<uint16_t>(m_header[3]);
378 else if (len == 127)
379 m_frameSize = (static_cast<uint64_t>(m_header[2]) << 56) |
380 (static_cast<uint64_t>(m_header[3]) << 48) |
381 (static_cast<uint64_t>(m_header[4]) << 40) |
382 (static_cast<uint64_t>(m_header[5]) << 32) |
383 (static_cast<uint64_t>(m_header[6]) << 24) |
384 (static_cast<uint64_t>(m_header[7]) << 16) |
385 (static_cast<uint64_t>(m_header[8]) << 8) |
386 static_cast<uint64_t>(m_header[9]);
387 else
388 m_frameSize = len;
389
390 // limit maximum size
391 if ((m_payload.size() + m_frameSize) > m_maxMessageSize)
392 return Fail(1009, "message too large");
393 }
394 }
395
396 if (m_frameSize != UINT64_MAX) {
397 size_t need = m_frameStart + m_frameSize - m_payload.size();
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800398 size_t toCopy = (std::min)(need, data.size());
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800399 m_payload.append(data.bytes_begin(), data.bytes_begin() + toCopy);
400 data = data.drop_front(toCopy);
401 need -= toCopy;
402 if (need == 0) {
403 // We have a complete frame
404 // If the message had masking, unmask it
405 if ((m_header[1] & kFlagMasking) != 0) {
406 uint8_t key[4] = {
407 m_header[m_headerSize - 4], m_header[m_headerSize - 3],
408 m_header[m_headerSize - 2], m_header[m_headerSize - 1]};
409 int n = 0;
410 for (uint8_t& ch :
411 MutableArrayRef<uint8_t>{m_payload}.slice(m_frameStart)) {
412 ch ^= key[n++];
413 if (n >= 4) n = 0;
414 }
415 }
416
417 // Handle message
418 bool fin = (m_header[0] & kFlagFin) != 0;
419 uint8_t opcode = m_header[0] & kOpMask;
420 switch (opcode) {
421 case kOpCont:
422 switch (m_fragmentOpcode) {
423 case kOpText:
424 if (!m_combineFragments || fin)
425 text(StringRef{reinterpret_cast<char*>(m_payload.data()),
426 m_payload.size()},
427 fin);
428 break;
429 case kOpBinary:
430 if (!m_combineFragments || fin) binary(m_payload, fin);
431 break;
432 default:
433 // no preceding message?
434 return Fail(1002, "invalid continuation message");
435 }
436 if (fin) m_fragmentOpcode = 0;
437 break;
438 case kOpText:
439 if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
440 if (!m_combineFragments || fin)
441 text(StringRef{reinterpret_cast<char*>(m_payload.data()),
442 m_payload.size()},
443 fin);
444 if (!fin) m_fragmentOpcode = opcode;
445 break;
446 case kOpBinary:
447 if (m_fragmentOpcode != 0) return Fail(1002, "incomplete fragment");
448 if (!m_combineFragments || fin) binary(m_payload, fin);
449 if (!fin) m_fragmentOpcode = opcode;
450 break;
451 case kOpClose: {
452 uint16_t code;
453 StringRef reason;
454 if (!fin) {
455 code = 1002;
456 reason = "cannot fragment control frames";
457 } else if (m_payload.size() < 2) {
458 code = 1005;
459 } else {
460 code = (static_cast<uint16_t>(m_payload[0]) << 8) |
461 static_cast<uint16_t>(m_payload[1]);
462 reason = StringRef{reinterpret_cast<char*>(m_payload.data()),
463 m_payload.size()}
464 .drop_front(2);
465 }
466 // Echo the close if we didn't previously send it
467 if (m_state != CLOSING) SendClose(code, reason);
468 SetClosed(code, reason);
469 // If we're the server, shutdown the connection.
470 if (m_server) Shutdown();
471 break;
472 }
473 case kOpPing:
474 if (!fin) return Fail(1002, "cannot fragment control frames");
475 ping(m_payload);
476 break;
477 case kOpPong:
478 if (!fin) return Fail(1002, "cannot fragment control frames");
479 pong(m_payload);
480 break;
481 default:
482 return Fail(1002, "invalid message opcode");
483 }
484
485 // Prepare for next message
486 m_header.clear();
487 m_headerSize = 0;
488 if (!m_combineFragments || fin) m_payload.clear();
489 m_frameStart = m_payload.size();
490 m_frameSize = UINT64_MAX;
491 }
492 }
493 }
494}
495
496void WebSocket::Send(
497 uint8_t opcode, ArrayRef<uv::Buffer> data,
498 std::function<void(MutableArrayRef<uv::Buffer>, uv::Error)> callback) {
499 // If we're not open, emit an error and don't send the data
500 if (m_state != OPEN) {
501 int err;
502 if (m_state == CONNECTING)
503 err = UV_EAGAIN;
504 else
505 err = UV_ESHUTDOWN;
506 SmallVector<uv::Buffer, 4> bufs{data.begin(), data.end()};
507 callback(bufs, uv::Error{err});
508 return;
509 }
510
511 auto req = std::make_shared<WebSocketWriteReq>(callback);
512 raw_uv_ostream os{req->m_bufs, 4096};
513
514 // opcode (includes FIN bit)
515 os << static_cast<unsigned char>(opcode);
516
517 // payload length
518 uint64_t size = 0;
519 for (auto&& buf : data) size += buf.len;
520 if (size < 126) {
521 os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | size);
522 } else if (size <= 0xffff) {
523 os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 126);
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800524 const uint8_t sizeMsb[] = {static_cast<uint8_t>((size >> 8) & 0xff),
525 static_cast<uint8_t>(size & 0xff)};
526 os << ArrayRef<uint8_t>(sizeMsb);
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800527 } else {
528 os << static_cast<unsigned char>((m_server ? 0x00 : kFlagMasking) | 127);
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800529 const uint8_t sizeMsb[] = {static_cast<uint8_t>((size >> 56) & 0xff),
530 static_cast<uint8_t>((size >> 48) & 0xff),
531 static_cast<uint8_t>((size >> 40) & 0xff),
532 static_cast<uint8_t>((size >> 32) & 0xff),
533 static_cast<uint8_t>((size >> 24) & 0xff),
534 static_cast<uint8_t>((size >> 16) & 0xff),
535 static_cast<uint8_t>((size >> 8) & 0xff),
536 static_cast<uint8_t>(size & 0xff)};
537 os << ArrayRef<uint8_t>(sizeMsb);
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800538 }
539
540 // clients need to mask the input data
541 if (!m_server) {
542 // generate masking key
543 static std::random_device rd;
544 static std::default_random_engine gen{rd()};
545 std::uniform_int_distribution<unsigned int> dist(0, 255);
546 uint8_t key[4];
547 for (uint8_t& v : key) v = dist(gen);
548 os << ArrayRef<uint8_t>{key, 4};
549 // copy and mask data
550 int n = 0;
551 for (auto&& buf : data) {
552 for (auto&& ch : buf.data()) {
553 os << static_cast<unsigned char>(static_cast<uint8_t>(ch) ^ key[n++]);
554 if (n >= 4) n = 0;
555 }
556 }
557 req->m_startUser = req->m_bufs.size();
558 req->m_bufs.append(data.begin(), data.end());
559 // don't send the user bufs as we copied their data
560 m_stream.Write(ArrayRef<uv::Buffer>{req->m_bufs}.slice(0, req->m_startUser),
561 req);
562 } else {
563 // servers can just send the buffers directly without masking
564 req->m_startUser = req->m_bufs.size();
565 req->m_bufs.append(data.begin(), data.end());
566 m_stream.Write(req->m_bufs, req);
567 }
568}