blob: 692e2a7dd9e6d7d4ec16d289bb745479bac3d92e [file] [log] [blame]
Brian Silverman41cdd3e2019-01-19 19:48:58 -08001/*----------------------------------------------------------------------------*/
2/* Copyright (c) 2018 FIRST. All Rights Reserved. */
3/* Open Source Software - may be modified and shared by FRC teams. The code */
4/* must be accompanied by the FIRST BSD license file in the root directory of */
5/* the project. */
6/*----------------------------------------------------------------------------*/
7
8#include "wpi/WebSocket.h" // NOLINT(build/include_order)
9
10#include "WebSocketTest.h"
11#include "wpi/Base64.h"
12#include "wpi/HttpParser.h"
13#include "wpi/SmallString.h"
14#include "wpi/raw_uv_ostream.h"
15#include "wpi/sha1.h"
16
17namespace wpi {
18
19class WebSocketClientTest : public WebSocketTest {
20 public:
21 WebSocketClientTest() {
22 // Bare bones server
23 req.header.connect([this](StringRef name, StringRef value) {
24 // save key (required for valid response)
25 if (name.equals_lower("sec-websocket-key")) clientKey = value;
26 });
27 req.headersComplete.connect([this](bool) {
28 // send response
29 SmallVector<uv::Buffer, 4> bufs;
30 raw_uv_ostream os{bufs, 4096};
31 os << "HTTP/1.1 101 Switching Protocols\r\n";
32 os << "Upgrade: websocket\r\n";
33 os << "Connection: Upgrade\r\n";
34
35 // accept hash
36 SHA1 hash;
37 hash.Update(clientKey);
38 hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
39 if (mockBadAccept) hash.Update("1");
40 SmallString<64> hashBuf;
41 SmallString<64> acceptBuf;
42 os << "Sec-WebSocket-Accept: "
43 << Base64Encode(hash.RawFinal(hashBuf), acceptBuf) << "\r\n";
44
45 if (!mockProtocol.empty())
46 os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n";
47
48 os << "\r\n";
49
50 conn->Write(bufs, [](auto bufs, uv::Error) {
51 for (auto& buf : bufs) buf.Deallocate();
52 });
53
54 serverHeadersDone = true;
55 if (connected) connected();
56 });
57
58 serverPipe->Listen([this] {
59 conn = serverPipe->Accept();
60 conn->StartRead();
61 conn->data.connect([this](uv::Buffer& buf, size_t size) {
62 StringRef data{buf.base, size};
63 if (!serverHeadersDone) {
64 data = req.Execute(data);
65 if (req.HasError()) Finish();
66 ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
67 if (data.empty()) return;
68 }
69 wireData.insert(wireData.end(), data.bytes_begin(), data.bytes_end());
70 });
71 conn->end.connect([this] { Finish(); });
72 });
73 }
74
75 bool mockBadAccept = false;
76 std::vector<uint8_t> wireData;
77 std::shared_ptr<uv::Pipe> conn;
78 HttpParser req{HttpParser::kRequest};
79 SmallString<64> clientKey;
80 std::string mockProtocol;
81 bool serverHeadersDone = false;
82 std::function<void()> connected;
83};
84
85TEST_F(WebSocketClientTest, Open) {
86 int gotOpen = 0;
87
88 clientPipe->Connect(pipeName, [&] {
89 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
90 ws->closed.connect([&](uint16_t code, StringRef reason) {
91 Finish();
92 if (code != 1005 && code != 1006)
93 FAIL() << "Code: " << code << " Reason: " << reason;
94 });
95 ws->open.connect([&](StringRef protocol) {
96 ++gotOpen;
97 Finish();
98 ASSERT_TRUE(protocol.empty());
99 });
100 });
101
102 loop->Run();
103
104 if (HasFatalFailure()) return;
105 ASSERT_EQ(gotOpen, 1);
106}
107
108TEST_F(WebSocketClientTest, BadAccept) {
109 int gotClosed = 0;
110
111 mockBadAccept = true;
112
113 clientPipe->Connect(pipeName, [&] {
114 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
115 ws->closed.connect([&](uint16_t code, StringRef msg) {
116 Finish();
117 ++gotClosed;
118 ASSERT_EQ(code, 1002) << "Message: " << msg;
119 });
120 ws->open.connect([&](StringRef protocol) {
121 Finish();
122 FAIL() << "Got open";
123 });
124 });
125
126 loop->Run();
127
128 if (HasFatalFailure()) return;
129 ASSERT_EQ(gotClosed, 1);
130}
131
132TEST_F(WebSocketClientTest, ProtocolGood) {
133 int gotOpen = 0;
134
135 mockProtocol = "myProtocol";
136
137 clientPipe->Connect(pipeName, [&] {
138 auto ws = WebSocket::CreateClient(
139 *clientPipe, "/test", pipeName,
140 ArrayRef<StringRef>{"myProtocol", "myProtocol2"});
141 ws->closed.connect([&](uint16_t code, StringRef msg) {
142 Finish();
143 if (code != 1005 && code != 1006)
144 FAIL() << "Code: " << code << "Message: " << msg;
145 });
146 ws->open.connect([&](StringRef protocol) {
147 ++gotOpen;
148 Finish();
149 ASSERT_EQ(protocol, "myProtocol");
150 });
151 });
152
153 loop->Run();
154
155 if (HasFatalFailure()) return;
156 ASSERT_EQ(gotOpen, 1);
157}
158
159TEST_F(WebSocketClientTest, ProtocolRespNotReq) {
160 int gotClosed = 0;
161
162 mockProtocol = "myProtocol";
163
164 clientPipe->Connect(pipeName, [&] {
165 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
166 ws->closed.connect([&](uint16_t code, StringRef msg) {
167 Finish();
168 ++gotClosed;
169 ASSERT_EQ(code, 1003) << "Message: " << msg;
170 });
171 ws->open.connect([&](StringRef protocol) {
172 Finish();
173 FAIL() << "Got open";
174 });
175 });
176
177 loop->Run();
178
179 if (HasFatalFailure()) return;
180 ASSERT_EQ(gotClosed, 1);
181}
182
183TEST_F(WebSocketClientTest, ProtocolReqNotResp) {
184 int gotClosed = 0;
185
186 clientPipe->Connect(pipeName, [&] {
187 auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
188 StringRef{"myProtocol"});
189 ws->closed.connect([&](uint16_t code, StringRef msg) {
190 Finish();
191 ++gotClosed;
192 ASSERT_EQ(code, 1002) << "Message: " << msg;
193 });
194 ws->open.connect([&](StringRef protocol) {
195 Finish();
196 FAIL() << "Got open";
197 });
198 });
199
200 loop->Run();
201
202 if (HasFatalFailure()) return;
203 ASSERT_EQ(gotClosed, 1);
204}
205
206//
207// Send and receive data. Most of these cases are tested in
208// WebSocketServerTest, so only spot check differences like masking.
209//
210
211class WebSocketClientDataTest : public WebSocketClientTest,
212 public ::testing::WithParamInterface<size_t> {
213 public:
214 WebSocketClientDataTest() {
215 clientPipe->Connect(pipeName, [&] {
216 ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
217 if (setupWebSocket) setupWebSocket();
218 });
219 }
220
221 std::function<void()> setupWebSocket;
222 std::shared_ptr<WebSocket> ws;
223};
224
225INSTANTIATE_TEST_CASE_P(WebSocketClientDataTests, WebSocketClientDataTest,
226 ::testing::Values(0, 1, 125, 126, 65535, 65536), );
227
228TEST_P(WebSocketClientDataTest, SendBinary) {
229 int gotCallback = 0;
230 std::vector<uint8_t> data(GetParam(), 0x03u);
231 setupWebSocket = [&] {
232 ws->open.connect([&](StringRef) {
233 ws->SendBinary(uv::Buffer(data), [&](auto bufs, uv::Error) {
234 ++gotCallback;
235 ws->Terminate();
236 ASSERT_FALSE(bufs.empty());
237 ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
238 });
239 });
240 };
241
242 loop->Run();
243
244 auto expectData = BuildMessage(0x02, true, true, data);
245 AdjustMasking(wireData);
246 ASSERT_EQ(wireData, expectData);
247 ASSERT_EQ(gotCallback, 1);
248}
249
250TEST_P(WebSocketClientDataTest, ReceiveBinary) {
251 int gotCallback = 0;
252 std::vector<uint8_t> data(GetParam(), 0x03u);
253 setupWebSocket = [&] {
254 ws->binary.connect([&](ArrayRef<uint8_t> inData, bool fin) {
255 ++gotCallback;
256 ws->Terminate();
257 ASSERT_TRUE(fin);
258 std::vector<uint8_t> recvData{inData.begin(), inData.end()};
259 ASSERT_EQ(data, recvData);
260 });
261 };
262 auto message = BuildMessage(0x02, true, false, data);
263 connected = [&] {
264 conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
265 };
266
267 loop->Run();
268
269 ASSERT_EQ(gotCallback, 1);
270}
271
272//
273// The client must close the connection if a masked frame is received.
274//
275
276TEST_P(WebSocketClientDataTest, ReceiveMasked) {
277 int gotCallback = 0;
278 std::vector<uint8_t> data(GetParam(), ' ');
279 setupWebSocket = [&] {
280 ws->text.connect([&](StringRef, bool) {
281 ws->Terminate();
282 FAIL() << "Should not have gotten masked message";
283 });
284 ws->closed.connect([&](uint16_t code, StringRef reason) {
285 ++gotCallback;
286 ASSERT_EQ(code, 1002) << "reason: " << reason;
287 });
288 };
289 auto message = BuildMessage(0x01, true, true, data);
290 connected = [&] {
291 conn->Write(uv::Buffer(message), [&](auto bufs, uv::Error) {});
292 };
293
294 loop->Run();
295
296 ASSERT_EQ(gotCallback, 1);
297}
298
299} // namespace wpi