blob: 2db9b54a550fce95f5e1f60ebba528616b2cc2f4 [file] [log] [blame]
Brian Silverman41cdd3e2019-01-19 19:48:58 -08001/*----------------------------------------------------------------------------*/
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -08002/* Copyright (c) 2018-2019 FIRST. All Rights Reserved. */
Brian Silverman41cdd3e2019-01-19 19:48:58 -08003/* Open Source Software - may be modified and shared by FRC teams. The code */
4/* must be accompanied by the FIRST BSD license file in the root directory of */
5/* the project. */
6/*----------------------------------------------------------------------------*/
7
8#include "wpi/WebSocket.h" // NOLINT(build/include_order)
9
10#include "WebSocketTest.h"
11#include "wpi/Base64.h"
12#include "wpi/HttpParser.h"
13#include "wpi/SmallString.h"
14#include "wpi/raw_uv_ostream.h"
15#include "wpi/sha1.h"
16
17namespace wpi {
18
19class WebSocketClientTest : public WebSocketTest {
20 public:
21 WebSocketClientTest() {
22 // Bare bones server
23 req.header.connect([this](StringRef name, StringRef value) {
24 // save key (required for valid response)
25 if (name.equals_lower("sec-websocket-key")) clientKey = value;
26 });
27 req.headersComplete.connect([this](bool) {
28 // send response
29 SmallVector<uv::Buffer, 4> bufs;
30 raw_uv_ostream os{bufs, 4096};
31 os << "HTTP/1.1 101 Switching Protocols\r\n";
32 os << "Upgrade: websocket\r\n";
33 os << "Connection: Upgrade\r\n";
34
35 // accept hash
36 SHA1 hash;
37 hash.Update(clientKey);
38 hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
39 if (mockBadAccept) hash.Update("1");
40 SmallString<64> hashBuf;
41 SmallString<64> acceptBuf;
42 os << "Sec-WebSocket-Accept: "
43 << Base64Encode(hash.RawFinal(hashBuf), acceptBuf) << "\r\n";
44
45 if (!mockProtocol.empty())
46 os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n";
47
48 os << "\r\n";
49
50 conn->Write(bufs, [](auto bufs, uv::Error) {
51 for (auto& buf : bufs) buf.Deallocate();
52 });
53
54 serverHeadersDone = true;
55 if (connected) connected();
56 });
57
58 serverPipe->Listen([this] {
59 conn = serverPipe->Accept();
60 conn->StartRead();
61 conn->data.connect([this](uv::Buffer& buf, size_t size) {
62 StringRef data{buf.base, size};
63 if (!serverHeadersDone) {
64 data = req.Execute(data);
65 if (req.HasError()) Finish();
66 ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
67 if (data.empty()) return;
68 }
69 wireData.insert(wireData.end(), data.bytes_begin(), data.bytes_end());
70 });
71 conn->end.connect([this] { Finish(); });
72 });
73 }
74
75 bool mockBadAccept = false;
76 std::vector<uint8_t> wireData;
77 std::shared_ptr<uv::Pipe> conn;
78 HttpParser req{HttpParser::kRequest};
79 SmallString<64> clientKey;
80 std::string mockProtocol;
81 bool serverHeadersDone = false;
82 std::function<void()> connected;
83};
84
85TEST_F(WebSocketClientTest, Open) {
86 int gotOpen = 0;
87
88 clientPipe->Connect(pipeName, [&] {
89 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
90 ws->closed.connect([&](uint16_t code, StringRef reason) {
91 Finish();
92 if (code != 1005 && code != 1006)
93 FAIL() << "Code: " << code << " Reason: " << reason;
94 });
95 ws->open.connect([&](StringRef protocol) {
96 ++gotOpen;
97 Finish();
98 ASSERT_TRUE(protocol.empty());
99 });
100 });
101
102 loop->Run();
103
104 if (HasFatalFailure()) return;
105 ASSERT_EQ(gotOpen, 1);
106}
107
108TEST_F(WebSocketClientTest, BadAccept) {
109 int gotClosed = 0;
110
111 mockBadAccept = true;
112
113 clientPipe->Connect(pipeName, [&] {
114 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
115 ws->closed.connect([&](uint16_t code, StringRef msg) {
116 Finish();
117 ++gotClosed;
118 ASSERT_EQ(code, 1002) << "Message: " << msg;
119 });
120 ws->open.connect([&](StringRef protocol) {
121 Finish();
122 FAIL() << "Got open";
123 });
124 });
125
126 loop->Run();
127
128 if (HasFatalFailure()) return;
129 ASSERT_EQ(gotClosed, 1);
130}
131
132TEST_F(WebSocketClientTest, ProtocolGood) {
133 int gotOpen = 0;
134
135 mockProtocol = "myProtocol";
136
137 clientPipe->Connect(pipeName, [&] {
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800138 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
139 {"myProtocol", "myProtocol2"});
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800140 ws->closed.connect([&](uint16_t code, StringRef msg) {
141 Finish();
142 if (code != 1005 && code != 1006)
143 FAIL() << "Code: " << code << "Message: " << msg;
144 });
145 ws->open.connect([&](StringRef protocol) {
146 ++gotOpen;
147 Finish();
148 ASSERT_EQ(protocol, "myProtocol");
149 });
150 });
151
152 loop->Run();
153
154 if (HasFatalFailure()) return;
155 ASSERT_EQ(gotOpen, 1);
156}
157
158TEST_F(WebSocketClientTest, ProtocolRespNotReq) {
159 int gotClosed = 0;
160
161 mockProtocol = "myProtocol";
162
163 clientPipe->Connect(pipeName, [&] {
164 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
165 ws->closed.connect([&](uint16_t code, StringRef msg) {
166 Finish();
167 ++gotClosed;
168 ASSERT_EQ(code, 1003) << "Message: " << msg;
169 });
170 ws->open.connect([&](StringRef protocol) {
171 Finish();
172 FAIL() << "Got open";
173 });
174 });
175
176 loop->Run();
177
178 if (HasFatalFailure()) return;
179 ASSERT_EQ(gotClosed, 1);
180}
181
182TEST_F(WebSocketClientTest, ProtocolReqNotResp) {
183 int gotClosed = 0;
184
185 clientPipe->Connect(pipeName, [&] {
186 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
187 StringRef{"myProtocol"});
188 ws->closed.connect([&](uint16_t code, StringRef msg) {
189 Finish();
190 ++gotClosed;
191 ASSERT_EQ(code, 1002) << "Message: " << msg;
192 });
193 ws->open.connect([&](StringRef protocol) {
194 Finish();
195 FAIL() << "Got open";
196 });
197 });
198
199 loop->Run();
200
201 if (HasFatalFailure()) return;
202 ASSERT_EQ(gotClosed, 1);
203}
204
205//
206// Send and receive data. Most of these cases are tested in
207// WebSocketServerTest, so only spot check differences like masking.
208//
209
210class WebSocketClientDataTest : public WebSocketClientTest,
211 public ::testing::WithParamInterface<size_t> {
212 public:
213 WebSocketClientDataTest() {
214 clientPipe->Connect(pipeName, [&] {
215 ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
216 if (setupWebSocket) setupWebSocket();
217 });
218 }
219
220 std::function<void()> setupWebSocket;
221 std::shared_ptr<WebSocket> ws;
222};
223
James Kuszmaul4f3ad3c2019-12-01 16:35:21 -0800224INSTANTIATE_TEST_SUITE_P(WebSocketClientDataTests, WebSocketClientDataTest,
225 ::testing::Values(0, 1, 125, 126, 65535, 65536));
Brian Silverman41cdd3e2019-01-19 19:48:58 -0800226
227TEST_P(WebSocketClientDataTest, SendBinary) {
228 int gotCallback = 0;
229 std::vector<uint8_t> data(GetParam(), 0x03u);
230 setupWebSocket = [&] {
231 ws->open.connect([&](StringRef) {
232 ws->SendBinary(uv::Buffer(data), [&](auto bufs, uv::Error) {
233 ++gotCallback;
234 ws->Terminate();
235 ASSERT_FALSE(bufs.empty());
236 ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
237 });
238 });
239 };
240
241 loop->Run();
242
243 auto expectData = BuildMessage(0x02, true, true, data);
244 AdjustMasking(wireData);
245 ASSERT_EQ(wireData, expectData);
246 ASSERT_EQ(gotCallback, 1);
247}
248
249TEST_P(WebSocketClientDataTest, ReceiveBinary) {
250 int gotCallback = 0;
251 std::vector<uint8_t> data(GetParam(), 0x03u);
252 setupWebSocket = [&] {
253 ws->binary.connect([&](ArrayRef<uint8_t> inData, bool fin) {
254 ++gotCallback;
255 ws->Terminate();
256 ASSERT_TRUE(fin);
257 std::vector<uint8_t> recvData{inData.begin(), inData.end()};
258 ASSERT_EQ(data, recvData);
259 });
260 };
261 auto message = BuildMessage(0x02, true, false, data);
262 connected = [&] {
263 conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
264 };
265
266 loop->Run();
267
268 ASSERT_EQ(gotCallback, 1);
269}
270
271//
272// The client must close the connection if a masked frame is received.
273//
274
275TEST_P(WebSocketClientDataTest, ReceiveMasked) {
276 int gotCallback = 0;
277 std::vector<uint8_t> data(GetParam(), ' ');
278 setupWebSocket = [&] {
279 ws->text.connect([&](StringRef, bool) {
280 ws->Terminate();
281 FAIL() << "Should not have gotten masked message";
282 });
283 ws->closed.connect([&](uint16_t code, StringRef reason) {
284 ++gotCallback;
285 ASSERT_EQ(code, 1002) << "reason: " << reason;
286 });
287 };
288 auto message = BuildMessage(0x01, true, true, data);
289 connected = [&] {
290 conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
291 };
292
293 loop->Run();
294
295 ASSERT_EQ(gotCallback, 1);
296}
297
298} // namespace wpi