blob: 8e332e4a355b7dff9eee10eef30eff5d3de3a423 [file] [log] [blame]
Adam Snaider96a0f4b2023-05-18 20:41:19 -07001#include <unistd.h>
2
3#include <chrono>
4#include <functional>
5
6#include "gflags/gflags.h"
7#include "gmock/gmock-matchers.h"
8#include "gtest/gtest.h"
9
10#include "aos/events/epoll.h"
11#include "aos/network/sctp_client.h"
12#include "aos/network/sctp_lib.h"
13#include "aos/network/sctp_server.h"
14
15DECLARE_bool(disable_ipv6);
16
17namespace aos::message_bridge::testing {
18
19using ::aos::internal::EPoll;
20using ::aos::internal::TimerFd;
21using ::testing::ElementsAre;
22
23using namespace ::std::chrono_literals;
24
25constexpr int kPort = 19423;
26constexpr int kStreams = 1;
27
28namespace {
29void EnableSctpAuthIfAvailable() {
30#if HAS_SCTP_AUTH
31 CHECK(system("/usr/sbin/sysctl net.sctp.auth_enable=1 || /sbin/sysctl "
32 "net.sctp.auth_enable=1") == 0)
33 << "Couldn't enable sctp authentication.";
34#endif
35}
36} // namespace
37
38// An asynchronous SCTP handler. It takes an SCTP receiver (a.k.a SctpServer or
39// SctpClient), and an `sctp_notification` handler and a `message` handler. It
40// asynchronously routes incoming messages to the appropriate handler.
41template <typename T>
42class SctpReceiver {
43 public:
44 SctpReceiver(
45 EPoll &epoll, T &receiver,
46 std::function<void(T &, const union sctp_notification *)> on_notify,
47 std::function<void(T &, std::vector<uint8_t>)> on_message)
48 : epoll_(epoll),
49 receiver_(receiver),
50 on_notify_(std::move(on_notify)),
51 on_message_(std::move(on_message)) {
52 epoll_.OnReadable(receiver_.fd(), [this]() { Read(); });
53 }
54
55 ~SctpReceiver() { epoll_.DeleteFd(receiver_.fd()); }
56
57 private:
58 // Handles an incoming message by routing it to the apropriate handler.
59 void Read() {
60 aos::unique_c_ptr<Message> message = receiver_.Read();
61 if (!message) {
62 return;
63 }
64
65 switch (message->message_type) {
66 case Message::kNotification: {
67 const union sctp_notification *notification =
68 reinterpret_cast<const union sctp_notification *>(message->data());
69 on_notify_(receiver_, notification);
70 break;
71 }
72 case Message::kMessage:
73 on_message_(receiver_, std::vector(message->data(),
74 message->data() + message->size));
75 break;
76 case Message::kOverflow:
77 LOG(FATAL) << "Overflow";
78 }
79 receiver_.FreeMessage(std::move(message));
80 }
81
82 EPoll &epoll_;
83 T &receiver_;
84 std::function<void(T &, const union sctp_notification *)> on_notify_;
85 std::function<void(T &, std::vector<uint8_t>)> on_message_;
86};
87
88// Base SctpTest class.
89//
90// The class provides a few virtual methods that should be overriden to define
91// the behavior of the test.
92class SctpTest : public ::testing::Test {
93 public:
94 SctpTest(std::vector<uint8_t> server_key = {},
95 std::vector<uint8_t> client_key = {},
Adam Snaider9bb33442023-06-26 16:31:37 -070096 SctpAuthMethod requested_authentication = SctpAuthMethod::kNoAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -070097 std::chrono::milliseconds timeout = 1000ms)
Adam Snaider9bb33442023-06-26 16:31:37 -070098 : server_(kStreams, "", kPort, requested_authentication),
99 client_("localhost", kPort, kStreams, "", 0, requested_authentication),
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700100 client_receiver_(
101 epoll_, client_,
102 [this](SctpClient &client,
103 const union sctp_notification *notification) {
104 HandleNotification(client, notification);
105 },
106 [this](SctpClient &client, std::vector<uint8_t> message) {
107 HandleMessage(client, std::move(message));
108 }),
109 server_receiver_(
110 epoll_, server_,
111 [this](SctpServer &server,
112 const union sctp_notification *notification) {
113 HandleNotification(server, notification);
114 },
115 [this](SctpServer &server, std::vector<uint8_t> message) {
116 HandleMessage(server, std::move(message));
117 }) {
Adam Snaider9bb33442023-06-26 16:31:37 -0700118 server_.SetAuthKey(server_key);
119 client_.SetAuthKey(client_key);
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700120 timeout_.SetTime(aos::monotonic_clock::now() + timeout,
121 std::chrono::milliseconds::zero());
122 epoll_.OnReadable(timeout_.fd(), [this]() { TimeOut(); });
123 }
124
125 static void SetUpTestSuite() {
126 EnableSctpAuthIfAvailable();
127 // Buildkite seems to have issues with ipv6 sctp sockets...
128 FLAGS_disable_ipv6 = true;
129 }
130
131 void SetUp() override { Run(); }
132
133 protected:
134 // Handles a server notification message.
135 //
136 // The default behaviour is to track the sctp association ID.
137 virtual void HandleNotification(SctpServer &,
138 const union sctp_notification *notification) {
139 if (notification->sn_header.sn_type == SCTP_ASSOC_CHANGE) {
140 assoc_ = notification->sn_assoc_change.sac_assoc_id;
141 }
142 }
143
144 // Handles the client notification message.
145 virtual void HandleNotification(SctpClient &,
146 const union sctp_notification *) {}
147
148 // Handles a server "data" message.
149 virtual void HandleMessage(SctpServer &, std::vector<uint8_t>) {}
150 // Handles a client "data" message.
151 virtual void HandleMessage(SctpClient &, std::vector<uint8_t>) {}
152
153 // Defines the "timeout" behaviour (fail by default).
154 virtual void TimeOut() {
155 Quit();
156 FAIL() << "Timer expired";
157 }
158
159 virtual ~SctpTest() {}
160
161 // Quit the test.
162 void Quit() {
163 epoll_.DeleteFd(timeout_.fd());
164 epoll_.Quit();
165 }
166 void Run() { epoll_.Run(); }
167
168 SctpServer server_;
169 SctpClient client_;
170 sctp_assoc_t assoc_ = 0;
171
172 private:
173 TimerFd timeout_;
174 EPoll epoll_;
175 SctpReceiver<SctpClient> client_receiver_;
176 SctpReceiver<SctpServer> server_receiver_;
177};
178
179// Verifies we can ping the server, and the server replies.
180class SctpPingPongTest : public SctpTest {
181 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700182 SctpPingPongTest()
183 : SctpTest({}, {}, SctpAuthMethod::kNoAuth, /*timeout=*/2s) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700184 // Start by having the client send "ping".
185 client_.Send(0, "ping", 0);
186 }
187
188 void HandleMessage(SctpServer &server,
189 std::vector<uint8_t> message) override {
190 // Server should receive a ping message.
191 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
192 got_ping_ = true;
193 ASSERT_NE(assoc_, 0);
194 // Reply with "pong".
195 server.Send("pong", assoc_, 0, 0);
196 }
197
198 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
199 // Client should receive a "pong" message.
200 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
201 got_pong_ = true;
202 // We are done.
203 Quit();
204 }
205 ~SctpPingPongTest() {
206 // Check that we got the ping/pong messages.
207 // This isn't strictly necessary as otherwise we would time out and fail
208 // anyway.
209 EXPECT_TRUE(got_ping_);
210 EXPECT_TRUE(got_pong_);
211 }
212
213 protected:
214 bool got_ping_ = false;
215 bool got_pong_ = false;
216};
217
218TEST_F(SctpPingPongTest, Test) {}
219
220#if HAS_SCTP_AUTH
221
222// Same as SctpPingPongTest but with authentication keys. Both keys are the
223// same so it should work the same way.
224class SctpAuthTest : public SctpTest {
225 public:
226 SctpAuthTest()
Adam Snaider9bb33442023-06-26 16:31:37 -0700227 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, SctpAuthMethod::kAuth,
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700228 /*timeout*/ 20s) {
229 // Start by having the client send "ping".
230 client_.Send(0, "ping", 0);
231 }
232
233 void HandleMessage(SctpServer &server,
234 std::vector<uint8_t> message) override {
235 // Server should receive a ping message.
236 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
237 got_ping_ = true;
238 ASSERT_NE(assoc_, 0);
239 // Reply with "pong".
240 server.Send("pong", assoc_, 0, 0);
241 }
242 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
243 // Client should receive a "pong" message.
244 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
245 got_pong_ = true;
246 // We are done.
247 Quit();
248 }
249 ~SctpAuthTest() {
250 EXPECT_TRUE(got_ping_);
251 EXPECT_TRUE(got_pong_);
252 }
253
254 protected:
255 bool got_ping_ = false;
256 bool got_pong_ = false;
257};
258
259TEST_F(SctpAuthTest, Test) {}
260
Adam Snaider9bb33442023-06-26 16:31:37 -0700261// Tests that we can dynamically change the SCTP authentication key used.
262class SctpChangingAuthKeysTest : public SctpTest {
263 public:
264 SctpChangingAuthKeysTest()
265 : SctpTest({1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
266 SctpAuthMethod::kAuth) {
267 // Start by having the client send "ping".
268 client_.SetAuthKey({5, 4, 3, 2, 1});
269 server_.SetAuthKey({5, 4, 3, 2, 1});
270 client_.Send(0, "ping", 0);
271 }
272
273 void HandleMessage(SctpServer &server,
274 std::vector<uint8_t> message) override {
275 // Server should receive a ping message.
276 EXPECT_THAT(message, ElementsAre('p', 'i', 'n', 'g'));
277 got_ping_ = true;
278 ASSERT_NE(assoc_, 0);
279 // Reply with "pong".
280 server.Send("pong", assoc_, 0, 0);
281 }
282 void HandleMessage(SctpClient &, std::vector<uint8_t> message) override {
283 // Client should receive a "pong" message.
284 EXPECT_THAT(message, ElementsAre('p', 'o', 'n', 'g'));
285 got_pong_ = true;
286 // We are done.
287 Quit();
288 }
289
290 ~SctpChangingAuthKeysTest() {
291 EXPECT_TRUE(got_ping_);
292 EXPECT_TRUE(got_pong_);
293 }
294
295 protected:
296 bool got_ping_ = false;
297 bool got_pong_ = false;
298};
299
300TEST_F(SctpChangingAuthKeysTest, Test) {}
301
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700302// Keys don't match, we should send the `ping` message but the server should
303// never receive it. We then time out as nothing calls Quit.
304class SctpMismatchedAuthTest : public SctpTest {
305 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700306 SctpMismatchedAuthTest()
307 : SctpTest({1, 2, 3, 4, 5, 6}, {5, 6, 7, 8, 9, 10},
308 SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700309 // Start by having the client send "ping".
310 client_.Send(0, "ping", 0);
311 }
312
313 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
314 FAIL() << "Authentication keys don't match. Message should be discarded";
315 Quit();
316 }
317
318 // We expect to time out since we never get the message.
319 void TimeOut() override { Quit(); }
320};
321
322TEST_F(SctpMismatchedAuthTest, Test) {}
323
324// Same as SctpMismatchedAuthTest but the client uses the null key. We should
325// see the same behaviour.
326class SctpOneNullKeyTest : public SctpTest {
327 public:
Adam Snaider9bb33442023-06-26 16:31:37 -0700328 SctpOneNullKeyTest()
329 : SctpTest({1, 2, 3, 4, 5, 6}, {}, SctpAuthMethod::kAuth) {
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700330 // Start by having the client send "ping".
331 client_.Send(0, "ping", 0);
332 }
333
334 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
335 FAIL() << "Authentication keys don't match. Message should be discarded";
336 Quit();
337 }
338
339 // We expect to time out since we never get the message.
340 void TimeOut() override { Quit(); }
341};
342
343TEST_F(SctpOneNullKeyTest, Test) {}
Adam Snaider9bb33442023-06-26 16:31:37 -0700344
345// If we want SCTP authentication but we don't set the auth keys, we shouldn't
346// be able to send data.
347class SctpAuthKeysNotSet : public SctpTest {
348 public:
349 SctpAuthKeysNotSet() : SctpTest({}, {}, SctpAuthMethod::kAuth) {
350 // Start by having the client send "ping".
351 client_.Send(0, "ping", 0);
352 }
353
354 void HandleMessage(SctpServer &, std::vector<uint8_t>) override {
355 FAIL() << "Haven't setup authentication keys. Should not get message.";
356 Quit();
357 }
358
359 // We expect to time out since we never get the message.
360 void TimeOut() override { Quit(); }
361};
362
363TEST_F(SctpAuthKeysNotSet, Test) {}
364
Adam Snaider96a0f4b2023-05-18 20:41:19 -0700365#endif // HAS_SCTP_AUTH
366
367} // namespace aos::message_bridge::testing