| // Copyright (c) FIRST and other WPILib contributors. |
| // Open Source Software; you can modify and/or share it under the terms of |
| // the WPILib BSD license file in the root directory of this project. |
| |
| #include "wpi/WebSocket.h" // NOLINT(build/include_order) |
| |
| #include "WebSocketTest.h" |
| |
| #include "wpi/HttpParser.h" |
| #include "wpi/StringExtras.h" |
| |
| namespace wpi { |
| |
| #ifdef _WIN32 |
| const char* WebSocketTest::pipeName = "\\\\.\\pipe\\websocket-unit-test"; |
| #else |
| const char* WebSocketTest::pipeName = "/tmp/websocket-unit-test"; |
| #endif |
| const uint8_t WebSocketTest::testMask[4] = {0x11, 0x22, 0x33, 0x44}; |
| |
| void WebSocketTest::SetUpTestCase() { |
| #ifndef _WIN32 |
| unlink(pipeName); |
| #endif |
| } |
| |
| std::vector<uint8_t> WebSocketTest::BuildHeader(uint8_t opcode, bool fin, |
| bool masking, uint64_t len) { |
| std::vector<uint8_t> data; |
| data.push_back(opcode | (fin ? 0x80u : 0x00u)); |
| if (len < 126) { |
| data.push_back(len | (masking ? 0x80 : 0x00u)); |
| } else if (len < 65536) { |
| data.push_back(126u | (masking ? 0x80 : 0x00u)); |
| data.push_back(len >> 8); |
| data.push_back(len & 0xff); |
| } else { |
| data.push_back(127u | (masking ? 0x80u : 0x00u)); |
| for (int i = 56; i >= 0; i -= 8) { |
| data.push_back((len >> i) & 0xff); |
| } |
| } |
| if (masking) { |
| data.insert(data.end(), &testMask[0], &testMask[4]); |
| } |
| return data; |
| } |
| |
| std::vector<uint8_t> WebSocketTest::BuildMessage(uint8_t opcode, bool fin, |
| bool masking, |
| span<const uint8_t> data) { |
| auto finalData = BuildHeader(opcode, fin, masking, data.size()); |
| size_t headerSize = finalData.size(); |
| finalData.insert(finalData.end(), data.begin(), data.end()); |
| if (masking) { |
| uint8_t mask[4] = {finalData[headerSize - 4], finalData[headerSize - 3], |
| finalData[headerSize - 2], finalData[headerSize - 1]}; |
| int n = 0; |
| for (size_t i = headerSize, end = finalData.size(); i < end; ++i) { |
| finalData[i] ^= mask[n++]; |
| if (n >= 4) { |
| n = 0; |
| } |
| } |
| } |
| return finalData; |
| } |
| |
| // If the message is masked, changes the mask to match the mask set by |
| // BuildHeader() by unmasking and remasking. |
| void WebSocketTest::AdjustMasking(span<uint8_t> message) { |
| if (message.size() < 2) { |
| return; |
| } |
| if ((message[1] & 0x80) == 0) { |
| return; // not masked |
| } |
| size_t maskPos; |
| uint8_t len = message[1] & 0x7f; |
| if (len == 126) { |
| maskPos = 4; |
| } else if (len == 127) { |
| maskPos = 10; |
| } else { |
| maskPos = 2; |
| } |
| uint8_t mask[4] = {message[maskPos], message[maskPos + 1], |
| message[maskPos + 2], message[maskPos + 3]}; |
| message[maskPos] = testMask[0]; |
| message[maskPos + 1] = testMask[1]; |
| message[maskPos + 2] = testMask[2]; |
| message[maskPos + 3] = testMask[3]; |
| int n = 0; |
| for (auto& ch : message.subspan(maskPos + 4)) { |
| ch ^= mask[n] ^ testMask[n]; |
| if (++n >= 4) { |
| n = 0; |
| } |
| } |
| } |
| |
| TEST_F(WebSocketTest, CreateClientBasic) { |
| int gotHost = 0; |
| int gotUpgrade = 0; |
| int gotConnection = 0; |
| int gotKey = 0; |
| int gotVersion = 0; |
| |
| HttpParser req{HttpParser::kRequest}; |
| req.url.connect([](std::string_view url) { ASSERT_EQ(url, "/test"); }); |
| req.header.connect([&](std::string_view name, std::string_view value) { |
| if (equals_lower(name, "host")) { |
| ASSERT_EQ(value, pipeName); |
| ++gotHost; |
| } else if (equals_lower(name, "upgrade")) { |
| ASSERT_EQ(value, "websocket"); |
| ++gotUpgrade; |
| } else if (equals_lower(name, "connection")) { |
| ASSERT_EQ(value, "Upgrade"); |
| ++gotConnection; |
| } else if (equals_lower(name, "sec-websocket-key")) { |
| ++gotKey; |
| } else if (equals_lower(name, "sec-websocket-version")) { |
| ASSERT_EQ(value, "13"); |
| ++gotVersion; |
| } else { |
| FAIL() << "unexpected header " << name; |
| } |
| }); |
| req.headersComplete.connect([&](bool) { Finish(); }); |
| |
| serverPipe->Listen([&]() { |
| auto conn = serverPipe->Accept(); |
| conn->StartRead(); |
| conn->data.connect([&](uv::Buffer& buf, size_t size) { |
| req.Execute(std::string_view{buf.base, size}); |
| if (req.HasError()) { |
| Finish(); |
| } |
| ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError()); |
| }); |
| }); |
| clientPipe->Connect(pipeName, [&]() { |
| auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName); |
| }); |
| |
| loop->Run(); |
| |
| if (HasFatalFailure()) { |
| return; |
| } |
| ASSERT_EQ(gotHost, 1); |
| ASSERT_EQ(gotUpgrade, 1); |
| ASSERT_EQ(gotConnection, 1); |
| ASSERT_EQ(gotKey, 1); |
| ASSERT_EQ(gotVersion, 1); |
| } |
| |
| TEST_F(WebSocketTest, CreateClientExtraHeaders) { |
| int gotExtra1 = 0; |
| int gotExtra2 = 0; |
| HttpParser req{HttpParser::kRequest}; |
| req.header.connect([&](std::string_view name, std::string_view value) { |
| if (equals(name, "Extra1")) { |
| ASSERT_EQ(value, "Data1"); |
| ++gotExtra1; |
| } else if (equals(name, "Extra2")) { |
| ASSERT_EQ(value, "Data2"); |
| ++gotExtra2; |
| } |
| }); |
| req.headersComplete.connect([&](bool) { Finish(); }); |
| |
| serverPipe->Listen([&]() { |
| auto conn = serverPipe->Accept(); |
| conn->StartRead(); |
| conn->data.connect([&](uv::Buffer& buf, size_t size) { |
| req.Execute(std::string_view{buf.base, size}); |
| if (req.HasError()) { |
| Finish(); |
| } |
| ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError()); |
| }); |
| }); |
| clientPipe->Connect(pipeName, [&]() { |
| WebSocket::ClientOptions options; |
| SmallVector<std::pair<std::string_view, std::string_view>, 4> extraHeaders; |
| extraHeaders.emplace_back("Extra1", "Data1"); |
| extraHeaders.emplace_back("Extra2", "Data2"); |
| options.extraHeaders = extraHeaders; |
| auto ws = |
| WebSocket::CreateClient(*clientPipe, "/test", pipeName, {}, options); |
| }); |
| |
| loop->Run(); |
| |
| if (HasFatalFailure()) { |
| return; |
| } |
| ASSERT_EQ(gotExtra1, 1); |
| ASSERT_EQ(gotExtra2, 1); |
| } |
| |
| TEST_F(WebSocketTest, CreateClientTimeout) { |
| int gotClosed = 0; |
| serverPipe->Listen([&]() { auto conn = serverPipe->Accept(); }); |
| clientPipe->Connect(pipeName, [&]() { |
| WebSocket::ClientOptions options; |
| options.handshakeTimeout = uv::Timer::Time{100}; |
| auto ws = |
| WebSocket::CreateClient(*clientPipe, "/test", pipeName, {}, options); |
| ws->closed.connect([&](uint16_t code, std::string_view) { |
| Finish(); |
| ++gotClosed; |
| ASSERT_EQ(code, 1006); |
| }); |
| }); |
| |
| loop->Run(); |
| |
| if (HasFatalFailure()) { |
| return; |
| } |
| ASSERT_EQ(gotClosed, 1); |
| } |
| |
| TEST_F(WebSocketTest, CreateServerBasic) { |
| int gotStatus = 0; |
| int gotUpgrade = 0; |
| int gotConnection = 0; |
| int gotAccept = 0; |
| int gotOpen = 0; |
| |
| HttpParser resp{HttpParser::kResponse}; |
| resp.status.connect([&](std::string_view status) { |
| ++gotStatus; |
| ASSERT_EQ(resp.GetStatusCode(), 101u) << "status: " << status; |
| }); |
| resp.header.connect([&](std::string_view name, std::string_view value) { |
| if (equals_lower(name, "upgrade")) { |
| ASSERT_EQ(value, "websocket"); |
| ++gotUpgrade; |
| } else if (equals_lower(name, "connection")) { |
| ASSERT_EQ(value, "Upgrade"); |
| ++gotConnection; |
| } else if (equals_lower(name, "sec-websocket-accept")) { |
| ASSERT_EQ(value, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); |
| ++gotAccept; |
| } else { |
| FAIL() << "unexpected header " << name; |
| } |
| }); |
| resp.headersComplete.connect([&](bool) { Finish(); }); |
| |
| serverPipe->Listen([&]() { |
| auto conn = serverPipe->Accept(); |
| auto ws = WebSocket::CreateServer(*conn, "dGhlIHNhbXBsZSBub25jZQ==", "13"); |
| ws->open.connect([&](std::string_view protocol) { |
| ++gotOpen; |
| ASSERT_TRUE(protocol.empty()); |
| }); |
| }); |
| clientPipe->Connect(pipeName, [&] { |
| clientPipe->StartRead(); |
| clientPipe->data.connect([&](uv::Buffer& buf, size_t size) { |
| resp.Execute(std::string_view{buf.base, size}); |
| if (resp.HasError()) { |
| Finish(); |
| } |
| ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError()); |
| }); |
| }); |
| |
| loop->Run(); |
| |
| if (HasFatalFailure()) { |
| return; |
| } |
| ASSERT_EQ(gotStatus, 1); |
| ASSERT_EQ(gotUpgrade, 1); |
| ASSERT_EQ(gotConnection, 1); |
| ASSERT_EQ(gotAccept, 1); |
| ASSERT_EQ(gotOpen, 1); |
| } |
| |
| TEST_F(WebSocketTest, CreateServerProtocol) { |
| int gotProtocol = 0; |
| int gotOpen = 0; |
| |
| HttpParser resp{HttpParser::kResponse}; |
| resp.header.connect([&](std::string_view name, std::string_view value) { |
| if (equals_lower(name, "sec-websocket-protocol")) { |
| ++gotProtocol; |
| ASSERT_EQ(value, "myProtocol"); |
| } |
| }); |
| resp.headersComplete.connect([&](bool) { Finish(); }); |
| |
| serverPipe->Listen([&]() { |
| auto conn = serverPipe->Accept(); |
| auto ws = WebSocket::CreateServer(*conn, "foo", "13", "myProtocol"); |
| ws->open.connect([&](std::string_view protocol) { |
| ++gotOpen; |
| ASSERT_EQ(protocol, "myProtocol"); |
| }); |
| }); |
| clientPipe->Connect(pipeName, [&] { |
| clientPipe->StartRead(); |
| clientPipe->data.connect([&](uv::Buffer& buf, size_t size) { |
| resp.Execute(std::string_view{buf.base, size}); |
| if (resp.HasError()) { |
| Finish(); |
| } |
| ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError()); |
| }); |
| }); |
| |
| loop->Run(); |
| |
| if (HasFatalFailure()) { |
| return; |
| } |
| ASSERT_EQ(gotProtocol, 1); |
| ASSERT_EQ(gotOpen, 1); |
| } |
| |
| TEST_F(WebSocketTest, CreateServerBadVersion) { |
| int gotStatus = 0; |
| int gotVersion = 0; |
| int gotUpgrade = 0; |
| |
| HttpParser resp{HttpParser::kResponse}; |
| resp.status.connect([&](std::string_view status) { |
| ++gotStatus; |
| ASSERT_EQ(resp.GetStatusCode(), 426u) << "status: " << status; |
| }); |
| resp.header.connect([&](std::string_view name, std::string_view value) { |
| if (equals_lower(name, "sec-websocket-version")) { |
| ++gotVersion; |
| ASSERT_EQ(value, "13"); |
| } else if (equals_lower(name, "upgrade")) { |
| ++gotUpgrade; |
| ASSERT_EQ(value, "WebSocket"); |
| } else { |
| FAIL() << "unexpected header " << name; |
| } |
| }); |
| resp.headersComplete.connect([&](bool) { Finish(); }); |
| |
| serverPipe->Listen([&] { |
| auto conn = serverPipe->Accept(); |
| auto ws = WebSocket::CreateServer(*conn, "foo", "14"); |
| ws->open.connect([&](std::string_view) { |
| Finish(); |
| FAIL(); |
| }); |
| }); |
| clientPipe->Connect(pipeName, [&] { |
| clientPipe->StartRead(); |
| clientPipe->data.connect([&](uv::Buffer& buf, size_t size) { |
| resp.Execute(std::string_view{buf.base, size}); |
| if (resp.HasError()) { |
| Finish(); |
| } |
| ASSERT_EQ(resp.GetError(), HPE_OK) << http_errno_name(resp.GetError()); |
| }); |
| }); |
| |
| loop->Run(); |
| |
| if (HasFatalFailure()) { |
| return; |
| } |
| ASSERT_EQ(gotStatus, 1); |
| ASSERT_EQ(gotVersion, 1); |
| ASSERT_EQ(gotUpgrade, 1); |
| } |
| |
| } // namespace wpi |