blob: c27bac051aa4b7740666cc7b3ea4e54561dab2ed [file] [log] [blame]
Brian Silverman41cdd3e2019-01-19 19:48:58 -08001/*----------------------------------------------------------------------------*/
2/* Copyright (c) 2018 FIRST. All Rights Reserved. */
3/* Open Source Software - may be modified and shared by FRC teams. The code */
4/* must be accompanied by the FIRST BSD license file in the root directory of */
5/* the project. */
6/*----------------------------------------------------------------------------*/
7
8#include "wpi/WebSocket.h" // NOLINT(build/include_order)
9
10#include "WebSocketTest.h"
11
12#include "wpi/HttpParser.h"
13
14namespace wpi {
15
16#ifdef _WIN32
17const char* WebSocketTest::pipeName = "\\\\.\\pipe\\websocket-unit-test";
18#else
19const char* WebSocketTest::pipeName = "/tmp/websocket-unit-test";
20#endif
21const uint8_t WebSocketTest::testMask[4] = {0x11, 0x22, 0x33, 0x44};
22
23void WebSocketTest::SetUpTestCase() {
24#ifndef _WIN32
25 unlink(pipeName);
26#endif
27}
28
29std::vector<uint8_t> WebSocketTest::BuildHeader(uint8_t opcode, bool fin,
30 bool masking, uint64_t len) {
31 std::vector<uint8_t> data;
32 data.push_back(opcode | (fin ? 0x80u : 0x00u));
33 if (len < 126) {
34 data.push_back(len | (masking ? 0x80 : 0x00u));
35 } else if (len < 65536) {
36 data.push_back(126u | (masking ? 0x80 : 0x00u));
37 data.push_back(len >> 8);
38 data.push_back(len & 0xff);
39 } else {
40 data.push_back(127u | (masking ? 0x80u : 0x00u));
41 for (int i = 56; i >= 0; i -= 8) data.push_back((len >> i) & 0xff);
42 }
43 if (masking) data.insert(data.end(), &testMask[0], &testMask[4]);
44 return data;
45}
46
47std::vector<uint8_t> WebSocketTest::BuildMessage(uint8_t opcode, bool fin,
48 bool masking,
49 ArrayRef<uint8_t> data) {
50 auto finalData = BuildHeader(opcode, fin, masking, data.size());
51 size_t headerSize = finalData.size();
52 finalData.insert(finalData.end(), data.begin(), data.end());
53 if (masking) {
54 uint8_t mask[4] = {finalData[headerSize - 4], finalData[headerSize - 3],
55 finalData[headerSize - 2], finalData[headerSize - 1]};
56 int n = 0;
57 for (size_t i = headerSize, end = finalData.size(); i < end; ++i) {
58 finalData[i] ^= mask[n++];
59 if (n >= 4) n = 0;
60 }
61 }
62 return finalData;
63}
64
65// If the message is masked, changes the mask to match the mask set by
66// BuildHeader() by unmasking and remasking.
67void WebSocketTest::AdjustMasking(MutableArrayRef<uint8_t> message) {
68 if (message.size() < 2) return;
69 if ((message[1] & 0x80) == 0) return; // not masked
70 size_t maskPos;
71 uint8_t len = message[1] & 0x7f;
72 if (len == 126)
73 maskPos = 4;
74 else if (len == 127)
75 maskPos = 10;
76 else
77 maskPos = 2;
78 uint8_t mask[4] = {message[maskPos], message[maskPos + 1],
79 message[maskPos + 2], message[maskPos + 3]};
80 message[maskPos] = testMask[0];
81 message[maskPos + 1] = testMask[1];
82 message[maskPos + 2] = testMask[2];
83 message[maskPos + 3] = testMask[3];
84 int n = 0;
85 for (auto& ch : message.slice(maskPos + 4)) {
86 ch ^= mask[n] ^ testMask[n];
87 if (++n >= 4) n = 0;
88 }
89}
90
91TEST_F(WebSocketTest, CreateClientBasic) {
92 int gotHost = 0;
93 int gotUpgrade = 0;
94 int gotConnection = 0;
95 int gotKey = 0;
96 int gotVersion = 0;
97
98 HttpParser req{HttpParser::kRequest};
99 req.url.connect([](StringRef url) { ASSERT_EQ(url, "/test"); });
100 req.header.connect([&](StringRef name, StringRef value) {
101 if (name.equals_lower("host")) {
102 ASSERT_EQ(value, pipeName);
103 ++gotHost;
104 } else if (name.equals_lower("upgrade")) {
105 ASSERT_EQ(value, "websocket");
106 ++gotUpgrade;
107 } else if (name.equals_lower("connection")) {
108 ASSERT_EQ(value, "Upgrade");
109 ++gotConnection;
110 } else if (name.equals_lower("sec-websocket-key")) {
111 ++gotKey;
112 } else if (name.equals_lower("sec-websocket-version")) {
113 ASSERT_EQ(value, "13");
114 ++gotVersion;
115 } else {
116 FAIL() << "unexpected header " << name.str();
117 }
118 });
119 req.headersComplete.connect([&](bool) { Finish(); });
120
121 serverPipe->Listen([&]() {
122 auto conn = serverPipe->Accept();
123 conn->StartRead();
124 conn->data.connect([&](uv::Buffer& buf, size_t size) {
125 req.Execute(StringRef{buf.base, size});
126 if (req.HasError()) Finish();
127 ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
128 });
129 });
130 clientPipe->Connect(pipeName, [&]() {
131 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
132 });
133
134 loop->Run();
135
136 if (HasFatalFailure()) return;
137 ASSERT_EQ(gotHost, 1);
138 ASSERT_EQ(gotUpgrade, 1);
139 ASSERT_EQ(gotConnection, 1);
140 ASSERT_EQ(gotKey, 1);
141 ASSERT_EQ(gotVersion, 1);
142}
143
144TEST_F(WebSocketTest, CreateClientExtraHeaders) {
145 int gotExtra1 = 0;
146 int gotExtra2 = 0;
147 HttpParser req{HttpParser::kRequest};
148 req.header.connect([&](StringRef name, StringRef value) {
149 if (name.equals("Extra1")) {
150 ASSERT_EQ(value, "Data1");
151 ++gotExtra1;
152 } else if (name.equals("Extra2")) {
153 ASSERT_EQ(value, "Data2");
154 ++gotExtra2;
155 }
156 });
157 req.headersComplete.connect([&](bool) { Finish(); });
158
159 serverPipe->Listen([&]() {
160 auto conn = serverPipe->Accept();
161 conn->StartRead();
162 conn->data.connect([&](uv::Buffer& buf, size_t size) {
163 req.Execute(StringRef{buf.base, size});
164 if (req.HasError()) Finish();
165 ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
166 });
167 });
168 clientPipe->Connect(pipeName, [&]() {
169 WebSocket::ClientOptions options;
170 SmallVector<std::pair<StringRef, StringRef>, 4> extraHeaders;
171 extraHeaders.emplace_back("Extra1", "Data1");
172 extraHeaders.emplace_back("Extra2", "Data2");
173 options.extraHeaders = extraHeaders;
174 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
175 ArrayRef<StringRef>{}, options);
176 });
177
178 loop->Run();
179
180 if (HasFatalFailure()) return;
181 ASSERT_EQ(gotExtra1, 1);
182 ASSERT_EQ(gotExtra2, 1);
183}
184
185TEST_F(WebSocketTest, CreateClientTimeout) {
186 int gotClosed = 0;
187 serverPipe->Listen([&]() { auto conn = serverPipe->Accept(); });
188 clientPipe->Connect(pipeName, [&]() {
189 WebSocket::ClientOptions options;
190 options.handshakeTimeout = uv::Timer::Time{100};
191 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
192 ArrayRef<StringRef>{}, options);
193 ws->closed.connect([&](uint16_t code, StringRef) {
194 Finish();
195 ++gotClosed;
196 ASSERT_EQ(code, 1006);
197 });
198 });
199
200 loop->Run();
201
202 if (HasFatalFailure()) return;
203 ASSERT_EQ(gotClosed, 1);
204}
205
206TEST_F(WebSocketTest, CreateServerBasic) {
207 int gotStatus = 0;
208 int gotUpgrade = 0;
209 int gotConnection = 0;
210 int gotAccept = 0;
211 int gotOpen = 0;
212
213 HttpParser resp{HttpParser::kResponse};
214 resp.status.connect([&](StringRef status) {
215 ++gotStatus;
216 ASSERT_EQ(resp.GetStatusCode(), 101u) << "status: " << status;
217 });
218 resp.header.connect([&](StringRef name, StringRef value) {
219 if (name.equals_lower("upgrade")) {
220 ASSERT_EQ(value, "websocket");
221 ++gotUpgrade;
222 } else if (name.equals_lower("connection")) {
223 ASSERT_EQ(value, "Upgrade");
224 ++gotConnection;
225 } else if (name.equals_lower("sec-websocket-accept")) {
226 ASSERT_EQ(value, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
227 ++gotAccept;
228 } else {
229 FAIL() << "unexpected header " << name.str();
230 }
231 });
232 resp.headersComplete.connect([&](bool) { Finish(); });
233
234 serverPipe->Listen([&]() {
235 auto conn = serverPipe->Accept();
236 auto ws = WebSocket::CreateServer(*conn, "dGhlIHNhbXBsZSBub25jZQ==", "13");
237 ws->open.connect([&](StringRef protocol) {
238 ++gotOpen;
239 ASSERT_TRUE(protocol.empty());
240 });
241 });
242 clientPipe->Connect(pipeName, [&] {
243 clientPipe->StartRead();
244 clientPipe->data.connect([&](uv::Buffer& buf, size_t size) {
245 resp.Execute(StringRef{buf.base, size});
246 if (resp.HasError()) Finish();
247 ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
248 });
249 });
250
251 loop->Run();
252
253 if (HasFatalFailure()) return;
254 ASSERT_EQ(gotStatus, 1);
255 ASSERT_EQ(gotUpgrade, 1);
256 ASSERT_EQ(gotConnection, 1);
257 ASSERT_EQ(gotAccept, 1);
258 ASSERT_EQ(gotOpen, 1);
259}
260
261TEST_F(WebSocketTest, CreateServerProtocol) {
262 int gotProtocol = 0;
263 int gotOpen = 0;
264
265 HttpParser resp{HttpParser::kResponse};
266 resp.header.connect([&](StringRef name, StringRef value) {
267 if (name.equals_lower("sec-websocket-protocol")) {
268 ++gotProtocol;
269 ASSERT_EQ(value, "myProtocol");
270 }
271 });
272 resp.headersComplete.connect([&](bool) { Finish(); });
273
274 serverPipe->Listen([&]() {
275 auto conn = serverPipe->Accept();
276 auto ws = WebSocket::CreateServer(*conn, "foo", "13", "myProtocol");
277 ws->open.connect([&](StringRef protocol) {
278 ++gotOpen;
279 ASSERT_EQ(protocol, "myProtocol");
280 });
281 });
282 clientPipe->Connect(pipeName, [&] {
283 clientPipe->StartRead();
284 clientPipe->data.connect([&](uv::Buffer& buf, size_t size) {
285 resp.Execute(StringRef{buf.base, size});
286 if (resp.HasError()) Finish();
287 ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
288 });
289 });
290
291 loop->Run();
292
293 if (HasFatalFailure()) return;
294 ASSERT_EQ(gotProtocol, 1);
295 ASSERT_EQ(gotOpen, 1);
296}
297
298TEST_F(WebSocketTest, CreateServerBadVersion) {
299 int gotStatus = 0;
300 int gotVersion = 0;
301 int gotUpgrade = 0;
302
303 HttpParser resp{HttpParser::kResponse};
304 resp.status.connect([&](StringRef status) {
305 ++gotStatus;
306 ASSERT_EQ(resp.GetStatusCode(), 426u) << "status: " << status;
307 });
308 resp.header.connect([&](StringRef name, StringRef value) {
309 if (name.equals_lower("sec-websocket-version")) {
310 ++gotVersion;
311 ASSERT_EQ(value, "13");
312 } else if (name.equals_lower("upgrade")) {
313 ++gotUpgrade;
314 ASSERT_EQ(value, "WebSocket");
315 } else {
316 FAIL() << "unexpected header " << name.str();
317 }
318 });
319 resp.headersComplete.connect([&](bool) { Finish(); });
320
321 serverPipe->Listen([&] {
322 auto conn = serverPipe->Accept();
323 auto ws = WebSocket::CreateServer(*conn, "foo", "14");
324 ws->open.connect([&](StringRef) {
325 Finish();
326 FAIL();
327 });
328 });
329 clientPipe->Connect(pipeName, [&] {
330 clientPipe->StartRead();
331 clientPipe->data.connect([&](uv::Buffer& buf, size_t size) {
332 resp.Execute(StringRef{buf.base, size});
333 if (resp.HasError()) Finish();
334 ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError());
335 });
336 });
337
338 loop->Run();
339
340 if (HasFatalFailure()) return;
341 ASSERT_EQ(gotStatus, 1);
342 ASSERT_EQ(gotVersion, 1);
343 ASSERT_EQ(gotUpgrade, 1);
344}
345
346} // namespace wpi