Brian Silverman | 41cdd3e | 2019-01-19 19:48:58 -0800 | [diff] [blame^] | 1 | /*----------------------------------------------------------------------------*/ |
| 2 | /* Copyright (c) 2018 FIRST. All Rights Reserved. */ |
| 3 | /* Open Source Software - may be modified and shared by FRC teams. The code */ |
| 4 | /* must be accompanied by the FIRST BSD license file in the root directory of */ |
| 5 | /* the project. */ |
| 6 | /*----------------------------------------------------------------------------*/ |
| 7 | |
| 8 | #include "wpi/WebSocketServer.h" |
| 9 | |
| 10 | #include "wpi/raw_uv_ostream.h" |
| 11 | #include "wpi/uv/Buffer.h" |
| 12 | #include "wpi/uv/Stream.h" |
| 13 | |
| 14 | using namespace wpi; |
| 15 | |
| 16 | WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) { |
| 17 | req.header.connect([this](StringRef name, StringRef value) { |
| 18 | if (name.equals_lower("host")) { |
| 19 | m_gotHost = true; |
| 20 | } else if (name.equals_lower("upgrade")) { |
| 21 | if (value.equals_lower("websocket")) m_websocket = true; |
| 22 | } else if (name.equals_lower("sec-websocket-key")) { |
| 23 | m_key = value; |
| 24 | } else if (name.equals_lower("sec-websocket-version")) { |
| 25 | m_version = value; |
| 26 | } else if (name.equals_lower("sec-websocket-protocol")) { |
| 27 | // Protocols are comma delimited, repeated headers add to list |
| 28 | SmallVector<StringRef, 2> protocols; |
| 29 | value.split(protocols, ",", -1, false); |
| 30 | for (auto protocol : protocols) { |
| 31 | protocol = protocol.trim(); |
| 32 | if (!protocol.empty()) m_protocols.emplace_back(protocol); |
| 33 | } |
| 34 | } |
| 35 | }); |
| 36 | req.headersComplete.connect([&req, this](bool) { |
| 37 | if (req.IsUpgrade() && IsUpgrade()) upgrade(); |
| 38 | }); |
| 39 | } |
| 40 | |
| 41 | std::pair<bool, StringRef> WebSocketServerHelper::MatchProtocol( |
| 42 | ArrayRef<StringRef> protocols) { |
| 43 | if (protocols.empty() && m_protocols.empty()) |
| 44 | return std::make_pair(true, StringRef{}); |
| 45 | for (auto protocol : protocols) { |
| 46 | for (auto&& clientProto : m_protocols) { |
| 47 | if (protocol == clientProto) return std::make_pair(true, protocol); |
| 48 | } |
| 49 | } |
| 50 | return std::make_pair(false, StringRef{}); |
| 51 | } |
| 52 | |
| 53 | WebSocketServer::WebSocketServer(uv::Stream& stream, |
| 54 | ArrayRef<StringRef> protocols, |
| 55 | const ServerOptions& options, |
| 56 | const private_init&) |
| 57 | : m_stream{stream}, |
| 58 | m_helper{m_req}, |
| 59 | m_protocols{protocols.begin(), protocols.end()}, |
| 60 | m_options{options} { |
| 61 | // Header handling |
| 62 | m_req.header.connect([this](StringRef name, StringRef value) { |
| 63 | if (name.equals_lower("host")) { |
| 64 | if (m_options.checkHost) { |
| 65 | if (!m_options.checkHost(value)) Abort(401, "Unrecognized Host"); |
| 66 | } |
| 67 | } |
| 68 | }); |
| 69 | m_req.url.connect([this](StringRef name) { |
| 70 | if (m_options.checkUrl) { |
| 71 | if (!m_options.checkUrl(name)) Abort(404, "Not Found"); |
| 72 | } |
| 73 | }); |
| 74 | m_req.headersComplete.connect([this](bool) { |
| 75 | // We only accept websocket connections |
| 76 | if (!m_helper.IsUpgrade() || !m_req.IsUpgrade()) |
| 77 | Abort(426, "Upgrade Required"); |
| 78 | }); |
| 79 | |
| 80 | // Handle upgrade event |
| 81 | m_helper.upgrade.connect([this] { |
| 82 | if (m_aborted) return; |
| 83 | |
| 84 | // Negotiate sub-protocol |
| 85 | SmallVector<StringRef, 2> protocols{m_protocols.begin(), m_protocols.end()}; |
| 86 | StringRef protocol = m_helper.MatchProtocol(protocols).second; |
| 87 | |
| 88 | // Disconnect our header reader |
| 89 | m_dataConn.disconnect(); |
| 90 | |
| 91 | // Accepting the stream may destroy this (as it replaces the stream user |
| 92 | // data), so grab a shared pointer first. |
| 93 | auto self = shared_from_this(); |
| 94 | |
| 95 | // Accept the upgrade |
| 96 | auto ws = m_helper.Accept(m_stream, protocol); |
| 97 | |
| 98 | // Connect the websocket open event to our connected event. |
| 99 | ws->open.connect_extended([ self, s = ws.get() ](auto conn, StringRef) { |
| 100 | self->connected(self->m_req.GetUrl(), *s); |
| 101 | conn.disconnect(); // one-shot |
| 102 | }); |
| 103 | }); |
| 104 | |
| 105 | // Set up stream |
| 106 | stream.StartRead(); |
| 107 | m_dataConn = |
| 108 | stream.data.connect_connection([this](uv::Buffer& buf, size_t size) { |
| 109 | if (m_aborted) return; |
| 110 | m_req.Execute(StringRef{buf.base, size}); |
| 111 | if (m_req.HasError()) Abort(400, "Bad Request"); |
| 112 | }); |
| 113 | m_errorConn = |
| 114 | stream.error.connect_connection([this](uv::Error) { m_stream.Close(); }); |
| 115 | m_endConn = stream.end.connect_connection([this] { m_stream.Close(); }); |
| 116 | } |
| 117 | |
| 118 | std::shared_ptr<WebSocketServer> WebSocketServer::Create( |
| 119 | uv::Stream& stream, ArrayRef<StringRef> protocols, |
| 120 | const ServerOptions& options) { |
| 121 | auto server = std::make_shared<WebSocketServer>(stream, protocols, options, |
| 122 | private_init{}); |
| 123 | stream.SetData(server); |
| 124 | return server; |
| 125 | } |
| 126 | |
| 127 | void WebSocketServer::Abort(uint16_t code, StringRef reason) { |
| 128 | if (m_aborted) return; |
| 129 | m_aborted = true; |
| 130 | |
| 131 | // Build response |
| 132 | SmallVector<uv::Buffer, 4> bufs; |
| 133 | raw_uv_ostream os{bufs, 1024}; |
| 134 | |
| 135 | // Handle unsupported version |
| 136 | os << "HTTP/1.1 " << code << ' ' << reason << "\r\n"; |
| 137 | if (code == 426) os << "Upgrade: WebSocket\r\n"; |
| 138 | os << "\r\n"; |
| 139 | m_stream.Write(bufs, [this](auto bufs, uv::Error) { |
| 140 | for (auto& buf : bufs) buf.Deallocate(); |
| 141 | m_stream.Shutdown([this] { m_stream.Close(); }); |
| 142 | }); |
| 143 | } |