blob: edf13d498f13640b88c79ac71f45efc04504c6cb [file] [log] [blame]
Adam Snaider96a0f4b2023-05-18 20:41:19 -07001#include <unistd.h>
2
3#include <chrono>
4#include <functional>
5
6#include "gflags/gflags.h"
7#include "gmock/gmock-matchers.h"
8#include "gtest/gtest.h"
9
10#include "aos/events/epoll.h"
11#include "aos/network/sctp_client.h"
12#include "aos/network/sctp_lib.h"
13#include "aos/network/sctp_server.h"
14
15DECLARE_bool(disable_ipv6);
16
17namespace aos::message_bridge::testing {
18
19using ::aos::internal::EPoll;
20using ::aos::internal::TimerFd;
21using ::testing::ElementsAre;
22
23using namespace ::std::chrono_literals;
24
25constexpr int kPort = 19423;
26constexpr int kStreams = 1;
27
28namespace {
29void EnableSctpAuthIfAvailable() {
30#if HAS_SCTP_AUTH
31 CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || /sbin/sysctl "
32 "net.sctp.auth_enable=1") == 0)
33 << "Couldn't enable sctp authentication.";
34#endif
35}
36} // namespace
37
38// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
39// SctpClient), and an `sctp_notification` handler and a `message` handler. It
40// asynchronously routes incoming messages to the appropriate handler.
41template <typename T>
42class SctpReceiver {
43 public:
44 SctpReceiver(
45 EPoll &epoll, T &receiver,
46 std::function<void(T &, const union sctp_notification *)> on_notify,
47 std::function<void(T &, std::vector<uint8_t>)> on_message)
48 : epoll_(epoll),
49 receiver_(receiver),
50 on_notify_(std::move(on_notify)),
51 on_message_(std::move(on_message)) {
52 epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
53 }
54
55 ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
56
57 private:
58 // Handles an incoming message by routing it to the apropriate handler.
59 void Read() {
60 aos::unique_c_ptr<Message> message = receiver_.Read();
61 if (!message) {
62 return;
63 }
64
65 switch (message->message_type) {
66 case Message::kNotification: {
67 const union sctp_notification *notification =
68 reinterpret_cast<const union sctp_notification *>(message->data());
69 on_notify_(receiver_, notification);
70 break;
71 }
72 case Message::kMessage:
73 on_message_(receiver_, std::vector(message->data(),
74 message->data() + message->size));
75 break;
76 case Message::kOverflow:
77 LOG(FATAL) << "Overflow";
78 }
79 receiver_.FreeMessage(std::move(message));
80 }
81
82 EPoll &epoll_;
83 T &receiver_;
84 std::function<void(T &, const union sctp_notification *)> on_notify_;
85 std::function<void(T &, std::vector<uint8_t>)> on_message_;
86};
87
88// Base SctpTest class.
89//
90// The class provides a few virtual methods that should be overriden to define
91// the behavior of the test.
92class SctpTest : public ::testing::Test {
93 public:
94 SctpTest(std::vector<uint8_t> server_key = {},
95 std::vector<uint8_t> client_key = {},
96 std::chrono::milliseconds timeout = 1000ms)
97 : server_(kStreams, "", kPort, std::move(server_key)),
98 client_("localhost", kPort, kStreams, "", 0, std::move(client_key)),
99 client_receiver_(
100 epoll_, client_,
101 [this](SctpClient &client,
102 const union sctp_notification *notification) {
103 HandleNotification(client, notification);
104 },
105 [this](SctpClient &client, std::vector<uint8_t> message) {
106 HandleMessage(client, std::move(message));
107 }),
108 server_receiver_(
109 epoll_, server_,
110 [this](SctpServer &server,
111 const union sctp_notification *notification) {
112 HandleNotification(server, notification);
113 },
114 [this](SctpServer &server, std::vector<uint8_t> message) {
115 HandleMessage(server, std::move(message));
116 }) {
117 timeout_.SetTime(aos::monotonic_clock::now() + timeout,
118 std::chrono::milliseconds::zero());
119 epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
120 }
121
122 static void SetUpTestSuite() {
123 EnableSctpAuthIfAvailable();
124 // Buildkite seems to have issues with ipv6 sctp sockets...
125 FLAGS_disable_ipv6 = true;
126 }
127
128 void SetUp() override { Run(); }
129
130 protected:
131 // Handles a server notification message.
132 //
133 // The default behaviour is to track the sctp association ID.
134 virtual void HandleNotification(SctpServer &,
135 const union sctp_notification *notification) {
136 if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
137 assoc_ = notification->sn_assoc_change.sac_assoc_id;
138 }
139 }
140
141 // Handles the client notification message.
142 virtual void HandleNotification(SctpClient &,
143 const union sctp_notification *) {}
144
145 // Handles a server "data" message.
146 virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
147 // Handles a client "data" message.
148 virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
149
150 // Defines the "timeout" behaviour (fail by default).
151 virtual void TimeOut() {
152 Quit();
153 FAIL() << "Timer expired";
154 }
155
156 virtual ~SctpTest() {}
157
158 // Quit the test.
159 void Quit() {
160 epoll_.DeleteFd(timeout_.fd());
161 epoll_.Quit();
162 }
163 void Run() { epoll_.Run(); }
164
165 SctpServer server_;
166 SctpClient client_;
167 sctp_assoc_t assoc_ = 0;
168
169 private:
170 TimerFd timeout_;
171 EPoll epoll_;
172 SctpReceiver<SctpClient> client_receiver_;
173 SctpReceiver<SctpServer> server_receiver_;
174};
175
176// Verifies we can ping the server, and the server replies.
177class SctpPingPongTest : public SctpTest {
178 public:
179 SctpPingPongTest() : SctpTest({}, {}, /*timeout=*/2s) {
180 // Start by having the client send "ping".
181 client_.Send(0, "ping", 0);
182 }
183
184 void HandleMessage(SctpServer &server,
185 std::vector<uint8_t> message) override {
186 // Server should receive a ping message.
187 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
188 got_ping_ = true;
189 ASSERT_NE(assoc_, 0);
190 // Reply with "pong".
191 server.Send("pong", assoc_, 0, 0);
192 }
193
194 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
195 // Client should receive a "pong" message.
196 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
197 got_pong_ = true;
198 // We are done.
199 Quit();
200 }
201 ~SctpPingPongTest() {
202 // Check that we got the ping/pong messages.
203 // This isn't strictly necessary as otherwise we would time out and fail
204 // anyway.
205 EXPECT_TRUE(got_ping_);
206 EXPECT_TRUE(got_pong_);
207 }
208
209 protected:
210 bool got_ping_ = false;
211 bool got_pong_ = false;
212};
213
214TEST_F(SctpPingPongTest, Test) {}
215
216#if HAS_SCTP_AUTH
217
218// Same as SctpPingPongTest but with authentication keys. Both keys are the
219// same so it should work the same way.
220class SctpAuthTest : public SctpTest {
221 public:
222 SctpAuthTest()
223 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
224 /*timeout*/ 20s) {
225 // Start by having the client send "ping".
226 client_.Send(0, "ping", 0);
227 }
228
229 void HandleMessage(SctpServer &server,
230 std::vector<uint8_t> message) override {
231 // Server should receive a ping message.
232 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
233 got_ping_ = true;
234 ASSERT_NE(assoc_, 0);
235 // Reply with "pong".
236 server.Send("pong", assoc_, 0, 0);
237 }
238 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
239 // Client should receive a "pong" message.
240 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
241 got_pong_ = true;
242 // We are done.
243 Quit();
244 }
245 ~SctpAuthTest() {
246 EXPECT_TRUE(got_ping_);
247 EXPECT_TRUE(got_pong_);
248 }
249
250 protected:
251 bool got_ping_ = false;
252 bool got_pong_ = false;
253};
254
255TEST_F(SctpAuthTest, Test) {}
256
257// Keys don't match, we should send the `ping` message but the server should
258// never receive it. We then time out as nothing calls Quit.
259class SctpMismatchedAuthTest : public SctpTest {
260 public:
261 SctpMismatchedAuthTest() : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10}) {
262 // Start by having the client send "ping".
263 client_.Send(0, "ping", 0);
264 }
265
266 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
267 FAIL() << "Authentication keys don't match. Message should be discarded";
268 Quit();
269 }
270
271 // We expect to time out since we never get the message.
272 void TimeOut() override { Quit(); }
273};
274
275TEST_F(SctpMismatchedAuthTest, Test) {}
276
277// Same as SctpMismatchedAuthTest but the client uses the null key. We should
278// see the same behaviour.
279class SctpOneNullKeyTest : public SctpTest {
280 public:
281 SctpOneNullKeyTest() : SctpTest({1, 2, 3, 4, 5, 6}, {}) {
282 // Start by having the client send "ping".
283 client_.Send(0, "ping", 0);
284 }
285
286 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
287 FAIL() << "Authentication keys don't match. Message should be discarded";
288 Quit();
289 }
290
291 // We expect to time out since we never get the message.
292 void TimeOut() override { Quit(); }
293};
294
295TEST_F(SctpOneNullKeyTest, Test) {}
296#endif // HAS_SCTP_AUTH
297
298} // namespace aos::message_bridge::testing