blob: 056cb62099325a324f0833f77c264280424d507a [file] [log] [blame]
Brian Silverman41cdd3e2019-01-19 19:48:58 -08001/*----------------------------------------------------------------------------*/
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -08002/* Copyright (c) 2018-2019 FIRST. All Rights Reserved. */
Brian Silverman41cdd3e2019-01-19 19:48:58 -08003/* Open Source Software - may be modified and shared by FRC teams. The code */
4/* must be accompanied by the FIRST BSD license file in the root directory of */
5/* the project. */
6/*----------------------------------------------------------------------------*/
7
8#include "wpi/WebSocketServer.h"
9
10#include "wpi/raw_uv_ostream.h"
11#include "wpi/uv/Buffer.h"
12#include "wpi/uv/Stream.h"
13
14using namespace wpi;
15
16WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) {
17 req.header.connect([this](StringRef name, StringRef value) {
18 if (name.equals_lower("host")) {
19 m_gotHost = true;
20 } else if (name.equals_lower("upgrade")) {
21 if (value.equals_lower("websocket")) m_websocket = true;
22 } else if (name.equals_lower("sec-websocket-key")) {
23 m_key = value;
24 } else if (name.equals_lower("sec-websocket-version")) {
25 m_version = value;
26 } else if (name.equals_lower("sec-websocket-protocol")) {
27 // Protocols are comma delimited, repeated headers add to list
28 SmallVector<StringRef, 2> protocols;
29 value.split(protocols, ",", -1, false);
30 for (auto protocol : protocols) {
31 protocol = protocol.trim();
32 if (!protocol.empty()) m_protocols.emplace_back(protocol);
33 }
34 }
35 });
36 req.headersComplete.connect([&req, this](bool) {
37 if (req.IsUpgrade() && IsUpgrade()) upgrade();
38 });
39}
40
41std::pair<bool, StringRef> WebSocketServerHelper::MatchProtocol(
42 ArrayRef<StringRef> protocols) {
43 if (protocols.empty() && m_protocols.empty())
44 return std::make_pair(true, StringRef{});
45 for (auto protocol : protocols) {
46 for (auto&& clientProto : m_protocols) {
47 if (protocol == clientProto) return std::make_pair(true, protocol);
48 }
49 }
50 return std::make_pair(false, StringRef{});
51}
52
53WebSocketServer::WebSocketServer(uv::Stream& stream,
54 ArrayRef<StringRef> protocols,
55 const ServerOptions& options,
56 const private_init&)
57 : m_stream{stream},
58 m_helper{m_req},
59 m_protocols{protocols.begin(), protocols.end()},
60 m_options{options} {
61 // Header handling
62 m_req.header.connect([this](StringRef name, StringRef value) {
63 if (name.equals_lower("host")) {
64 if (m_options.checkHost) {
65 if (!m_options.checkHost(value)) Abort(401, "Unrecognized Host");
66 }
67 }
68 });
69 m_req.url.connect([this](StringRef name) {
70 if (m_options.checkUrl) {
71 if (!m_options.checkUrl(name)) Abort(404, "Not Found");
72 }
73 });
74 m_req.headersComplete.connect([this](bool) {
75 // We only accept websocket connections
76 if (!m_helper.IsUpgrade() || !m_req.IsUpgrade())
77 Abort(426, "Upgrade Required");
78 });
79
80 // Handle upgrade event
81 m_helper.upgrade.connect([this] {
82 if (m_aborted) return;
83
84 // Negotiate sub-protocol
85 SmallVector<StringRef, 2> protocols{m_protocols.begin(), m_protocols.end()};
86 StringRef protocol = m_helper.MatchProtocol(protocols).second;
87
88 // Disconnect our header reader
89 m_dataConn.disconnect();
90
91 // Accepting the stream may destroy this (as it replaces the stream user
92 // data), so grab a shared pointer first.
93 auto self = shared_from_this();
94
95 // Accept the upgrade
96 auto ws = m_helper.Accept(m_stream, protocol);
97
98 // Connect the websocket open event to our connected event.
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -080099 ws->open.connect_extended([self, s = ws.get()](auto conn, StringRef) {
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800100 self->connected(self->m_req.GetUrl(), *s);
101 conn.disconnect(); // one-shot
102 });
103 });
104
105 // Set up stream
106 stream.StartRead();
107 m_dataConn =
108 stream.data.connect_connection([this](uv::Buffer& buf, size_t size) {
109 if (m_aborted) return;
110 m_req.Execute(StringRef{buf.base, size});
111 if (m_req.HasError()) Abort(400, "Bad Request");
112 });
113 m_errorConn =
114 stream.error.connect_connection([this](uv::Error) { m_stream.Close(); });
115 m_endConn = stream.end.connect_connection([this] { m_stream.Close(); });
116}
117
118std::shared_ptr<WebSocketServer> WebSocketServer::Create(
119 uv::Stream& stream, ArrayRef<StringRef> protocols,
120 const ServerOptions& options) {
121 auto server = std::make_shared<WebSocketServer>(stream, protocols, options,
122 private_init{});
123 stream.SetData(server);
124 return server;
125}
126
127void WebSocketServer::Abort(uint16_t code, StringRef reason) {
128 if (m_aborted) return;
129 m_aborted = true;
130
131 // Build response
132 SmallVector<uv::Buffer, 4> bufs;
133 raw_uv_ostream os{bufs, 1024};
134
135 // Handle unsupported version
136 os << "HTTP/1.1 " << code << ' ' << reason << "\r\n";
137 if (code == 426) os << "Upgrade: WebSocket\r\n";
138 os << "\r\n";
139 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
140 for (auto& buf : bufs) buf.Deallocate();
141 m_stream.Shutdown([this] { m_stream.Close(); });
142 });
143}