blob: 40d1631882a99402669e8e97b842d415f3fbbeb1 [file] [log] [blame]
Austin Schuhfaec51a2023-09-08 17:43:32 -07001#include "aos/ipc_lib/lockless_queue_stepping.h"
2
3#include <dlfcn.h>
4#include <elf.h>
5#include <linux/futex.h>
6#include <sys/mman.h>
7#include <sys/procfs.h>
8#include <sys/ptrace.h>
9#include <sys/syscall.h>
10#include <sys/uio.h>
11#include <unistd.h>
12#include <wait.h>
13
14#include <memory>
15#include <thread>
16
17#include "glog/logging.h"
18#include "gtest/gtest.h"
19
20#include "aos/ipc_lib/aos_sync.h"
21#include "aos/ipc_lib/lockless_queue_memory.h"
22#include "aos/ipc_lib/shm_observers.h"
23#include "aos/libc/aos_strsignal.h"
24#include "aos/testing/prevent_exit.h"
25
26#ifdef SUPPORTS_SHM_ROBUSTNESS_TEST
27
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -080028namespace aos::ipc_lib::testing {
Austin Schuhfaec51a2023-09-08 17:43:32 -070029
30namespace {
31pid_t gettid() { return syscall(SYS_gettid); }
32
33::std::atomic<GlobalState *> global_state;
34
35struct sigaction old_segv_handler, old_trap_handler;
36
37// Calls the original signal handler.
38bool CallChainedAction(const struct sigaction &action, int signal,
39 siginfo_t *siginfo, void *context) {
40 if (action.sa_handler == SIG_IGN || action.sa_handler == SIG_DFL) {
41 return false;
42 }
43 if (action.sa_flags & SA_SIGINFO) {
44 action.sa_sigaction(signal, siginfo, context);
45 } else {
46 action.sa_handler(signal);
47 }
48 return true;
49}
50
51void segv_handler(int signal, siginfo_t *siginfo, void *context_void) {
52 GlobalState *my_global_state = GlobalState::Get();
53 const int saved_errno = errno;
54 SIMPLE_ASSERT(signal == SIGSEGV, "wrong signal for SIGSEGV handler");
55
56 // Only process memory addresses in our shared memory block.
57 if (!my_global_state->IsInLocklessQueueMemory(siginfo->si_addr)) {
58 if (CallChainedAction(old_segv_handler, signal, siginfo, context_void)) {
59 errno = saved_errno;
60 return;
61 } else {
62 SIMPLE_ASSERT(false, "actual SIGSEGV");
63 }
64 }
65 SIMPLE_ASSERT(my_global_state->state == DieAtState::kRunning,
66 "bad state for SIGSEGV");
67
68 my_global_state->HandleWrite(siginfo->si_addr);
69
70 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
71 my_global_state->state = DieAtState::kWriting;
72 errno = saved_errno;
73
74#if defined(__x86_64__)
75 __asm__ __volatile__("int $3" ::: "memory", "cc");
76#elif defined(__aarch64__)
77 __asm__ __volatile__("brk #0" ::: "memory", "cc");
78#else
79#error Unhandled architecture
80#endif
81}
82
83// The SEGV handler has set a breakpoint 1 instruction in the future. This
84// clears it, marks memory readonly, and continues.
85void trap_handler(int signal, siginfo_t *, void * /*context*/) {
86 GlobalState *my_global_state = GlobalState::Get();
87 const int saved_errno = errno;
88 SIMPLE_ASSERT(signal == SIGTRAP, "wrong signal for SIGTRAP handler");
89
90 my_global_state->state = DieAtState::kWriting;
91 SIMPLE_ASSERT(my_global_state->state == DieAtState::kWriting,
92 "bad state for SIGTRAP");
93 my_global_state->ShmProtectOrDie(PROT_READ);
94 my_global_state->state = DieAtState::kRunning;
95 errno = saved_errno;
96}
97
98// Installs the signal handler.
99void InstallHandler(int signal, void (*handler)(int, siginfo_t *, void *),
100 struct sigaction *old_action) {
101 struct sigaction action;
102 memset(&action, 0, sizeof(action));
103 action.sa_sigaction = handler;
104 // We don't do a full normal signal handler exit with ptrace, so SA_NODEFER is
105 // necessary to keep our signal handler active.
106 action.sa_flags = SA_RESTART | SA_SIGINFO | SA_NODEFER;
107#ifdef AOS_SANITIZER_thread
108 // Tsan messes with signal handlers to check for race conditions, and it
109 // causes problems, so we have to work around it for SIGTRAP.
110 if (signal == SIGTRAP) {
111 typedef int (*SigactionType)(int, const struct sigaction *,
112 struct sigaction *);
113 SigactionType real_sigaction =
114 reinterpret_cast<SigactionType>(dlsym(RTLD_NEXT, "sigaction"));
115 if (sigaction == real_sigaction) {
116 LOG(WARNING) << "failed to work around tsan signal handling weirdness";
117 }
118 PCHECK(real_sigaction(signal, &action, old_action) == 0);
119 return;
120 }
121#endif
122 PCHECK(sigaction(signal, &action, old_action) == 0);
123}
124
125// A mutex lock is about to happen. Mark the memory rw, and check to see if we
126// should die.
127void futex_before(void *address, bool) {
128 GlobalState *my_global_state = GlobalState::Get();
129 if (my_global_state->IsInLocklessQueueMemory(address)) {
130 assert(my_global_state->state == DieAtState::kRunning);
131 my_global_state->HandleWrite(address);
132 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
133 my_global_state->state = DieAtState::kWriting;
134 }
135}
136
137// We have a manual trap for mutexes. Check to see if we were supposed to die
138// on this write (the compare/exchange for the mutex), and mark the memory ro
139// again.
140void futex_after(void *address, bool) {
141 GlobalState *my_global_state = GlobalState::Get();
142 if (my_global_state->IsInLocklessQueueMemory(address)) {
143 assert(my_global_state->state == DieAtState::kWriting);
144 my_global_state->ShmProtectOrDie(PROT_READ);
145 my_global_state->state = DieAtState::kRunning;
146 }
147}
148
149} // namespace
150
151void GlobalState::HandleWrite(void *address) {
152 uintptr_t address_offset = reinterpret_cast<uintptr_t>(address) -
153 reinterpret_cast<uintptr_t>(lockless_queue_memory);
154 if (writes_in != nullptr) {
155 SIMPLE_ASSERT(writes_in->At(current_location) == address_offset,
156 "wrong write order");
157 }
158 if (writes_out != nullptr) {
159 writes_out->Add(address_offset);
160 }
161 if (die_at != 0) {
162 if (die_at == current_location) {
163 _exit(kExitEarlyValue);
164 }
165 }
166 ++current_location;
167}
168
169GlobalState *GlobalState::Get() {
170 return global_state.load(::std::memory_order_relaxed);
171}
172
173std::tuple<GlobalState *, WritesArray *> GlobalState::MakeGlobalState() {
174 // Map the global state and memory for the Writes array so it exists across
175 // the process boundary.
176 void *shared_allocations = static_cast<GlobalState *>(
177 mmap(nullptr, sizeof(GlobalState) + sizeof(WritesArray),
178 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
179 CHECK_NE(MAP_FAILED, shared_allocations);
180
181 global_state.store(static_cast<GlobalState *>(shared_allocations));
182 void *expected_writes_shared_allocations = static_cast<void *>(
183 static_cast<uint8_t *>(shared_allocations) + sizeof(GlobalState));
184 WritesArray *expected_writes =
185 static_cast<WritesArray *>(expected_writes_shared_allocations);
186 new (expected_writes) WritesArray();
187 return std::make_pair(static_cast<GlobalState *>(shared_allocations),
188 expected_writes);
189}
190
191bool GlobalState::IsInLocklessQueueMemory(void *address) {
192 void *read_lockless_queue_memory = lockless_queue_memory;
193 if (address < read_lockless_queue_memory) {
194 return false;
195 if (reinterpret_cast<uintptr_t>(address) >
196 reinterpret_cast<uintptr_t>(read_lockless_queue_memory) +
197 lockless_queue_memory_size)
198 return false;
199 }
200 return true;
201}
202
203void GlobalState::ShmProtectOrDie(int prot) {
204 PCHECK(mprotect(lockless_queue_memory, lockless_queue_memory_size, prot) !=
205 -1)
206 << ": mprotect(" << lockless_queue_memory << ", "
207 << lockless_queue_memory_size << ", 0x" << std::hex << prot << ") failed";
208}
209
210void GlobalState::RegisterSegvAndTrapHandlers() {
211 InstallHandler(SIGSEGV, segv_handler, &old_segv_handler);
212 InstallHandler(SIGTRAP, trap_handler, &old_trap_handler);
213 CHECK_EQ(old_trap_handler.sa_handler, SIG_DFL);
214 linux_code::ipc_lib::SetShmAccessorObservers(futex_before, futex_after);
215}
216
217// gtest only allows creating fatal failures in functions returning void...
218// status is from wait(2).
219void DetectFatalFailures(int status) {
220 if (WIFEXITED(status)) {
221 FAIL() << " child returned status "
222 << ::std::to_string(WEXITSTATUS(status));
223 } else if (WIFSIGNALED(status)) {
224 FAIL() << " child exited because of signal "
225 << aos_strsignal(WTERMSIG(status));
226 } else {
227 FAIL() << " child exited with status " << ::std::hex << status;
228 }
229}
230
231// Returns true if it runs all the way through.
232bool RunFunctionDieAt(::std::function<void(void *)> prepare,
233 ::std::function<void(void *)> function,
234 bool *test_failure, size_t die_at,
235 uintptr_t writable_offset, const WritesArray *writes_in,
236 WritesArray *writes_out) {
237 GlobalState *my_global_state = GlobalState::Get();
238 my_global_state->writes_in = writes_in;
239 my_global_state->writes_out = writes_out;
240 my_global_state->die_at = die_at;
241 my_global_state->current_location = 0;
242 my_global_state->state = DieAtState::kDisabled;
243
244 const pid_t pid = fork();
245 PCHECK(pid != -1) << ": fork() failed";
246 if (pid == 0) {
247 // Run the test.
248 ::aos::testing::PreventExit();
249
250 prepare(my_global_state->lockless_queue_memory);
251
252 // Update the robust list offset.
253 linux_code::ipc_lib::SetRobustListOffset(writable_offset);
254 // Install a segv handler (to detect writes to the memory block), and a trap
255 // handler so we can single step.
256 my_global_state->RegisterSegvAndTrapHandlers();
257
258 PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
259 my_global_state->ShmProtectOrDie(PROT_READ);
260 my_global_state->state = DieAtState::kRunning;
261
262 function(my_global_state->lockless_queue_memory);
263 my_global_state->state = DieAtState::kDisabled;
264 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
265 _exit(0);
266 } else {
267 // Annoying wrapper type because elf_gregset_t is an array, which C++
268 // handles poorly.
269 struct RestoreState {
270 RestoreState(elf_gregset_t regs_in) {
271 memcpy(regs, regs_in, sizeof(regs));
272 }
273 elf_gregset_t regs;
274 };
275 std::optional<RestoreState> restore_regs;
276 bool pass_trap = false;
277 // Wait until the child process dies.
278 while (true) {
279 int status;
280 pid_t waited_on = waitpid(pid, &status, 0);
281 if (waited_on == -1) {
282 if (errno == EINTR) continue;
283 PCHECK(false) << ": waitpid(" << pid << ", " << &status
284 << ", 0) failed";
285 }
286 CHECK_EQ(waited_on, pid)
287 << ": waitpid got child " << waited_on << " instead of " << pid;
288 if (WIFSTOPPED(status)) {
289 // The child was stopped via ptrace.
290 const int stop_signal = WSTOPSIG(status);
291 elf_gregset_t regs;
292 {
293 struct iovec iov;
294 iov.iov_base = &regs;
295 iov.iov_len = sizeof(regs);
296 PCHECK(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov) == 0);
297 CHECK_EQ(iov.iov_len, sizeof(regs))
298 << ": ptrace regset is the wrong size";
299 }
300 if (stop_signal == SIGSEGV) {
301 // It's a SEGV, hopefully due to writing to the shared memory which is
302 // marked read-only. We record the instruction that faulted so we can
303 // look for it while single-stepping, then deliver the signal so the
304 // child can mark it read-write and then poke us to single-step that
305 // instruction.
306
307 CHECK(!restore_regs)
308 << ": Traced child got a SEGV while single-stepping";
309 // Save all the registers to resume execution at the current location
310 // in the child.
311 restore_regs = RestoreState(regs);
312 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGSEGV) == 0);
313 continue;
314 }
315 if (stop_signal == SIGTRAP) {
316 if (pass_trap) {
317 // This is the new SIGTRAP we generated, which we just want to pass
318 // through so the child's signal handler can restore the memory to
319 // read-only
320 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGTRAP) == 0);
321 pass_trap = false;
322 continue;
323 }
324 if (restore_regs) {
325 // Restore the state we saved before delivering the SEGV, and then
326 // single-step that one instruction.
327 struct iovec iov;
328 iov.iov_base = &restore_regs->regs;
329 iov.iov_len = sizeof(restore_regs->regs);
330 PCHECK(ptrace(PTRACE_SETREGSET, pid, NT_PRSTATUS, &iov) == 0);
331 restore_regs = std::nullopt;
332 PCHECK(ptrace(PTRACE_SINGLESTEP, pid, nullptr, nullptr) == 0);
333 continue;
334 }
335 // We executed the single instruction that originally faulted, so
336 // now deliver a SIGTRAP to the child so it can mark the memory
337 // read-only again.
338 pass_trap = true;
339 PCHECK(kill(pid, SIGTRAP) == 0);
340 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, nullptr) == 0);
341 continue;
342 }
343 LOG(FATAL) << "Traced child was stopped with unexpected signal: "
344 << static_cast<int>(WSTOPSIG(status));
345 }
346 if (WIFEXITED(status)) {
347 if (WEXITSTATUS(status) == 0) return true;
348 if (WEXITSTATUS(status) == kExitEarlyValue) return false;
349 }
350 DetectFatalFailures(status);
351 if (test_failure) *test_failure = true;
352 return false;
353 }
354 }
355}
356
357bool RunFunctionDieAtAndCheck(const LocklessQueueConfiguration &config,
358 ::std::function<void(void *)> prepare,
359 ::std::function<void(void *)> function,
360 ::std::function<void(void *)> check,
361 bool *test_failure, size_t die_at,
362 const WritesArray *writes_in,
363 WritesArray *writes_out) {
364 // Allocate shared memory.
365 GlobalState *my_global_state = GlobalState::Get();
366 my_global_state->lockless_queue_memory_size = LocklessQueueMemorySize(config);
367 my_global_state->lockless_queue_memory = static_cast<void *>(
368 mmap(nullptr, my_global_state->lockless_queue_memory_size,
369 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
370 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory);
371
372 // And the backup used to point the robust list at.
373 my_global_state->lockless_queue_memory_lock_backup = static_cast<void *>(
374 mmap(nullptr, my_global_state->lockless_queue_memory_size,
375 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
376 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory_lock_backup);
377
378 // The writable offset tells us how to convert from a pointer in the queue to
379 // a pointer that is safe to write. This is so robust futexes don't spin the
380 // kernel when the user maps a page PROT_READ, and the kernel tries to clear
381 // the futex there.
382 const uintptr_t writable_offset =
383 reinterpret_cast<uintptr_t>(
384 my_global_state->lockless_queue_memory_lock_backup) -
385 reinterpret_cast<uintptr_t>(my_global_state->lockless_queue_memory);
386
387 bool r;
388 // Do the actual test in a new thread so any locked mutexes will be cleaned up
389 // nicely with owner-died at the end.
390 ::std::thread test_thread([&prepare, &function, &check, test_failure, die_at,
391 writes_in, writes_out, writable_offset, &r]() {
392 r = RunFunctionDieAt(prepare, function, test_failure, die_at,
393 writable_offset, writes_in, writes_out);
394 if (::testing::Test::HasFailure()) {
395 r = false;
396 if (test_failure) *test_failure = true;
397 return;
398 }
399
400 check(GlobalState::Get()->lockless_queue_memory);
401 });
402 test_thread.join();
403 return r;
404}
405
406// Tests function to make sure it handles dying after each store it makes to
407// shared memory. check should make sure function behaved correctly.
408// This will repeatedly create a new TestSharedMemory, run prepare, run
409// function, and then
410// run check, killing the process function is running in at various points. It
411// will stop if anything reports a fatal gtest failure.
412void TestShmRobustness(const LocklessQueueConfiguration &config,
413 ::std::function<void(void *)> prepare,
414 ::std::function<void(void *)> function,
415 ::std::function<void(void *)> check) {
416 auto [my_global_state, expected_writes] = GlobalState::MakeGlobalState();
417
418 bool test_failed = false;
419 ASSERT_TRUE(RunFunctionDieAtAndCheck(config, prepare, function, check,
420 &test_failed, 0, nullptr,
421 expected_writes));
422 if (test_failed) {
423 ADD_FAILURE();
424 return;
425 }
426
427 size_t die_at = 1;
428 while (true) {
429 SCOPED_TRACE("dying at " + ::std::to_string(die_at) + "/" +
430 ::std::to_string(expected_writes->size()));
431 if (RunFunctionDieAtAndCheck(config, prepare, function, check, &test_failed,
432 die_at, expected_writes, nullptr)) {
433 LOG(INFO) << "Tested " << die_at << " death points";
434 return;
435 }
436 if (test_failed) {
437 ADD_FAILURE();
438 }
439 if (::testing::Test::HasFailure()) return;
440 ++die_at;
441 }
442}
443
444SharedTid::SharedTid() {
445 // Capture the tid in the child so we can tell if it died. Use mmap so it
446 // works across the process boundary.
447 tid_ =
448 static_cast<pid_t *>(mmap(nullptr, sizeof(pid_t), PROT_READ | PROT_WRITE,
449 MAP_SHARED | MAP_ANONYMOUS, -1, 0));
450 CHECK_NE(MAP_FAILED, tid_);
451}
452
453SharedTid::~SharedTid() { CHECK_EQ(munmap(tid_, sizeof(pid_t)), 0); }
454
455void SharedTid::Set() { *tid_ = gettid(); }
456
457pid_t SharedTid::Get() { return *tid_; }
458
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -0800459} // namespace aos::ipc_lib::testing
Austin Schuhfaec51a2023-09-08 17:43:32 -0700460
461#endif // SUPPORTS_SHM_ROBSTNESS_TEST