blob: bce5c190425cce5d9919fe8ef86d5961d5d11f67 [file] [log] [blame]
Adam Snaider96a0f4b2023-05-18 20:41:19 -07001#include <unistd.h>
2
3#include <chrono>
4#include <functional>
5
6#include "gflags/gflags.h"
7#include "gmock/gmock-matchers.h"
8#include "gtest/gtest.h"
9
10#include "aos/events/epoll.h"
11#include "aos/network/sctp_client.h"
12#include "aos/network/sctp_lib.h"
13#include "aos/network/sctp_server.h"
Adam Snaider930bf7f2023-09-22 10:57:21 -070014#include "sctp_lib.h"
Adam Snaider96a0f4b2023-05-18 20:41:19 -070015
16DECLARE_bool(disable_ipv6);
17
18namespace aos::message_bridge::testing {
19
20using ::aos::internal::EPoll;
21using ::aos::internal::TimerFd;
22using ::testing::ElementsAre;
23
24using namespace ::std::chrono_literals;
25
26constexpr int kPort = 19423;
27constexpr int kStreams = 1;
28
29namespace {
30void EnableSctpAuthIfAvailable() {
31#if HAS_SCTP_AUTH
Adam Snaider930bf7f2023-09-22 10:57:21 -070032 // Open an SCTP socket to bring the kernel SCTP module
33 SctpServer server(1, "localhost");
34 CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || "
35 "/sbin/sysctl net.sctp.auth_enable=1") == 0)
Adam Snaider96a0f4b2023-05-18 20:41:19 -070036 << "Couldn't enable sctp authentication.";
37#endif
38}
39} // namespace
40
41// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
42// SctpClient), and an `sctp_notification` handler and a `message` handler. It
43// asynchronously routes incoming messages to the appropriate handler.
44template <typename T>
45class SctpReceiver {
46 public:
47 SctpReceiver(
48 EPoll &epoll, T &receiver,
49 std::function<void(T &, const union sctp_notification *)> on_notify,
50 std::function<void(T &, std::vector<uint8_t>)> on_message)
51 : epoll_(epoll),
52 receiver_(receiver),
53 on_notify_(std::move(on_notify)),
54 on_message_(std::move(on_message)) {
55 epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
56 }
57
58 ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
59
60 private:
61 // Handles an incoming message by routing it to the apropriate handler.
62 void Read() {
63 aos::unique_c_ptr<Message> message = receiver_.Read();
64 if (!message) {
65 return;
66 }
67
68 switch (message->message_type) {
69 case Message::kNotification: {
70 const union sctp_notification *notification =
71 reinterpret_cast<const union sctp_notification *>(message->data());
72 on_notify_(receiver_, notification);
73 break;
74 }
75 case Message::kMessage:
76 on_message_(receiver_, std::vector(message->data(),
77 message->data() + message->size));
78 break;
79 case Message::kOverflow:
80 LOG(FATAL) << "Overflow";
81 }
82 receiver_.FreeMessage(std::move(message));
83 }
84
85 EPoll &epoll_;
86 T &receiver_;
87 std::function<void(T &, const union sctp_notification *)> on_notify_;
88 std::function<void(T &, std::vector<uint8_t>)> on_message_;
89};
90
91// Base SctpTest class.
92//
93// The class provides a few virtual methods that should be overriden to define
94// the behavior of the test.
95class SctpTest : public ::testing::Test {
96 public:
97 SctpTest(std::vector<uint8_t> server_key = {},
98 std::vector<uint8_t> client_key = {},
Adam Snaider9bb33442023-06-26 16:31:37 -070099 SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700100 std::chrono::milliseconds timeout = 1000ms)
Adam Snaider9bb33442023-06-26 16:31:37 -0700101 : server_(kStreams, "", kPort, requested_authentication),
102 client_("localhost", kPort, kStreams, "", 0, requested_authentication),
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700103 client_receiver_(
104 epoll_, client_,
105 [this](SctpClient &client,
106 const union sctp_notification *notification) {
107 HandleNotification(client, notification);
108 },
109 [this](SctpClient &client, std::vector<uint8_t> message) {
110 HandleMessage(client, std::move(message));
111 }),
112 server_receiver_(
113 epoll_, server_,
114 [this](SctpServer &server,
115 const union sctp_notification *notification) {
116 HandleNotification(server, notification);
117 },
118 [this](SctpServer &server, std::vector<uint8_t> message) {
119 HandleMessage(server, std::move(message));
120 }) {
Adam Snaider9bb33442023-06-26 16:31:37 -0700121 server_.SetAuthKey(server_key);
122 client_.SetAuthKey(client_key);
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700123 timeout_.SetTime(aos::monotonic_clock::now() + timeout,
124 std::chrono::milliseconds::zero());
125 epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
126 }
127
128 static void SetUpTestSuite() {
129 EnableSctpAuthIfAvailable();
130 // Buildkite seems to have issues with ipv6 sctp sockets...
131 FLAGS_disable_ipv6 = true;
132 }
133
134 void SetUp() override { Run(); }
135
136 protected:
137 // Handles a server notification message.
138 //
139 // The default behaviour is to track the sctp association ID.
140 virtual void HandleNotification(SctpServer &,
141 const union sctp_notification *notification) {
142 if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
143 assoc_ = notification->sn_assoc_change.sac_assoc_id;
144 }
145 }
146
147 // Handles the client notification message.
148 virtual void HandleNotification(SctpClient &,
149 const union sctp_notification *) {}
150
151 // Handles a server "data" message.
152 virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
153 // Handles a client "data" message.
154 virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
155
156 // Defines the "timeout" behaviour (fail by default).
157 virtual void TimeOut() {
158 Quit();
159 FAIL() << "Timer expired";
160 }
161
162 virtual ~SctpTest() {}
163
164 // Quit the test.
165 void Quit() {
166 epoll_.DeleteFd(timeout_.fd());
167 epoll_.Quit();
168 }
169 void Run() { epoll_.Run(); }
170
171 SctpServer server_;
172 SctpClient client_;
173 sctp_assoc_t assoc_ = 0;
174
175 private:
176 TimerFd timeout_;
177 EPoll epoll_;
178 SctpReceiver<SctpClient> client_receiver_;
179 SctpReceiver<SctpServer> server_receiver_;
180};
181
182// Verifies we can ping the server, and the server replies.
183class SctpPingPongTest : public SctpTest {
184 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700185 SctpPingPongTest()
186 : SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700187 // Start by having the client send "ping".
188 client_.Send(0, "ping", 0);
189 }
190
191 void HandleMessage(SctpServer &server,
192 std::vector<uint8_t> message) override {
193 // Server should receive a ping message.
194 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
195 got_ping_ = true;
196 ASSERT_NE(assoc_, 0);
197 // Reply with "pong".
198 server.Send("pong", assoc_, 0, 0);
199 }
200
201 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
202 // Client should receive a "pong" message.
203 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
204 got_pong_ = true;
205 // We are done.
206 Quit();
207 }
208 ~SctpPingPongTest() {
209 // Check that we got the ping/pong messages.
210 // This isn't strictly necessary as otherwise we would time out and fail
211 // anyway.
212 EXPECT_TRUE(got_ping_);
213 EXPECT_TRUE(got_pong_);
214 }
215
216 protected:
217 bool got_ping_ = false;
218 bool got_pong_ = false;
219};
220
221TEST_F(SctpPingPongTest, Test) {}
222
223#if HAS_SCTP_AUTH
224
225// Same as SctpPingPongTest but with authentication keys. Both keys are the
226// same so it should work the same way.
227class SctpAuthTest : public SctpTest {
228 public:
229 SctpAuthTest()
Adam Snaider9bb33442023-06-26 16:31:37 -0700230 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700231 /*timeout*/ 20s) {
232 // Start by having the client send "ping".
233 client_.Send(0, "ping", 0);
234 }
235
236 void HandleMessage(SctpServer &server,
237 std::vector<uint8_t> message) override {
238 // Server should receive a ping message.
239 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
240 got_ping_ = true;
241 ASSERT_NE(assoc_, 0);
242 // Reply with "pong".
243 server.Send("pong", assoc_, 0, 0);
244 }
245 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
246 // Client should receive a "pong" message.
247 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
248 got_pong_ = true;
249 // We are done.
250 Quit();
251 }
252 ~SctpAuthTest() {
253 EXPECT_TRUE(got_ping_);
254 EXPECT_TRUE(got_pong_);
255 }
256
257 protected:
258 bool got_ping_ = false;
259 bool got_pong_ = false;
260};
261
262TEST_F(SctpAuthTest, Test) {}
263
Adam Snaider9bb33442023-06-26 16:31:37 -0700264// Tests that we can dynamically change the SCTP authentication key used.
265class SctpChangingAuthKeysTest : public SctpTest {
266 public:
267 SctpChangingAuthKeysTest()
268 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
269 SctpAuthMethod::kAuth) {
270 // Start by having the client send "ping".
271 client_.SetAuthKey({5, 4, 3, 2, 1});
272 server_.SetAuthKey({5, 4, 3, 2, 1});
273 client_.Send(0, "ping", 0);
274 }
275
276 void HandleMessage(SctpServer &server,
277 std::vector<uint8_t> message) override {
278 // Server should receive a ping message.
279 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
280 got_ping_ = true;
281 ASSERT_NE(assoc_, 0);
282 // Reply with "pong".
283 server.Send("pong", assoc_, 0, 0);
284 }
285 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
286 // Client should receive a "pong" message.
287 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
288 got_pong_ = true;
289 // We are done.
290 Quit();
291 }
292
293 ~SctpChangingAuthKeysTest() {
294 EXPECT_TRUE(got_ping_);
295 EXPECT_TRUE(got_pong_);
296 }
297
298 protected:
299 bool got_ping_ = false;
300 bool got_pong_ = false;
301};
302
303TEST_F(SctpChangingAuthKeysTest, Test) {}
304
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700305// Keys don't match, we should send the `ping` message but the server should
306// never receive it. We then time out as nothing calls Quit.
307class SctpMismatchedAuthTest : public SctpTest {
308 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700309 SctpMismatchedAuthTest()
310 : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
311 SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700312 // Start by having the client send "ping".
313 client_.Send(0, "ping", 0);
314 }
315
316 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
317 FAIL() << "Authentication keys don't match. Message should be discarded";
318 Quit();
319 }
320
321 // We expect to time out since we never get the message.
322 void TimeOut() override { Quit(); }
323};
324
325TEST_F(SctpMismatchedAuthTest, Test) {}
326
327// Same as SctpMismatchedAuthTest but the client uses the null key. We should
328// see the same behaviour.
329class SctpOneNullKeyTest : public SctpTest {
330 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700331 SctpOneNullKeyTest()
332 : SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700333 // Start by having the client send "ping".
334 client_.Send(0, "ping", 0);
335 }
336
337 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
338 FAIL() << "Authentication keys don't match. Message should be discarded";
339 Quit();
340 }
341
342 // We expect to time out since we never get the message.
343 void TimeOut() override { Quit(); }
344};
345
346TEST_F(SctpOneNullKeyTest, Test) {}
Adam Snaider9bb33442023-06-26 16:31:37 -0700347
348// If we want SCTP authentication but we don't set the auth keys, we shouldn't
349// be able to send data.
350class SctpAuthKeysNotSet : public SctpTest {
351 public:
352 SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
353 // Start by having the client send "ping".
354 client_.Send(0, "ping", 0);
355 }
356
357 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
358 FAIL() << "Haven't setup authentication keys. Should not get message.";
359 Quit();
360 }
361
362 // We expect to time out since we never get the message.
363 void TimeOut() override { Quit(); }
364};
365
366TEST_F(SctpAuthKeysNotSet, Test) {}
367
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700368#endif // HAS_SCTP_AUTH
369
370} // namespace aos::message_bridge::testing