blob: 458d2d1d795d8d815e7a89690df5d507dab9269b [file] [log] [blame]
Adam Snaider96a0f4b2023-05-18 20:41:19 -07001#include <unistd.h>
2
3#include <chrono>
4#include <functional>
5
6#include "gflags/gflags.h"
7#include "gmock/gmock-matchers.h"
8#include "gtest/gtest.h"
9
10#include "aos/events/epoll.h"
11#include "aos/network/sctp_client.h"
12#include "aos/network/sctp_lib.h"
13#include "aos/network/sctp_server.h"
14
15DECLARE_bool(disable_ipv6);
16
17namespace aos::message_bridge::testing {
18
19using ::aos::internal::EPoll;
20using ::aos::internal::TimerFd;
21using ::testing::ElementsAre;
22
23using namespace ::std::chrono_literals;
24
25constexpr int kPort = 19423;
26constexpr int kStreams = 1;
27
28namespace {
29void EnableSctpAuthIfAvailable() {
30#if HAS_SCTP_AUTH
Adam Snaider04cc67a2023-09-11 09:33:45 -070031 CHECK(system("/usr/sbin/modprobe sctp || /sbin/modprobe sctp; "
32 "/usr/sbin/sysctl net.sctp.auth_enable=1 || "
33 "/sbin/sysctl "
Adam Snaider96a0f4b2023-05-18 20:41:19 -070034 "net.sctp.auth_enable=1") == 0)
35 << "Couldn't enable sctp authentication.";
36#endif
37}
38} // namespace
39
40// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
41// SctpClient), and an `sctp_notification` handler and a `message` handler. It
42// asynchronously routes incoming messages to the appropriate handler.
43template <typename T>
44class SctpReceiver {
45 public:
46 SctpReceiver(
47 EPoll &epoll, T &receiver,
48 std::function<void(T &, const union sctp_notification *)> on_notify,
49 std::function<void(T &, std::vector<uint8_t>)> on_message)
50 : epoll_(epoll),
51 receiver_(receiver),
52 on_notify_(std::move(on_notify)),
53 on_message_(std::move(on_message)) {
54 epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
55 }
56
57 ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
58
59 private:
60 // Handles an incoming message by routing it to the apropriate handler.
61 void Read() {
62 aos::unique_c_ptr<Message> message = receiver_.Read();
63 if (!message) {
64 return;
65 }
66
67 switch (message->message_type) {
68 case Message::kNotification: {
69 const union sctp_notification *notification =
70 reinterpret_cast<const union sctp_notification *>(message->data());
71 on_notify_(receiver_, notification);
72 break;
73 }
74 case Message::kMessage:
75 on_message_(receiver_, std::vector(message->data(),
76 message->data() + message->size));
77 break;
78 case Message::kOverflow:
79 LOG(FATAL) << "Overflow";
80 }
81 receiver_.FreeMessage(std::move(message));
82 }
83
84 EPoll &epoll_;
85 T &receiver_;
86 std::function<void(T &, const union sctp_notification *)> on_notify_;
87 std::function<void(T &, std::vector<uint8_t>)> on_message_;
88};
89
90// Base SctpTest class.
91//
92// The class provides a few virtual methods that should be overriden to define
93// the behavior of the test.
94class SctpTest : public ::testing::Test {
95 public:
96 SctpTest(std::vector<uint8_t> server_key = {},
97 std::vector<uint8_t> client_key = {},
Adam Snaider9bb33442023-06-26 16:31:37 -070098 SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -070099 std::chrono::milliseconds timeout = 1000ms)
Adam Snaider9bb33442023-06-26 16:31:37 -0700100 : server_(kStreams, "", kPort, requested_authentication),
101 client_("localhost", kPort, kStreams, "", 0, requested_authentication),
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700102 client_receiver_(
103 epoll_, client_,
104 [this](SctpClient &client,
105 const union sctp_notification *notification) {
106 HandleNotification(client, notification);
107 },
108 [this](SctpClient &client, std::vector<uint8_t> message) {
109 HandleMessage(client, std::move(message));
110 }),
111 server_receiver_(
112 epoll_, server_,
113 [this](SctpServer &server,
114 const union sctp_notification *notification) {
115 HandleNotification(server, notification);
116 },
117 [this](SctpServer &server, std::vector<uint8_t> message) {
118 HandleMessage(server, std::move(message));
119 }) {
Adam Snaider9bb33442023-06-26 16:31:37 -0700120 server_.SetAuthKey(server_key);
121 client_.SetAuthKey(client_key);
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700122 timeout_.SetTime(aos::monotonic_clock::now() + timeout,
123 std::chrono::milliseconds::zero());
124 epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
125 }
126
127 static void SetUpTestSuite() {
128 EnableSctpAuthIfAvailable();
129 // Buildkite seems to have issues with ipv6 sctp sockets...
130 FLAGS_disable_ipv6 = true;
131 }
132
133 void SetUp() override { Run(); }
134
135 protected:
136 // Handles a server notification message.
137 //
138 // The default behaviour is to track the sctp association ID.
139 virtual void HandleNotification(SctpServer &,
140 const union sctp_notification *notification) {
141 if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
142 assoc_ = notification->sn_assoc_change.sac_assoc_id;
143 }
144 }
145
146 // Handles the client notification message.
147 virtual void HandleNotification(SctpClient &,
148 const union sctp_notification *) {}
149
150 // Handles a server "data" message.
151 virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
152 // Handles a client "data" message.
153 virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
154
155 // Defines the "timeout" behaviour (fail by default).
156 virtual void TimeOut() {
157 Quit();
158 FAIL() << "Timer expired";
159 }
160
161 virtual ~SctpTest() {}
162
163 // Quit the test.
164 void Quit() {
165 epoll_.DeleteFd(timeout_.fd());
166 epoll_.Quit();
167 }
168 void Run() { epoll_.Run(); }
169
170 SctpServer server_;
171 SctpClient client_;
172 sctp_assoc_t assoc_ = 0;
173
174 private:
175 TimerFd timeout_;
176 EPoll epoll_;
177 SctpReceiver<SctpClient> client_receiver_;
178 SctpReceiver<SctpServer> server_receiver_;
179};
180
181// Verifies we can ping the server, and the server replies.
182class SctpPingPongTest : public SctpTest {
183 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700184 SctpPingPongTest()
185 : SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700186 // Start by having the client send "ping".
187 client_.Send(0, "ping", 0);
188 }
189
190 void HandleMessage(SctpServer &server,
191 std::vector<uint8_t> message) override {
192 // Server should receive a ping message.
193 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
194 got_ping_ = true;
195 ASSERT_NE(assoc_, 0);
196 // Reply with "pong".
197 server.Send("pong", assoc_, 0, 0);
198 }
199
200 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
201 // Client should receive a "pong" message.
202 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
203 got_pong_ = true;
204 // We are done.
205 Quit();
206 }
207 ~SctpPingPongTest() {
208 // Check that we got the ping/pong messages.
209 // This isn't strictly necessary as otherwise we would time out and fail
210 // anyway.
211 EXPECT_TRUE(got_ping_);
212 EXPECT_TRUE(got_pong_);
213 }
214
215 protected:
216 bool got_ping_ = false;
217 bool got_pong_ = false;
218};
219
220TEST_F(SctpPingPongTest, Test) {}
221
222#if HAS_SCTP_AUTH
223
224// Same as SctpPingPongTest but with authentication keys. Both keys are the
225// same so it should work the same way.
226class SctpAuthTest : public SctpTest {
227 public:
228 SctpAuthTest()
Adam Snaider9bb33442023-06-26 16:31:37 -0700229 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700230 /*timeout*/ 20s) {
231 // Start by having the client send "ping".
232 client_.Send(0, "ping", 0);
233 }
234
235 void HandleMessage(SctpServer &server,
236 std::vector<uint8_t> message) override {
237 // Server should receive a ping message.
238 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
239 got_ping_ = true;
240 ASSERT_NE(assoc_, 0);
241 // Reply with "pong".
242 server.Send("pong", assoc_, 0, 0);
243 }
244 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
245 // Client should receive a "pong" message.
246 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
247 got_pong_ = true;
248 // We are done.
249 Quit();
250 }
251 ~SctpAuthTest() {
252 EXPECT_TRUE(got_ping_);
253 EXPECT_TRUE(got_pong_);
254 }
255
256 protected:
257 bool got_ping_ = false;
258 bool got_pong_ = false;
259};
260
261TEST_F(SctpAuthTest, Test) {}
262
Adam Snaider9bb33442023-06-26 16:31:37 -0700263// Tests that we can dynamically change the SCTP authentication key used.
264class SctpChangingAuthKeysTest : public SctpTest {
265 public:
266 SctpChangingAuthKeysTest()
267 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
268 SctpAuthMethod::kAuth) {
269 // Start by having the client send "ping".
270 client_.SetAuthKey({5, 4, 3, 2, 1});
271 server_.SetAuthKey({5, 4, 3, 2, 1});
272 client_.Send(0, "ping", 0);
273 }
274
275 void HandleMessage(SctpServer &server,
276 std::vector<uint8_t> message) override {
277 // Server should receive a ping message.
278 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
279 got_ping_ = true;
280 ASSERT_NE(assoc_, 0);
281 // Reply with "pong".
282 server.Send("pong", assoc_, 0, 0);
283 }
284 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
285 // Client should receive a "pong" message.
286 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
287 got_pong_ = true;
288 // We are done.
289 Quit();
290 }
291
292 ~SctpChangingAuthKeysTest() {
293 EXPECT_TRUE(got_ping_);
294 EXPECT_TRUE(got_pong_);
295 }
296
297 protected:
298 bool got_ping_ = false;
299 bool got_pong_ = false;
300};
301
302TEST_F(SctpChangingAuthKeysTest, Test) {}
303
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700304// Keys don't match, we should send the `ping` message but the server should
305// never receive it. We then time out as nothing calls Quit.
306class SctpMismatchedAuthTest : public SctpTest {
307 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700308 SctpMismatchedAuthTest()
309 : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
310 SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700311 // Start by having the client send "ping".
312 client_.Send(0, "ping", 0);
313 }
314
315 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
316 FAIL() << "Authentication keys don't match. Message should be discarded";
317 Quit();
318 }
319
320 // We expect to time out since we never get the message.
321 void TimeOut() override { Quit(); }
322};
323
324TEST_F(SctpMismatchedAuthTest, Test) {}
325
326// Same as SctpMismatchedAuthTest but the client uses the null key. We should
327// see the same behaviour.
328class SctpOneNullKeyTest : public SctpTest {
329 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700330 SctpOneNullKeyTest()
331 : SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700332 // Start by having the client send "ping".
333 client_.Send(0, "ping", 0);
334 }
335
336 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
337 FAIL() << "Authentication keys don't match. Message should be discarded";
338 Quit();
339 }
340
341 // We expect to time out since we never get the message.
342 void TimeOut() override { Quit(); }
343};
344
345TEST_F(SctpOneNullKeyTest, Test) {}
Adam Snaider9bb33442023-06-26 16:31:37 -0700346
347// If we want SCTP authentication but we don't set the auth keys, we shouldn't
348// be able to send data.
349class SctpAuthKeysNotSet : public SctpTest {
350 public:
351 SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
352 // Start by having the client send "ping".
353 client_.Send(0, "ping", 0);
354 }
355
356 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
357 FAIL() << "Haven't setup authentication keys. Should not get message.";
358 Quit();
359 }
360
361 // We expect to time out since we never get the message.
362 void TimeOut() override { Quit(); }
363};
364
365TEST_F(SctpAuthKeysNotSet, Test) {}
366
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700367#endif // HAS_SCTP_AUTH
368
369} // namespace aos::message_bridge::testing