blob: 0ec52143a95929ab019beab75328f51e729b6dcf [file] [log] [blame]
Austin Schuhfaec51a2023-09-08 17:43:32 -07001#include "aos/ipc_lib/lockless_queue_stepping.h"
2
3#include <dlfcn.h>
4#include <elf.h>
5#include <linux/futex.h>
6#include <sys/mman.h>
7#include <sys/procfs.h>
8#include <sys/ptrace.h>
9#include <sys/syscall.h>
10#include <sys/uio.h>
11#include <unistd.h>
12#include <wait.h>
13
14#include <memory>
15#include <thread>
16
17#include "glog/logging.h"
18#include "gtest/gtest.h"
19
20#include "aos/ipc_lib/aos_sync.h"
21#include "aos/ipc_lib/lockless_queue_memory.h"
22#include "aos/ipc_lib/shm_observers.h"
23#include "aos/libc/aos_strsignal.h"
24#include "aos/testing/prevent_exit.h"
25
26#ifdef SUPPORTS_SHM_ROBUSTNESS_TEST
27
28namespace aos {
29namespace ipc_lib {
30namespace testing {
31
32namespace {
33pid_t gettid() { return syscall(SYS_gettid); }
34
35::std::atomic<GlobalState *> global_state;
36
37struct sigaction old_segv_handler, old_trap_handler;
38
39// Calls the original signal handler.
40bool CallChainedAction(const struct sigaction &action, int signal,
41 siginfo_t *siginfo, void *context) {
42 if (action.sa_handler == SIG_IGN || action.sa_handler == SIG_DFL) {
43 return false;
44 }
45 if (action.sa_flags & SA_SIGINFO) {
46 action.sa_sigaction(signal, siginfo, context);
47 } else {
48 action.sa_handler(signal);
49 }
50 return true;
51}
52
53void segv_handler(int signal, siginfo_t *siginfo, void *context_void) {
54 GlobalState *my_global_state = GlobalState::Get();
55 const int saved_errno = errno;
56 SIMPLE_ASSERT(signal == SIGSEGV, "wrong signal for SIGSEGV handler");
57
58 // Only process memory addresses in our shared memory block.
59 if (!my_global_state->IsInLocklessQueueMemory(siginfo->si_addr)) {
60 if (CallChainedAction(old_segv_handler, signal, siginfo, context_void)) {
61 errno = saved_errno;
62 return;
63 } else {
64 SIMPLE_ASSERT(false, "actual SIGSEGV");
65 }
66 }
67 SIMPLE_ASSERT(my_global_state->state == DieAtState::kRunning,
68 "bad state for SIGSEGV");
69
70 my_global_state->HandleWrite(siginfo->si_addr);
71
72 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
73 my_global_state->state = DieAtState::kWriting;
74 errno = saved_errno;
75
76#if defined(__x86_64__)
77 __asm__ __volatile__("int $3" ::: "memory", "cc");
78#elif defined(__aarch64__)
79 __asm__ __volatile__("brk #0" ::: "memory", "cc");
80#else
81#error Unhandled architecture
82#endif
83}
84
85// The SEGV handler has set a breakpoint 1 instruction in the future. This
86// clears it, marks memory readonly, and continues.
87void trap_handler(int signal, siginfo_t *, void * /*context*/) {
88 GlobalState *my_global_state = GlobalState::Get();
89 const int saved_errno = errno;
90 SIMPLE_ASSERT(signal == SIGTRAP, "wrong signal for SIGTRAP handler");
91
92 my_global_state->state = DieAtState::kWriting;
93 SIMPLE_ASSERT(my_global_state->state == DieAtState::kWriting,
94 "bad state for SIGTRAP");
95 my_global_state->ShmProtectOrDie(PROT_READ);
96 my_global_state->state = DieAtState::kRunning;
97 errno = saved_errno;
98}
99
100// Installs the signal handler.
101void InstallHandler(int signal, void (*handler)(int, siginfo_t *, void *),
102 struct sigaction *old_action) {
103 struct sigaction action;
104 memset(&action, 0, sizeof(action));
105 action.sa_sigaction = handler;
106 // We don't do a full normal signal handler exit with ptrace, so SA_NODEFER is
107 // necessary to keep our signal handler active.
108 action.sa_flags = SA_RESTART | SA_SIGINFO | SA_NODEFER;
109#ifdef AOS_SANITIZER_thread
110 // Tsan messes with signal handlers to check for race conditions, and it
111 // causes problems, so we have to work around it for SIGTRAP.
112 if (signal == SIGTRAP) {
113 typedef int (*SigactionType)(int, const struct sigaction *,
114 struct sigaction *);
115 SigactionType real_sigaction =
116 reinterpret_cast<SigactionType>(dlsym(RTLD_NEXT, "sigaction"));
117 if (sigaction == real_sigaction) {
118 LOG(WARNING) << "failed to work around tsan signal handling weirdness";
119 }
120 PCHECK(real_sigaction(signal, &action, old_action) == 0);
121 return;
122 }
123#endif
124 PCHECK(sigaction(signal, &action, old_action) == 0);
125}
126
127// A mutex lock is about to happen. Mark the memory rw, and check to see if we
128// should die.
129void futex_before(void *address, bool) {
130 GlobalState *my_global_state = GlobalState::Get();
131 if (my_global_state->IsInLocklessQueueMemory(address)) {
132 assert(my_global_state->state == DieAtState::kRunning);
133 my_global_state->HandleWrite(address);
134 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
135 my_global_state->state = DieAtState::kWriting;
136 }
137}
138
139// We have a manual trap for mutexes. Check to see if we were supposed to die
140// on this write (the compare/exchange for the mutex), and mark the memory ro
141// again.
142void futex_after(void *address, bool) {
143 GlobalState *my_global_state = GlobalState::Get();
144 if (my_global_state->IsInLocklessQueueMemory(address)) {
145 assert(my_global_state->state == DieAtState::kWriting);
146 my_global_state->ShmProtectOrDie(PROT_READ);
147 my_global_state->state = DieAtState::kRunning;
148 }
149}
150
151} // namespace
152
153void GlobalState::HandleWrite(void *address) {
154 uintptr_t address_offset = reinterpret_cast<uintptr_t>(address) -
155 reinterpret_cast<uintptr_t>(lockless_queue_memory);
156 if (writes_in != nullptr) {
157 SIMPLE_ASSERT(writes_in->At(current_location) == address_offset,
158 "wrong write order");
159 }
160 if (writes_out != nullptr) {
161 writes_out->Add(address_offset);
162 }
163 if (die_at != 0) {
164 if (die_at == current_location) {
165 _exit(kExitEarlyValue);
166 }
167 }
168 ++current_location;
169}
170
171GlobalState *GlobalState::Get() {
172 return global_state.load(::std::memory_order_relaxed);
173}
174
175std::tuple<GlobalState *, WritesArray *> GlobalState::MakeGlobalState() {
176 // Map the global state and memory for the Writes array so it exists across
177 // the process boundary.
178 void *shared_allocations = static_cast<GlobalState *>(
179 mmap(nullptr, sizeof(GlobalState) + sizeof(WritesArray),
180 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
181 CHECK_NE(MAP_FAILED, shared_allocations);
182
183 global_state.store(static_cast<GlobalState *>(shared_allocations));
184 void *expected_writes_shared_allocations = static_cast<void *>(
185 static_cast<uint8_t *>(shared_allocations) + sizeof(GlobalState));
186 WritesArray *expected_writes =
187 static_cast<WritesArray *>(expected_writes_shared_allocations);
188 new (expected_writes) WritesArray();
189 return std::make_pair(static_cast<GlobalState *>(shared_allocations),
190 expected_writes);
191}
192
193bool GlobalState::IsInLocklessQueueMemory(void *address) {
194 void *read_lockless_queue_memory = lockless_queue_memory;
195 if (address < read_lockless_queue_memory) {
196 return false;
197 if (reinterpret_cast<uintptr_t>(address) >
198 reinterpret_cast<uintptr_t>(read_lockless_queue_memory) +
199 lockless_queue_memory_size)
200 return false;
201 }
202 return true;
203}
204
205void GlobalState::ShmProtectOrDie(int prot) {
206 PCHECK(mprotect(lockless_queue_memory, lockless_queue_memory_size, prot) !=
207 -1)
208 << ": mprotect(" << lockless_queue_memory << ", "
209 << lockless_queue_memory_size << ", 0x" << std::hex << prot << ") failed";
210}
211
212void GlobalState::RegisterSegvAndTrapHandlers() {
213 InstallHandler(SIGSEGV, segv_handler, &old_segv_handler);
214 InstallHandler(SIGTRAP, trap_handler, &old_trap_handler);
215 CHECK_EQ(old_trap_handler.sa_handler, SIG_DFL);
216 linux_code::ipc_lib::SetShmAccessorObservers(futex_before, futex_after);
217}
218
219// gtest only allows creating fatal failures in functions returning void...
220// status is from wait(2).
221void DetectFatalFailures(int status) {
222 if (WIFEXITED(status)) {
223 FAIL() << " child returned status "
224 << ::std::to_string(WEXITSTATUS(status));
225 } else if (WIFSIGNALED(status)) {
226 FAIL() << " child exited because of signal "
227 << aos_strsignal(WTERMSIG(status));
228 } else {
229 FAIL() << " child exited with status " << ::std::hex << status;
230 }
231}
232
233// Returns true if it runs all the way through.
234bool RunFunctionDieAt(::std::function<void(void *)> prepare,
235 ::std::function<void(void *)> function,
236 bool *test_failure, size_t die_at,
237 uintptr_t writable_offset, const WritesArray *writes_in,
238 WritesArray *writes_out) {
239 GlobalState *my_global_state = GlobalState::Get();
240 my_global_state->writes_in = writes_in;
241 my_global_state->writes_out = writes_out;
242 my_global_state->die_at = die_at;
243 my_global_state->current_location = 0;
244 my_global_state->state = DieAtState::kDisabled;
245
246 const pid_t pid = fork();
247 PCHECK(pid != -1) << ": fork() failed";
248 if (pid == 0) {
249 // Run the test.
250 ::aos::testing::PreventExit();
251
252 prepare(my_global_state->lockless_queue_memory);
253
254 // Update the robust list offset.
255 linux_code::ipc_lib::SetRobustListOffset(writable_offset);
256 // Install a segv handler (to detect writes to the memory block), and a trap
257 // handler so we can single step.
258 my_global_state->RegisterSegvAndTrapHandlers();
259
260 PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
261 my_global_state->ShmProtectOrDie(PROT_READ);
262 my_global_state->state = DieAtState::kRunning;
263
264 function(my_global_state->lockless_queue_memory);
265 my_global_state->state = DieAtState::kDisabled;
266 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
267 _exit(0);
268 } else {
269 // Annoying wrapper type because elf_gregset_t is an array, which C++
270 // handles poorly.
271 struct RestoreState {
272 RestoreState(elf_gregset_t regs_in) {
273 memcpy(regs, regs_in, sizeof(regs));
274 }
275 elf_gregset_t regs;
276 };
277 std::optional<RestoreState> restore_regs;
278 bool pass_trap = false;
279 // Wait until the child process dies.
280 while (true) {
281 int status;
282 pid_t waited_on = waitpid(pid, &status, 0);
283 if (waited_on == -1) {
284 if (errno == EINTR) continue;
285 PCHECK(false) << ": waitpid(" << pid << ", " << &status
286 << ", 0) failed";
287 }
288 CHECK_EQ(waited_on, pid)
289 << ": waitpid got child " << waited_on << " instead of " << pid;
290 if (WIFSTOPPED(status)) {
291 // The child was stopped via ptrace.
292 const int stop_signal = WSTOPSIG(status);
293 elf_gregset_t regs;
294 {
295 struct iovec iov;
296 iov.iov_base = &regs;
297 iov.iov_len = sizeof(regs);
298 PCHECK(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov) == 0);
299 CHECK_EQ(iov.iov_len, sizeof(regs))
300 << ": ptrace regset is the wrong size";
301 }
302 if (stop_signal == SIGSEGV) {
303 // It's a SEGV, hopefully due to writing to the shared memory which is
304 // marked read-only. We record the instruction that faulted so we can
305 // look for it while single-stepping, then deliver the signal so the
306 // child can mark it read-write and then poke us to single-step that
307 // instruction.
308
309 CHECK(!restore_regs)
310 << ": Traced child got a SEGV while single-stepping";
311 // Save all the registers to resume execution at the current location
312 // in the child.
313 restore_regs = RestoreState(regs);
314 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGSEGV) == 0);
315 continue;
316 }
317 if (stop_signal == SIGTRAP) {
318 if (pass_trap) {
319 // This is the new SIGTRAP we generated, which we just want to pass
320 // through so the child's signal handler can restore the memory to
321 // read-only
322 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGTRAP) == 0);
323 pass_trap = false;
324 continue;
325 }
326 if (restore_regs) {
327 // Restore the state we saved before delivering the SEGV, and then
328 // single-step that one instruction.
329 struct iovec iov;
330 iov.iov_base = &restore_regs->regs;
331 iov.iov_len = sizeof(restore_regs->regs);
332 PCHECK(ptrace(PTRACE_SETREGSET, pid, NT_PRSTATUS, &iov) == 0);
333 restore_regs = std::nullopt;
334 PCHECK(ptrace(PTRACE_SINGLESTEP, pid, nullptr, nullptr) == 0);
335 continue;
336 }
337 // We executed the single instruction that originally faulted, so
338 // now deliver a SIGTRAP to the child so it can mark the memory
339 // read-only again.
340 pass_trap = true;
341 PCHECK(kill(pid, SIGTRAP) == 0);
342 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, nullptr) == 0);
343 continue;
344 }
345 LOG(FATAL) << "Traced child was stopped with unexpected signal: "
346 << static_cast<int>(WSTOPSIG(status));
347 }
348 if (WIFEXITED(status)) {
349 if (WEXITSTATUS(status) == 0) return true;
350 if (WEXITSTATUS(status) == kExitEarlyValue) return false;
351 }
352 DetectFatalFailures(status);
353 if (test_failure) *test_failure = true;
354 return false;
355 }
356 }
357}
358
359bool RunFunctionDieAtAndCheck(const LocklessQueueConfiguration &config,
360 ::std::function<void(void *)> prepare,
361 ::std::function<void(void *)> function,
362 ::std::function<void(void *)> check,
363 bool *test_failure, size_t die_at,
364 const WritesArray *writes_in,
365 WritesArray *writes_out) {
366 // Allocate shared memory.
367 GlobalState *my_global_state = GlobalState::Get();
368 my_global_state->lockless_queue_memory_size = LocklessQueueMemorySize(config);
369 my_global_state->lockless_queue_memory = static_cast<void *>(
370 mmap(nullptr, my_global_state->lockless_queue_memory_size,
371 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
372 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory);
373
374 // And the backup used to point the robust list at.
375 my_global_state->lockless_queue_memory_lock_backup = static_cast<void *>(
376 mmap(nullptr, my_global_state->lockless_queue_memory_size,
377 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
378 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory_lock_backup);
379
380 // The writable offset tells us how to convert from a pointer in the queue to
381 // a pointer that is safe to write. This is so robust futexes don't spin the
382 // kernel when the user maps a page PROT_READ, and the kernel tries to clear
383 // the futex there.
384 const uintptr_t writable_offset =
385 reinterpret_cast<uintptr_t>(
386 my_global_state->lockless_queue_memory_lock_backup) -
387 reinterpret_cast<uintptr_t>(my_global_state->lockless_queue_memory);
388
389 bool r;
390 // Do the actual test in a new thread so any locked mutexes will be cleaned up
391 // nicely with owner-died at the end.
392 ::std::thread test_thread([&prepare, &function, &check, test_failure, die_at,
393 writes_in, writes_out, writable_offset, &r]() {
394 r = RunFunctionDieAt(prepare, function, test_failure, die_at,
395 writable_offset, writes_in, writes_out);
396 if (::testing::Test::HasFailure()) {
397 r = false;
398 if (test_failure) *test_failure = true;
399 return;
400 }
401
402 check(GlobalState::Get()->lockless_queue_memory);
403 });
404 test_thread.join();
405 return r;
406}
407
408// Tests function to make sure it handles dying after each store it makes to
409// shared memory. check should make sure function behaved correctly.
410// This will repeatedly create a new TestSharedMemory, run prepare, run
411// function, and then
412// run check, killing the process function is running in at various points. It
413// will stop if anything reports a fatal gtest failure.
414void TestShmRobustness(const LocklessQueueConfiguration &config,
415 ::std::function<void(void *)> prepare,
416 ::std::function<void(void *)> function,
417 ::std::function<void(void *)> check) {
418 auto [my_global_state, expected_writes] = GlobalState::MakeGlobalState();
419
420 bool test_failed = false;
421 ASSERT_TRUE(RunFunctionDieAtAndCheck(config, prepare, function, check,
422 &test_failed, 0, nullptr,
423 expected_writes));
424 if (test_failed) {
425 ADD_FAILURE();
426 return;
427 }
428
429 size_t die_at = 1;
430 while (true) {
431 SCOPED_TRACE("dying at " + ::std::to_string(die_at) + "/" +
432 ::std::to_string(expected_writes->size()));
433 if (RunFunctionDieAtAndCheck(config, prepare, function, check, &test_failed,
434 die_at, expected_writes, nullptr)) {
435 LOG(INFO) << "Tested " << die_at << " death points";
436 return;
437 }
438 if (test_failed) {
439 ADD_FAILURE();
440 }
441 if (::testing::Test::HasFailure()) return;
442 ++die_at;
443 }
444}
445
446SharedTid::SharedTid() {
447 // Capture the tid in the child so we can tell if it died. Use mmap so it
448 // works across the process boundary.
449 tid_ =
450 static_cast<pid_t *>(mmap(nullptr, sizeof(pid_t), PROT_READ | PROT_WRITE,
451 MAP_SHARED | MAP_ANONYMOUS, -1, 0));
452 CHECK_NE(MAP_FAILED, tid_);
453}
454
455SharedTid::~SharedTid() { CHECK_EQ(munmap(tid_, sizeof(pid_t)), 0); }
456
457void SharedTid::Set() { *tid_ = gettid(); }
458
459pid_t SharedTid::Get() { return *tid_; }
460
461bool PretendOwnerDied(aos_mutex *mutex, pid_t tid) {
462 if ((mutex->futex & FUTEX_TID_MASK) == tid) {
463 mutex->futex = FUTEX_OWNER_DIED;
464 return true;
465 }
466 return false;
467}
468
469} // namespace testing
470} // namespace ipc_lib
471} // namespace aos
472
473#endif // SUPPORTS_SHM_ROBSTNESS_TEST