blob: d202e5d2c2f208b461220204805359a75295d83d [file] [log] [blame]
Adam Snaider96a0f4b2023-05-18 20:41:19 -07001#include <unistd.h>
2
3#include <chrono>
4#include <functional>
5
Austin Schuh99f7c6a2024-06-25 22:07:44 -07006#include "absl/flags/declare.h"
7#include "absl/flags/flag.h"
Adam Snaider96a0f4b2023-05-18 20:41:19 -07008#include "gmock/gmock-matchers.h"
9#include "gtest/gtest.h"
10
11#include "aos/events/epoll.h"
12#include "aos/network/sctp_client.h"
13#include "aos/network/sctp_lib.h"
14#include "aos/network/sctp_server.h"
Adam Snaider930bf7f2023-09-22 10:57:21 -070015#include "sctp_lib.h"
Adam Snaider96a0f4b2023-05-18 20:41:19 -070016
Austin Schuh99f7c6a2024-06-25 22:07:44 -070017ABSL_DECLARE_FLAG(bool, disable_ipv6);
Adam Snaider96a0f4b2023-05-18 20:41:19 -070018
19namespace aos::message_bridge::testing {
20
21using ::aos::internal::EPoll;
22using ::aos::internal::TimerFd;
23using ::testing::ElementsAre;
24
25using namespace ::std::chrono_literals;
26
27constexpr int kPort = 19423;
28constexpr int kStreams = 1;
29
30namespace {
31void EnableSctpAuthIfAvailable() {
32#if HAS_SCTP_AUTH
Adam Snaider930bf7f2023-09-22 10:57:21 -070033 // Open an SCTP socket to bring the kernel SCTP module
34 SctpServer server(1, "localhost");
35 CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || "
36 "/sbin/sysctl net.sctp.auth_enable=1") == 0)
Adam Snaider96a0f4b2023-05-18 20:41:19 -070037 << "Couldn't enable sctp authentication.";
38#endif
39}
40} // namespace
41
42// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
43// SctpClient), and an `sctp_notification` handler and a `message` handler. It
44// asynchronously routes incoming messages to the appropriate handler.
45template <typename T>
46class SctpReceiver {
47 public:
48 SctpReceiver(
49 EPoll &epoll, T &receiver,
50 std::function<void(T &, const union sctp_notification *)> on_notify,
51 std::function<void(T &, std::vector<uint8_t>)> on_message)
52 : epoll_(epoll),
53 receiver_(receiver),
54 on_notify_(std::move(on_notify)),
55 on_message_(std::move(on_message)) {
56 epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
57 }
58
59 ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
60
61 private:
62 // Handles an incoming message by routing it to the apropriate handler.
63 void Read() {
64 aos::unique_c_ptr<Message> message = receiver_.Read();
65 if (!message) {
66 return;
67 }
68
69 switch (message->message_type) {
70 case Message::kNotification: {
71 const union sctp_notification *notification =
72 reinterpret_cast<const union sctp_notification *>(message->data());
73 on_notify_(receiver_, notification);
74 break;
75 }
76 case Message::kMessage:
77 on_message_(receiver_, std::vector(message->data(),
78 message->data() + message->size));
79 break;
80 case Message::kOverflow:
81 LOG(FATAL) << "Overflow";
82 }
83 receiver_.FreeMessage(std::move(message));
84 }
85
86 EPoll &epoll_;
87 T &receiver_;
88 std::function<void(T &, const union sctp_notification *)> on_notify_;
89 std::function<void(T &, std::vector<uint8_t>)> on_message_;
90};
91
92// Base SctpTest class.
93//
94// The class provides a few virtual methods that should be overriden to define
95// the behavior of the test.
96class SctpTest : public ::testing::Test {
97 public:
98 SctpTest(std::vector<uint8_t> server_key = {},
99 std::vector<uint8_t> client_key = {},
Adam Snaider9bb33442023-06-26 16:31:37 -0700100 SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700101 std::chrono::milliseconds timeout = 1000ms)
Adam Snaider9bb33442023-06-26 16:31:37 -0700102 : server_(kStreams, "", kPort, requested_authentication),
103 client_("localhost", kPort, kStreams, "", 0, requested_authentication),
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700104 client_receiver_(
105 epoll_, client_,
106 [this](SctpClient &client,
107 const union sctp_notification *notification) {
108 HandleNotification(client, notification);
109 },
110 [this](SctpClient &client, std::vector<uint8_t> message) {
111 HandleMessage(client, std::move(message));
112 }),
113 server_receiver_(
114 epoll_, server_,
115 [this](SctpServer &server,
116 const union sctp_notification *notification) {
117 HandleNotification(server, notification);
118 },
119 [this](SctpServer &server, std::vector<uint8_t> message) {
120 HandleMessage(server, std::move(message));
121 }) {
Adam Snaider9bb33442023-06-26 16:31:37 -0700122 server_.SetAuthKey(server_key);
123 client_.SetAuthKey(client_key);
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700124 timeout_.SetTime(aos::monotonic_clock::now() + timeout,
125 std::chrono::milliseconds::zero());
126 epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
127 }
128
129 static void SetUpTestSuite() {
130 EnableSctpAuthIfAvailable();
131 // Buildkite seems to have issues with ipv6 sctp sockets...
Austin Schuh99f7c6a2024-06-25 22:07:44 -0700132 absl::SetFlag(&FLAGS_disable_ipv6, true);
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700133 }
134
135 void SetUp() override { Run(); }
136
137 protected:
138 // Handles a server notification message.
139 //
140 // The default behaviour is to track the sctp association ID.
141 virtual void HandleNotification(SctpServer &,
142 const union sctp_notification *notification) {
143 if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
144 assoc_ = notification->sn_assoc_change.sac_assoc_id;
145 }
146 }
147
148 // Handles the client notification message.
149 virtual void HandleNotification(SctpClient &,
150 const union sctp_notification *) {}
151
152 // Handles a server "data" message.
153 virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
154 // Handles a client "data" message.
155 virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
156
157 // Defines the "timeout" behaviour (fail by default).
158 virtual void TimeOut() {
159 Quit();
160 FAIL() << "Timer expired";
161 }
162
163 virtual ~SctpTest() {}
164
165 // Quit the test.
166 void Quit() {
167 epoll_.DeleteFd(timeout_.fd());
168 epoll_.Quit();
169 }
170 void Run() { epoll_.Run(); }
171
172 SctpServer server_;
173 SctpClient client_;
174 sctp_assoc_t assoc_ = 0;
175
176 private:
177 TimerFd timeout_;
178 EPoll epoll_;
179 SctpReceiver<SctpClient> client_receiver_;
180 SctpReceiver<SctpServer> server_receiver_;
181};
182
183// Verifies we can ping the server, and the server replies.
184class SctpPingPongTest : public SctpTest {
185 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700186 SctpPingPongTest()
187 : SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700188 // Start by having the client send "ping".
189 client_.Send(0, "ping", 0);
190 }
191
192 void HandleMessage(SctpServer &server,
193 std::vector<uint8_t> message) override {
194 // Server should receive a ping message.
195 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
196 got_ping_ = true;
197 ASSERT_NE(assoc_, 0);
198 // Reply with "pong".
199 server.Send("pong", assoc_, 0, 0);
200 }
201
202 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
203 // Client should receive a "pong" message.
204 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
205 got_pong_ = true;
206 // We are done.
207 Quit();
208 }
209 ~SctpPingPongTest() {
210 // Check that we got the ping/pong messages.
211 // This isn't strictly necessary as otherwise we would time out and fail
212 // anyway.
213 EXPECT_TRUE(got_ping_);
214 EXPECT_TRUE(got_pong_);
215 }
216
217 protected:
218 bool got_ping_ = false;
219 bool got_pong_ = false;
220};
221
222TEST_F(SctpPingPongTest, Test) {}
223
224#if HAS_SCTP_AUTH
225
226// Same as SctpPingPongTest but with authentication keys. Both keys are the
227// same so it should work the same way.
228class SctpAuthTest : public SctpTest {
229 public:
230 SctpAuthTest()
Adam Snaider9bb33442023-06-26 16:31:37 -0700231 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700232 /*timeout*/ 20s) {
233 // Start by having the client send "ping".
234 client_.Send(0, "ping", 0);
235 }
236
237 void HandleMessage(SctpServer &server,
238 std::vector<uint8_t> message) override {
239 // Server should receive a ping message.
240 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
241 got_ping_ = true;
242 ASSERT_NE(assoc_, 0);
243 // Reply with "pong".
244 server.Send("pong", assoc_, 0, 0);
245 }
246 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
247 // Client should receive a "pong" message.
248 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
249 got_pong_ = true;
250 // We are done.
251 Quit();
252 }
253 ~SctpAuthTest() {
254 EXPECT_TRUE(got_ping_);
255 EXPECT_TRUE(got_pong_);
256 }
257
258 protected:
259 bool got_ping_ = false;
260 bool got_pong_ = false;
261};
262
263TEST_F(SctpAuthTest, Test) {}
264
Adam Snaider9bb33442023-06-26 16:31:37 -0700265// Tests that we can dynamically change the SCTP authentication key used.
266class SctpChangingAuthKeysTest : public SctpTest {
267 public:
268 SctpChangingAuthKeysTest()
269 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
270 SctpAuthMethod::kAuth) {
271 // Start by having the client send "ping".
272 client_.SetAuthKey({5, 4, 3, 2, 1});
273 server_.SetAuthKey({5, 4, 3, 2, 1});
274 client_.Send(0, "ping", 0);
275 }
276
277 void HandleMessage(SctpServer &server,
278 std::vector<uint8_t> message) override {
279 // Server should receive a ping message.
280 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
281 got_ping_ = true;
282 ASSERT_NE(assoc_, 0);
283 // Reply with "pong".
284 server.Send("pong", assoc_, 0, 0);
285 }
286 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
287 // Client should receive a "pong" message.
288 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
289 got_pong_ = true;
290 // We are done.
291 Quit();
292 }
293
294 ~SctpChangingAuthKeysTest() {
295 EXPECT_TRUE(got_ping_);
296 EXPECT_TRUE(got_pong_);
297 }
298
299 protected:
300 bool got_ping_ = false;
301 bool got_pong_ = false;
302};
303
304TEST_F(SctpChangingAuthKeysTest, Test) {}
305
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700306// Keys don't match, we should send the `ping` message but the server should
307// never receive it. We then time out as nothing calls Quit.
308class SctpMismatchedAuthTest : public SctpTest {
309 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700310 SctpMismatchedAuthTest()
311 : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
312 SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700313 // Start by having the client send "ping".
314 client_.Send(0, "ping", 0);
315 }
316
317 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
318 FAIL() << "Authentication keys don't match. Message should be discarded";
319 Quit();
320 }
321
322 // We expect to time out since we never get the message.
323 void TimeOut() override { Quit(); }
324};
325
326TEST_F(SctpMismatchedAuthTest, Test) {}
327
328// Same as SctpMismatchedAuthTest but the client uses the null key. We should
329// see the same behaviour.
330class SctpOneNullKeyTest : public SctpTest {
331 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700332 SctpOneNullKeyTest()
333 : SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700334 // Start by having the client send "ping".
335 client_.Send(0, "ping", 0);
336 }
337
338 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
339 FAIL() << "Authentication keys don't match. Message should be discarded";
340 Quit();
341 }
342
343 // We expect to time out since we never get the message.
344 void TimeOut() override { Quit(); }
345};
346
347TEST_F(SctpOneNullKeyTest, Test) {}
Adam Snaider9bb33442023-06-26 16:31:37 -0700348
349// If we want SCTP authentication but we don't set the auth keys, we shouldn't
350// be able to send data.
351class SctpAuthKeysNotSet : public SctpTest {
352 public:
353 SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
354 // Start by having the client send "ping".
355 client_.Send(0, "ping", 0);
356 }
357
358 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
359 FAIL() << "Haven't setup authentication keys. Should not get message.";
360 Quit();
361 }
362
363 // We expect to time out since we never get the message.
364 void TimeOut() override { Quit(); }
365};
366
367TEST_F(SctpAuthKeysNotSet, Test) {}
368
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700369#endif // HAS_SCTP_AUTH
370
371} // namespace aos::message_bridge::testing