James Kuszmaul | cf32412 | 2023-01-14 14:07:17 -0800 | [diff] [blame^] | 1 | // Copyright (c) FIRST and other WPILib contributors. |
| 2 | // Open Source Software; you can modify and/or share it under the terms of |
| 3 | // the WPILib BSD license file in the root directory of this project. |
| 4 | |
| 5 | #include "wpinet/WebSocketServer.h" |
| 6 | |
| 7 | #include <utility> |
| 8 | |
| 9 | #include <wpi/StringExtras.h> |
| 10 | #include <wpi/fmt/raw_ostream.h> |
| 11 | |
| 12 | #include "wpinet/raw_uv_ostream.h" |
| 13 | #include "wpinet/uv/Buffer.h" |
| 14 | #include "wpinet/uv/Stream.h" |
| 15 | |
| 16 | using namespace wpi; |
| 17 | |
| 18 | WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) { |
| 19 | req.header.connect([this](std::string_view name, std::string_view value) { |
| 20 | if (equals_lower(name, "host")) { |
| 21 | m_gotHost = true; |
| 22 | } else if (equals_lower(name, "upgrade")) { |
| 23 | if (equals_lower(value, "websocket")) { |
| 24 | m_websocket = true; |
| 25 | } |
| 26 | } else if (equals_lower(name, "sec-websocket-key")) { |
| 27 | m_key = value; |
| 28 | } else if (equals_lower(name, "sec-websocket-version")) { |
| 29 | m_version = value; |
| 30 | } else if (equals_lower(name, "sec-websocket-protocol")) { |
| 31 | // Protocols are comma delimited, repeated headers add to list |
| 32 | SmallVector<std::string_view, 2> protocols; |
| 33 | split(value, protocols, ",", -1, false); |
| 34 | for (auto protocol : protocols) { |
| 35 | protocol = trim(protocol); |
| 36 | if (!protocol.empty()) { |
| 37 | m_protocols.emplace_back(protocol); |
| 38 | } |
| 39 | } |
| 40 | } |
| 41 | }); |
| 42 | req.headersComplete.connect([&req, this](bool) { |
| 43 | if (req.IsUpgrade() && IsUpgrade()) { |
| 44 | upgrade(); |
| 45 | } |
| 46 | }); |
| 47 | } |
| 48 | |
| 49 | std::pair<bool, std::string_view> WebSocketServerHelper::MatchProtocol( |
| 50 | std::span<const std::string_view> protocols) { |
| 51 | if (protocols.empty() && m_protocols.empty()) { |
| 52 | return {true, {}}; |
| 53 | } |
| 54 | for (auto protocol : protocols) { |
| 55 | for (auto&& clientProto : m_protocols) { |
| 56 | if (protocol == clientProto) { |
| 57 | return {true, protocol}; |
| 58 | } |
| 59 | } |
| 60 | } |
| 61 | return {false, {}}; |
| 62 | } |
| 63 | |
| 64 | WebSocketServer::WebSocketServer(uv::Stream& stream, |
| 65 | std::span<const std::string_view> protocols, |
| 66 | ServerOptions options, const private_init&) |
| 67 | : m_stream{stream}, |
| 68 | m_helper{m_req}, |
| 69 | m_protocols{protocols.begin(), protocols.end()}, |
| 70 | m_options{std::move(options)} { |
| 71 | // Header handling |
| 72 | m_req.header.connect([this](std::string_view name, std::string_view value) { |
| 73 | if (equals_lower(name, "host")) { |
| 74 | if (m_options.checkHost) { |
| 75 | if (!m_options.checkHost(value)) { |
| 76 | Abort(401, "Unrecognized Host"); |
| 77 | } |
| 78 | } |
| 79 | } |
| 80 | }); |
| 81 | m_req.url.connect([this](std::string_view name) { |
| 82 | if (m_options.checkUrl) { |
| 83 | if (!m_options.checkUrl(name)) { |
| 84 | Abort(404, "Not Found"); |
| 85 | } |
| 86 | } |
| 87 | }); |
| 88 | m_req.headersComplete.connect([this](bool) { |
| 89 | // We only accept websocket connections |
| 90 | if (!m_helper.IsUpgrade() || !m_req.IsUpgrade()) { |
| 91 | Abort(426, "Upgrade Required"); |
| 92 | } |
| 93 | }); |
| 94 | |
| 95 | // Handle upgrade event |
| 96 | m_helper.upgrade.connect([this] { |
| 97 | if (m_aborted) { |
| 98 | return; |
| 99 | } |
| 100 | |
| 101 | // Negotiate sub-protocol |
| 102 | SmallVector<std::string_view, 2> protocols{m_protocols.begin(), |
| 103 | m_protocols.end()}; |
| 104 | std::string_view protocol = m_helper.MatchProtocol(protocols).second; |
| 105 | |
| 106 | // Disconnect our header reader |
| 107 | m_dataConn.disconnect(); |
| 108 | |
| 109 | // Accepting the stream may destroy this (as it replaces the stream user |
| 110 | // data), so grab a shared pointer first. |
| 111 | auto self = shared_from_this(); |
| 112 | |
| 113 | // Accept the upgrade |
| 114 | auto ws = m_helper.Accept(m_stream, protocol); |
| 115 | |
| 116 | // Connect the websocket open event to our connected event. |
| 117 | ws->open.connect_extended( |
| 118 | [self, s = ws.get()](auto conn, std::string_view) { |
| 119 | self->connected(self->m_req.GetUrl(), *s); |
| 120 | conn.disconnect(); // one-shot |
| 121 | }); |
| 122 | }); |
| 123 | |
| 124 | // Set up stream |
| 125 | stream.StartRead(); |
| 126 | m_dataConn = |
| 127 | stream.data.connect_connection([this](uv::Buffer& buf, size_t size) { |
| 128 | if (m_aborted) { |
| 129 | return; |
| 130 | } |
| 131 | m_req.Execute(std::string_view{buf.base, size}); |
| 132 | if (m_req.HasError()) { |
| 133 | Abort(400, "Bad Request"); |
| 134 | } |
| 135 | }); |
| 136 | m_errorConn = |
| 137 | stream.error.connect_connection([this](uv::Error) { m_stream.Close(); }); |
| 138 | m_endConn = stream.end.connect_connection([this] { m_stream.Close(); }); |
| 139 | } |
| 140 | |
| 141 | std::shared_ptr<WebSocketServer> WebSocketServer::Create( |
| 142 | uv::Stream& stream, std::span<const std::string_view> protocols, |
| 143 | const ServerOptions& options) { |
| 144 | auto server = std::make_shared<WebSocketServer>(stream, protocols, options, |
| 145 | private_init{}); |
| 146 | stream.SetData(server); |
| 147 | return server; |
| 148 | } |
| 149 | |
| 150 | void WebSocketServer::Abort(uint16_t code, std::string_view reason) { |
| 151 | if (m_aborted) { |
| 152 | return; |
| 153 | } |
| 154 | m_aborted = true; |
| 155 | |
| 156 | // Build response |
| 157 | SmallVector<uv::Buffer, 4> bufs; |
| 158 | raw_uv_ostream os{bufs, 1024}; |
| 159 | |
| 160 | // Handle unsupported version |
| 161 | fmt::print(os, "HTTP/1.1 {} {}\r\n", code, reason); |
| 162 | if (code == 426) { |
| 163 | os << "Upgrade: WebSocket\r\n"; |
| 164 | } |
| 165 | os << "\r\n"; |
| 166 | m_stream.Write(bufs, [this](auto bufs, uv::Error) { |
| 167 | for (auto& buf : bufs) { |
| 168 | buf.Deallocate(); |
| 169 | } |
| 170 | m_stream.Shutdown([this] { m_stream.Close(); }); |
| 171 | }); |
| 172 | } |