blob: 09424bb311296f3346dc0676db2dd8ff13f4dc0d [file] [log] [blame]
James Kuszmaulcf324122023-01-14 14:07:17 -08001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
4
5#include "wpinet/WebSocketServer.h"
6
7#include <utility>
8
9#include <wpi/StringExtras.h>
10#include <wpi/fmt/raw_ostream.h>
11
12#include "wpinet/raw_uv_ostream.h"
13#include "wpinet/uv/Buffer.h"
14#include "wpinet/uv/Stream.h"
15
16using namespace wpi;
17
18WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) {
19 req.header.connect([this](std::string_view name, std::string_view value) {
20 if (equals_lower(name, "host")) {
21 m_gotHost = true;
22 } else if (equals_lower(name, "upgrade")) {
23 if (equals_lower(value, "websocket")) {
24 m_websocket = true;
25 }
26 } else if (equals_lower(name, "sec-websocket-key")) {
27 m_key = value;
28 } else if (equals_lower(name, "sec-websocket-version")) {
29 m_version = value;
30 } else if (equals_lower(name, "sec-websocket-protocol")) {
31 // Protocols are comma delimited, repeated headers add to list
32 SmallVector<std::string_view, 2> protocols;
33 split(value, protocols, ",", -1, false);
34 for (auto protocol : protocols) {
35 protocol = trim(protocol);
36 if (!protocol.empty()) {
37 m_protocols.emplace_back(protocol);
38 }
39 }
40 }
41 });
42 req.headersComplete.connect([&req, this](bool) {
43 if (req.IsUpgrade() && IsUpgrade()) {
44 upgrade();
45 }
46 });
47}
48
49std::pair<bool, std::string_view> WebSocketServerHelper::MatchProtocol(
50 std::span<const std::string_view> protocols) {
51 if (protocols.empty() && m_protocols.empty()) {
52 return {true, {}};
53 }
54 for (auto protocol : protocols) {
55 for (auto&& clientProto : m_protocols) {
56 if (protocol == clientProto) {
57 return {true, protocol};
58 }
59 }
60 }
61 return {false, {}};
62}
63
64WebSocketServer::WebSocketServer(uv::Stream& stream,
65 std::span<const std::string_view> protocols,
66 ServerOptions options, const private_init&)
67 : m_stream{stream},
68 m_helper{m_req},
69 m_protocols{protocols.begin(), protocols.end()},
70 m_options{std::move(options)} {
71 // Header handling
72 m_req.header.connect([this](std::string_view name, std::string_view value) {
73 if (equals_lower(name, "host")) {
74 if (m_options.checkHost) {
75 if (!m_options.checkHost(value)) {
76 Abort(401, "Unrecognized Host");
77 }
78 }
79 }
80 });
81 m_req.url.connect([this](std::string_view name) {
82 if (m_options.checkUrl) {
83 if (!m_options.checkUrl(name)) {
84 Abort(404, "Not Found");
85 }
86 }
87 });
88 m_req.headersComplete.connect([this](bool) {
89 // We only accept websocket connections
90 if (!m_helper.IsUpgrade() || !m_req.IsUpgrade()) {
91 Abort(426, "Upgrade Required");
92 }
93 });
94
95 // Handle upgrade event
96 m_helper.upgrade.connect([this] {
97 if (m_aborted) {
98 return;
99 }
100
101 // Negotiate sub-protocol
102 SmallVector<std::string_view, 2> protocols{m_protocols.begin(),
103 m_protocols.end()};
104 std::string_view protocol = m_helper.MatchProtocol(protocols).second;
105
106 // Disconnect our header reader
107 m_dataConn.disconnect();
108
109 // Accepting the stream may destroy this (as it replaces the stream user
110 // data), so grab a shared pointer first.
111 auto self = shared_from_this();
112
113 // Accept the upgrade
114 auto ws = m_helper.Accept(m_stream, protocol);
115
116 // Connect the websocket open event to our connected event.
117 ws->open.connect_extended(
118 [self, s = ws.get()](auto conn, std::string_view) {
119 self->connected(self->m_req.GetUrl(), *s);
120 conn.disconnect(); // one-shot
121 });
122 });
123
124 // Set up stream
125 stream.StartRead();
126 m_dataConn =
127 stream.data.connect_connection([this](uv::Buffer& buf, size_t size) {
128 if (m_aborted) {
129 return;
130 }
131 m_req.Execute(std::string_view{buf.base, size});
132 if (m_req.HasError()) {
133 Abort(400, "Bad Request");
134 }
135 });
136 m_errorConn =
137 stream.error.connect_connection([this](uv::Error) { m_stream.Close(); });
138 m_endConn = stream.end.connect_connection([this] { m_stream.Close(); });
139}
140
141std::shared_ptr<WebSocketServer> WebSocketServer::Create(
142 uv::Stream& stream, std::span<const std::string_view> protocols,
143 const ServerOptions& options) {
144 auto server = std::make_shared<WebSocketServer>(stream, protocols, options,
145 private_init{});
146 stream.SetData(server);
147 return server;
148}
149
150void WebSocketServer::Abort(uint16_t code, std::string_view reason) {
151 if (m_aborted) {
152 return;
153 }
154 m_aborted = true;
155
156 // Build response
157 SmallVector<uv::Buffer, 4> bufs;
158 raw_uv_ostream os{bufs, 1024};
159
160 // Handle unsupported version
161 fmt::print(os, "HTTP/1.1 {} {}\r\n", code, reason);
162 if (code == 426) {
163 os << "Upgrade: WebSocket\r\n";
164 }
165 os << "\r\n";
166 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
167 for (auto& buf : bufs) {
168 buf.Deallocate();
169 }
170 m_stream.Shutdown([this] { m_stream.Close(); });
171 });
172}