blob: 1562f3bf04f606a909e2c06e2227c4f74d6fcd52 [file] [log] [blame]
Austin Schuh812d0d12021-11-04 20:16:48 -07001// Copyright (c) FIRST and other WPILib contributors.
2// Open Source Software; you can modify and/or share it under the terms of
3// the WPILib BSD license file in the root directory of this project.
Brian Silverman8fce7482020-01-05 13:18:21 -08004
5#include "wpi/WebSocketServer.h"
6
Austin Schuh812d0d12021-11-04 20:16:48 -07007#include <utility>
8
9#include "wpi/StringExtras.h"
10#include "wpi/fmt/raw_ostream.h"
Brian Silverman8fce7482020-01-05 13:18:21 -080011#include "wpi/raw_uv_ostream.h"
12#include "wpi/uv/Buffer.h"
13#include "wpi/uv/Stream.h"
14
15using namespace wpi;
16
17WebSocketServerHelper::WebSocketServerHelper(HttpParser& req) {
Austin Schuh812d0d12021-11-04 20:16:48 -070018 req.header.connect([this](std::string_view name, std::string_view value) {
19 if (equals_lower(name, "host")) {
Brian Silverman8fce7482020-01-05 13:18:21 -080020 m_gotHost = true;
Austin Schuh812d0d12021-11-04 20:16:48 -070021 } else if (equals_lower(name, "upgrade")) {
22 if (equals_lower(value, "websocket")) {
23 m_websocket = true;
24 }
25 } else if (equals_lower(name, "sec-websocket-key")) {
Brian Silverman8fce7482020-01-05 13:18:21 -080026 m_key = value;
Austin Schuh812d0d12021-11-04 20:16:48 -070027 } else if (equals_lower(name, "sec-websocket-version")) {
Brian Silverman8fce7482020-01-05 13:18:21 -080028 m_version = value;
Austin Schuh812d0d12021-11-04 20:16:48 -070029 } else if (equals_lower(name, "sec-websocket-protocol")) {
Brian Silverman8fce7482020-01-05 13:18:21 -080030 // Protocols are comma delimited, repeated headers add to list
Austin Schuh812d0d12021-11-04 20:16:48 -070031 SmallVector<std::string_view, 2> protocols;
32 split(value, protocols, ",", -1, false);
Brian Silverman8fce7482020-01-05 13:18:21 -080033 for (auto protocol : protocols) {
Austin Schuh812d0d12021-11-04 20:16:48 -070034 protocol = trim(protocol);
35 if (!protocol.empty()) {
36 m_protocols.emplace_back(protocol);
37 }
Brian Silverman8fce7482020-01-05 13:18:21 -080038 }
39 }
40 });
41 req.headersComplete.connect([&req, this](bool) {
Austin Schuh812d0d12021-11-04 20:16:48 -070042 if (req.IsUpgrade() && IsUpgrade()) {
43 upgrade();
44 }
Brian Silverman8fce7482020-01-05 13:18:21 -080045 });
46}
47
Austin Schuh812d0d12021-11-04 20:16:48 -070048std::pair<bool, std::string_view> WebSocketServerHelper::MatchProtocol(
49 span<const std::string_view> protocols) {
50 if (protocols.empty() && m_protocols.empty()) {
51 return {true, {}};
52 }
Brian Silverman8fce7482020-01-05 13:18:21 -080053 for (auto protocol : protocols) {
54 for (auto&& clientProto : m_protocols) {
Austin Schuh812d0d12021-11-04 20:16:48 -070055 if (protocol == clientProto) {
56 return {true, protocol};
57 }
Brian Silverman8fce7482020-01-05 13:18:21 -080058 }
59 }
Austin Schuh812d0d12021-11-04 20:16:48 -070060 return {false, {}};
Brian Silverman8fce7482020-01-05 13:18:21 -080061}
62
63WebSocketServer::WebSocketServer(uv::Stream& stream,
Austin Schuh812d0d12021-11-04 20:16:48 -070064 span<const std::string_view> protocols,
65 ServerOptions options, const private_init&)
Brian Silverman8fce7482020-01-05 13:18:21 -080066 : m_stream{stream},
67 m_helper{m_req},
68 m_protocols{protocols.begin(), protocols.end()},
Austin Schuh812d0d12021-11-04 20:16:48 -070069 m_options{std::move(options)} {
Brian Silverman8fce7482020-01-05 13:18:21 -080070 // Header handling
Austin Schuh812d0d12021-11-04 20:16:48 -070071 m_req.header.connect([this](std::string_view name, std::string_view value) {
72 if (equals_lower(name, "host")) {
Brian Silverman8fce7482020-01-05 13:18:21 -080073 if (m_options.checkHost) {
Austin Schuh812d0d12021-11-04 20:16:48 -070074 if (!m_options.checkHost(value)) {
75 Abort(401, "Unrecognized Host");
76 }
Brian Silverman8fce7482020-01-05 13:18:21 -080077 }
78 }
79 });
Austin Schuh812d0d12021-11-04 20:16:48 -070080 m_req.url.connect([this](std::string_view name) {
Brian Silverman8fce7482020-01-05 13:18:21 -080081 if (m_options.checkUrl) {
Austin Schuh812d0d12021-11-04 20:16:48 -070082 if (!m_options.checkUrl(name)) {
83 Abort(404, "Not Found");
84 }
Brian Silverman8fce7482020-01-05 13:18:21 -080085 }
86 });
87 m_req.headersComplete.connect([this](bool) {
88 // We only accept websocket connections
Austin Schuh812d0d12021-11-04 20:16:48 -070089 if (!m_helper.IsUpgrade() || !m_req.IsUpgrade()) {
Brian Silverman8fce7482020-01-05 13:18:21 -080090 Abort(426, "Upgrade Required");
Austin Schuh812d0d12021-11-04 20:16:48 -070091 }
Brian Silverman8fce7482020-01-05 13:18:21 -080092 });
93
94 // Handle upgrade event
95 m_helper.upgrade.connect([this] {
Austin Schuh812d0d12021-11-04 20:16:48 -070096 if (m_aborted) {
97 return;
98 }
Brian Silverman8fce7482020-01-05 13:18:21 -080099
100 // Negotiate sub-protocol
Austin Schuh812d0d12021-11-04 20:16:48 -0700101 SmallVector<std::string_view, 2> protocols{m_protocols.begin(),
102 m_protocols.end()};
103 std::string_view protocol = m_helper.MatchProtocol(protocols).second;
Brian Silverman8fce7482020-01-05 13:18:21 -0800104
105 // Disconnect our header reader
106 m_dataConn.disconnect();
107
108 // Accepting the stream may destroy this (as it replaces the stream user
109 // data), so grab a shared pointer first.
110 auto self = shared_from_this();
111
112 // Accept the upgrade
113 auto ws = m_helper.Accept(m_stream, protocol);
114
115 // Connect the websocket open event to our connected event.
Austin Schuh812d0d12021-11-04 20:16:48 -0700116 ws->open.connect_extended(
117 [self, s = ws.get()](auto conn, std::string_view) {
118 self->connected(self->m_req.GetUrl(), *s);
119 conn.disconnect(); // one-shot
120 });
Brian Silverman8fce7482020-01-05 13:18:21 -0800121 });
122
123 // Set up stream
124 stream.StartRead();
125 m_dataConn =
126 stream.data.connect_connection([this](uv::Buffer& buf, size_t size) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700127 if (m_aborted) {
128 return;
129 }
130 m_req.Execute(std::string_view{buf.base, size});
131 if (m_req.HasError()) {
132 Abort(400, "Bad Request");
133 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800134 });
135 m_errorConn =
136 stream.error.connect_connection([this](uv::Error) { m_stream.Close(); });
137 m_endConn = stream.end.connect_connection([this] { m_stream.Close(); });
138}
139
140std::shared_ptr<WebSocketServer> WebSocketServer::Create(
Austin Schuh812d0d12021-11-04 20:16:48 -0700141 uv::Stream& stream, span<const std::string_view> protocols,
Brian Silverman8fce7482020-01-05 13:18:21 -0800142 const ServerOptions& options) {
143 auto server = std::make_shared<WebSocketServer>(stream, protocols, options,
144 private_init{});
145 stream.SetData(server);
146 return server;
147}
148
Austin Schuh812d0d12021-11-04 20:16:48 -0700149void WebSocketServer::Abort(uint16_t code, std::string_view reason) {
150 if (m_aborted) {
151 return;
152 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800153 m_aborted = true;
154
155 // Build response
156 SmallVector<uv::Buffer, 4> bufs;
157 raw_uv_ostream os{bufs, 1024};
158
159 // Handle unsupported version
Austin Schuh812d0d12021-11-04 20:16:48 -0700160 fmt::print(os, "HTTP/1.1 {} {}\r\n", code, reason);
161 if (code == 426) {
162 os << "Upgrade: WebSocket\r\n";
163 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800164 os << "\r\n";
165 m_stream.Write(bufs, [this](auto bufs, uv::Error) {
Austin Schuh812d0d12021-11-04 20:16:48 -0700166 for (auto& buf : bufs) {
167 buf.Deallocate();
168 }
Brian Silverman8fce7482020-01-05 13:18:21 -0800169 m_stream.Shutdown([this] { m_stream.Close(); });
170 });
171}