blob: d202e5d2c2f208b461220204805359a75295d83d [file] [log] [blame]
#include <unistd.h>
#include <chrono>
#include <functional>
#include "absl/flags/declare.h"
#include "absl/flags/flag.h"
#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"
#include "aos/events/epoll.h"
#include "aos/network/sctp_client.h"
#include "aos/network/sctp_lib.h"
#include "aos/network/sctp_server.h"
#include "sctp_lib.h"
ABSL_DECLARE_FLAG(bool, disable_ipv6);
namespace aos::message_bridge::testing {
using ::aos::internal::EPoll;
using ::aos::internal::TimerFd;
using ::testing::ElementsAre;
using namespace ::std::chrono_literals;
constexpr int kPort = 19423;
constexpr int kStreams = 1;
namespace {
void EnableSctpAuthIfAvailable() {
#if HAS_SCTP_AUTH
// Open an SCTP socket to bring the kernel SCTP module
SctpServer server(1, "localhost");
CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || "
"/sbin/sysctl net.sctp.auth_enable=1") == 0)
<< "Couldn't enable sctp authentication.";
#endif
}
} // namespace
// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
// SctpClient), and an `sctp_notification` handler and a `message` handler. It
// asynchronously routes incoming messages to the appropriate handler.
template <typename T>
class SctpReceiver {
public:
SctpReceiver(
EPoll &epoll, T &receiver,
std::function<void(T &, const union sctp_notification *)> on_notify,
std::function<void(T &, std::vector<uint8_t>)> on_message)
: epoll_(epoll),
receiver_(receiver),
on_notify_(std::move(on_notify)),
on_message_(std::move(on_message)) {
epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
}
~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
private:
// Handles an incoming message by routing it to the apropriate handler.
void Read() {
aos::unique_c_ptr<Message> message = receiver_.Read();
if (!message) {
return;
}
switch (message->message_type) {
case Message::kNotification: {
const union sctp_notification *notification =
reinterpret_cast<const union sctp_notification *>(message->data());
on_notify_(receiver_, notification);
break;
}
case Message::kMessage:
on_message_(receiver_, std::vector(message->data(),
message->data() + message->size));
break;
case Message::kOverflow:
LOG(FATAL) << "Overflow";
}
receiver_.FreeMessage(std::move(message));
}
EPoll &epoll_;
T &receiver_;
std::function<void(T &, const union sctp_notification *)> on_notify_;
std::function<void(T &, std::vector<uint8_t>)> on_message_;
};
// Base SctpTest class.
//
// The class provides a few virtual methods that should be overriden to define
// the behavior of the test.
class SctpTest : public ::testing::Test {
public:
SctpTest(std::vector<uint8_t> server_key = {},
std::vector<uint8_t> client_key = {},
SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
std::chrono::milliseconds timeout = 1000ms)
: server_(kStreams, "", kPort, requested_authentication),
client_("localhost", kPort, kStreams, "", 0, requested_authentication),
client_receiver_(
epoll_, client_,
[this](SctpClient &client,
const union sctp_notification *notification) {
HandleNotification(client, notification);
},
[this](SctpClient &client, std::vector<uint8_t> message) {
HandleMessage(client, std::move(message));
}),
server_receiver_(
epoll_, server_,
[this](SctpServer &server,
const union sctp_notification *notification) {
HandleNotification(server, notification);
},
[this](SctpServer &server, std::vector<uint8_t> message) {
HandleMessage(server, std::move(message));
}) {
server_.SetAuthKey(server_key);
client_.SetAuthKey(client_key);
timeout_.SetTime(aos::monotonic_clock::now() + timeout,
std::chrono::milliseconds::zero());
epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
}
static void SetUpTestSuite() {
EnableSctpAuthIfAvailable();
// Buildkite seems to have issues with ipv6 sctp sockets...
absl::SetFlag(&FLAGS_disable_ipv6, true);
}
void SetUp() override { Run(); }
protected:
// Handles a server notification message.
//
// The default behaviour is to track the sctp association ID.
virtual void HandleNotification(SctpServer &,
const union sctp_notification *notification) {
if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
assoc_ = notification->sn_assoc_change.sac_assoc_id;
}
}
// Handles the client notification message.
virtual void HandleNotification(SctpClient &,
const union sctp_notification *) {}
// Handles a server "data" message.
virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
// Handles a client "data" message.
virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
// Defines the "timeout" behaviour (fail by default).
virtual void TimeOut() {
Quit();
FAIL() << "Timer expired";
}
virtual ~SctpTest() {}
// Quit the test.
void Quit() {
epoll_.DeleteFd(timeout_.fd());
epoll_.Quit();
}
void Run() { epoll_.Run(); }
SctpServer server_;
SctpClient client_;
sctp_assoc_t assoc_ = 0;
private:
TimerFd timeout_;
EPoll epoll_;
SctpReceiver<SctpClient> client_receiver_;
SctpReceiver<SctpServer> server_receiver_;
};
// Verifies we can ping the server, and the server replies.
class SctpPingPongTest : public SctpTest {
public:
SctpPingPongTest()
: SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
void HandleMessage(SctpServer &server,
std::vector<uint8_t> message) override {
// Server should receive a ping message.
EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
got_ping_ = true;
ASSERT_NE(assoc_, 0);
// Reply with "pong".
server.Send("pong", assoc_, 0, 0);
}
void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
// Client should receive a "pong" message.
EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
got_pong_ = true;
// We are done.
Quit();
}
~SctpPingPongTest() {
// Check that we got the ping/pong messages.
// This isn't strictly necessary as otherwise we would time out and fail
// anyway.
EXPECT_TRUE(got_ping_);
EXPECT_TRUE(got_pong_);
}
protected:
bool got_ping_ = false;
bool got_pong_ = false;
};
TEST_F(SctpPingPongTest, Test) {}
#if HAS_SCTP_AUTH
// Same as SctpPingPongTest but with authentication keys. Both keys are the
// same so it should work the same way.
class SctpAuthTest : public SctpTest {
public:
SctpAuthTest()
: SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
/*timeout*/ 20s) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
void HandleMessage(SctpServer &server,
std::vector<uint8_t> message) override {
// Server should receive a ping message.
EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
got_ping_ = true;
ASSERT_NE(assoc_, 0);
// Reply with "pong".
server.Send("pong", assoc_, 0, 0);
}
void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
// Client should receive a "pong" message.
EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
got_pong_ = true;
// We are done.
Quit();
}
~SctpAuthTest() {
EXPECT_TRUE(got_ping_);
EXPECT_TRUE(got_pong_);
}
protected:
bool got_ping_ = false;
bool got_pong_ = false;
};
TEST_F(SctpAuthTest, Test) {}
// Tests that we can dynamically change the SCTP authentication key used.
class SctpChangingAuthKeysTest : public SctpTest {
public:
SctpChangingAuthKeysTest()
: SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
SctpAuthMethod::kAuth) {
// Start by having the client send "ping".
client_.SetAuthKey({5, 4, 3, 2, 1});
server_.SetAuthKey({5, 4, 3, 2, 1});
client_.Send(0, "ping", 0);
}
void HandleMessage(SctpServer &server,
std::vector<uint8_t> message) override {
// Server should receive a ping message.
EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
got_ping_ = true;
ASSERT_NE(assoc_, 0);
// Reply with "pong".
server.Send("pong", assoc_, 0, 0);
}
void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
// Client should receive a "pong" message.
EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
got_pong_ = true;
// We are done.
Quit();
}
~SctpChangingAuthKeysTest() {
EXPECT_TRUE(got_ping_);
EXPECT_TRUE(got_pong_);
}
protected:
bool got_ping_ = false;
bool got_pong_ = false;
};
TEST_F(SctpChangingAuthKeysTest, Test) {}
// Keys don't match, we should send the `ping` message but the server should
// never receive it. We then time out as nothing calls Quit.
class SctpMismatchedAuthTest : public SctpTest {
public:
SctpMismatchedAuthTest()
: SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
SctpAuthMethod::kAuth) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
FAIL() << "Authentication keys don't match. Message should be discarded";
Quit();
}
// We expect to time out since we never get the message.
void TimeOut() override { Quit(); }
};
TEST_F(SctpMismatchedAuthTest, Test) {}
// Same as SctpMismatchedAuthTest but the client uses the null key. We should
// see the same behaviour.
class SctpOneNullKeyTest : public SctpTest {
public:
SctpOneNullKeyTest()
: SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
FAIL() << "Authentication keys don't match. Message should be discarded";
Quit();
}
// We expect to time out since we never get the message.
void TimeOut() override { Quit(); }
};
TEST_F(SctpOneNullKeyTest, Test) {}
// If we want SCTP authentication but we don't set the auth keys, we shouldn't
// be able to send data.
class SctpAuthKeysNotSet : public SctpTest {
public:
SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
// Start by having the client send "ping".
client_.Send(0, "ping", 0);
}
void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
FAIL() << "Haven't setup authentication keys. Should not get message.";
Quit();
}
// We expect to time out since we never get the message.
void TimeOut() override { Quit(); }
};
TEST_F(SctpAuthKeysNotSet, Test) {}
#endif // HAS_SCTP_AUTH
} // namespace aos::message_bridge::testing