blob: 9d85e6676a1265401f34c873694857688456969d [file] [log] [blame]
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include "wpi/WebSocket.h" // NOLINT(build/include_order)
#include "WebSocketTest.h"
#include "wpi/Base64.h"
#include "wpi/HttpParser.h"
#include "wpi/SmallString.h"
#include "wpi/StringExtras.h"
#include "wpi/raw_uv_ostream.h"
#include "wpi/sha1.h"
namespace wpi {
class WebSocketClientTest : public WebSocketTest {
public:
WebSocketClientTest() {
// Bare bones server
req.header.connect([this](std::string_view name, std::string_view value) {
// save key (required for valid response)
if (equals_lower(name, "sec-websocket-key")) {
clientKey = value;
}
});
req.headersComplete.connect([this](bool) {
// send response
SmallVector<uv::Buffer, 4> bufs;
raw_uv_ostream os{bufs, 4096};
os << "HTTP/1.1 101 Switching Protocols\r\n";
os << "Upgrade: websocket\r\n";
os << "Connection: Upgrade\r\n";
// accept hash
SHA1 hash;
hash.Update(clientKey);
hash.Update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
if (mockBadAccept) {
hash.Update("1");
}
SmallString<64> hashBuf;
SmallString<64> acceptBuf;
os << "Sec-WebSocket-Accept: "
<< Base64Encode(hash.RawFinal(hashBuf), acceptBuf) << "\r\n";
if (!mockProtocol.empty()) {
os << "Sec-WebSocket-Protocol: " << mockProtocol << "\r\n";
}
os << "\r\n";
conn->Write(bufs, [](auto bufs, uv::Error) {
for (auto& buf : bufs) {
buf.Deallocate();
}
});
serverHeadersDone = true;
if (connected) {
connected();
}
});
serverPipe->Listen([this] {
conn = serverPipe->Accept();
conn->StartRead();
conn->data.connect([this](uv::Buffer& buf, size_t size) {
std::string_view data{buf.base, size};
if (!serverHeadersDone) {
data = req.Execute(data);
if (req.HasError()) {
Finish();
}
ASSERT_EQ(req.GetError(), HPE_OK) << http_errno_name(req.GetError());
if (data.empty()) {
return;
}
}
wireData.insert(wireData.end(), data.begin(), data.end());
});
conn->end.connect([this] { Finish(); });
});
}
bool mockBadAccept = false;
std::vector<uint8_t> wireData;
std::shared_ptr<uv::Pipe> conn;
HttpParser req{HttpParser::kRequest};
SmallString<64> clientKey;
std::string mockProtocol;
bool serverHeadersDone = false;
std::function<void()> connected;
};
TEST_F(WebSocketClientTest, Open) {
int gotOpen = 0;
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
ws->closed.connect([&](uint16_t code, std::string_view reason) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << " Reason: " << reason;
}
});
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
Finish();
ASSERT_TRUE(protocol.empty());
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotOpen, 1);
}
TEST_F(WebSocketClientTest, BadAccept) {
int gotClosed = 0;
mockBadAccept = true;
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
ASSERT_EQ(code, 1002) << "Message: " << msg;
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL() << "Got open";
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotClosed, 1);
}
TEST_F(WebSocketClientTest, ProtocolGood) {
int gotOpen = 0;
mockProtocol = "myProtocol";
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
{"myProtocol", "myProtocol2"});
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
if (code != 1005 && code != 1006) {
FAIL() << "Code: " << code << "Message: " << msg;
}
});
ws->open.connect([&](std::string_view protocol) {
++gotOpen;
Finish();
ASSERT_EQ(protocol, "myProtocol");
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotOpen, 1);
}
TEST_F(WebSocketClientTest, ProtocolRespNotReq) {
int gotClosed = 0;
mockProtocol = "myProtocol";
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
ASSERT_EQ(code, 1003) << "Message: " << msg;
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL() << "Got open";
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotClosed, 1);
}
TEST_F(WebSocketClientTest, ProtocolReqNotResp) {
int gotClosed = 0;
clientPipe->Connect(pipeName, [&] {
auto ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName,
{{"myProtocol"}});
ws->closed.connect([&](uint16_t code, std::string_view msg) {
Finish();
++gotClosed;
ASSERT_EQ(code, 1002) << "Message: " << msg;
});
ws->open.connect([&](std::string_view protocol) {
Finish();
FAIL() << "Got open";
});
});
loop->Run();
if (HasFatalFailure()) {
return;
}
ASSERT_EQ(gotClosed, 1);
}
//
// Send and receive data. Most of these cases are tested in
// WebSocketServerTest, so only spot check differences like masking.
//
class WebSocketClientDataTest : public WebSocketClientTest,
public ::testing::WithParamInterface<size_t> {
public:
WebSocketClientDataTest() {
clientPipe->Connect(pipeName, [&] {
ws = WebSocket::CreateClient(*clientPipe, "/test", pipeName);
if (setupWebSocket) {
setupWebSocket();
}
});
}
std::function<void()> setupWebSocket;
std::shared_ptr<WebSocket> ws;
};
INSTANTIATE_TEST_SUITE_P(WebSocketClientDataTests, WebSocketClientDataTest,
::testing::Values(0, 1, 125, 126, 65535, 65536));
TEST_P(WebSocketClientDataTest, SendBinary) {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
setupWebSocket = [&] {
ws->open.connect([&](std::string_view) {
ws->SendBinary({{data}}, [&](auto bufs, uv::Error) {
++gotCallback;
ws->Terminate();
ASSERT_FALSE(bufs.empty());
ASSERT_EQ(bufs[0].base, reinterpret_cast<const char*>(data.data()));
});
});
};
loop->Run();
auto expectData = BuildMessage(0x02, true, true, data);
AdjustMasking(wireData);
ASSERT_EQ(wireData, expectData);
ASSERT_EQ(gotCallback, 1);
}
TEST_P(WebSocketClientDataTest, ReceiveBinary) {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), 0x03u);
setupWebSocket = [&] {
ws->binary.connect([&](auto inData, bool fin) {
++gotCallback;
ws->Terminate();
ASSERT_TRUE(fin);
std::vector<uint8_t> recvData{inData.begin(), inData.end()};
ASSERT_EQ(data, recvData);
});
};
auto message = BuildMessage(0x02, true, false, data);
connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); };
loop->Run();
ASSERT_EQ(gotCallback, 1);
}
//
// The client must close the connection if a masked frame is received.
//
TEST_P(WebSocketClientDataTest, ReceiveMasked) {
int gotCallback = 0;
std::vector<uint8_t> data(GetParam(), ' ');
setupWebSocket = [&] {
ws->text.connect([&](std::string_view, bool) {
ws->Terminate();
FAIL() << "Should not have gotten masked message";
});
ws->closed.connect([&](uint16_t code, std::string_view reason) {
++gotCallback;
ASSERT_EQ(code, 1002) << "reason: " << reason;
});
};
auto message = BuildMessage(0x01, true, true, data);
connected = [&] { conn->Write({{message}}, [&](auto bufs, uv::Error) {}); };
loop->Run();
ASSERT_EQ(gotCallback, 1);
}
} // namespace wpi