blob: 8608ad577a830932d06512d46550414370678ad9 [file] [log] [blame]
Austin Schuhfaec51a2023-09-08 17:43:32 -07001#include "aos/ipc_lib/lockless_queue_stepping.h"
2
Stephan Pleines682928d2024-05-31 20:43:48 -07003#include <assert.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -07004#include <elf.h>
Stephan Pleines682928d2024-05-31 20:43:48 -07005#include <errno.h>
6#include <signal.h>
7#include <string.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -07008#include <sys/mman.h>
9#include <sys/procfs.h>
10#include <sys/ptrace.h>
11#include <sys/syscall.h>
12#include <sys/uio.h>
Stephan Pleines682928d2024-05-31 20:43:48 -070013#include <sys/wait.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -070014#include <unistd.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -070015
Stephan Pleines682928d2024-05-31 20:43:48 -070016#include <atomic>
17#include <new>
18#include <optional>
19#include <ostream>
20#include <string>
Austin Schuhfaec51a2023-09-08 17:43:32 -070021#include <thread>
Stephan Pleines682928d2024-05-31 20:43:48 -070022#include <utility>
Austin Schuhfaec51a2023-09-08 17:43:32 -070023
Austin Schuh99f7c6a2024-06-25 22:07:44 -070024#include "absl/log/check.h"
25#include "absl/log/log.h"
Austin Schuhfaec51a2023-09-08 17:43:32 -070026#include "gtest/gtest.h"
27
28#include "aos/ipc_lib/aos_sync.h"
Austin Schuhfaec51a2023-09-08 17:43:32 -070029#include "aos/ipc_lib/shm_observers.h"
30#include "aos/libc/aos_strsignal.h"
31#include "aos/testing/prevent_exit.h"
32
33#ifdef SUPPORTS_SHM_ROBUSTNESS_TEST
34
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -080035namespace aos::ipc_lib::testing {
Austin Schuhfaec51a2023-09-08 17:43:32 -070036
37namespace {
38pid_t gettid() { return syscall(SYS_gettid); }
39
40::std::atomic<GlobalState *> global_state;
41
42struct sigaction old_segv_handler, old_trap_handler;
43
44// Calls the original signal handler.
45bool CallChainedAction(const struct sigaction &action, int signal,
46 siginfo_t *siginfo, void *context) {
47 if (action.sa_handler == SIG_IGN || action.sa_handler == SIG_DFL) {
48 return false;
49 }
50 if (action.sa_flags & SA_SIGINFO) {
51 action.sa_sigaction(signal, siginfo, context);
52 } else {
53 action.sa_handler(signal);
54 }
55 return true;
56}
57
58void segv_handler(int signal, siginfo_t *siginfo, void *context_void) {
59 GlobalState *my_global_state = GlobalState::Get();
60 const int saved_errno = errno;
61 SIMPLE_ASSERT(signal == SIGSEGV, "wrong signal for SIGSEGV handler");
62
63 // Only process memory addresses in our shared memory block.
64 if (!my_global_state->IsInLocklessQueueMemory(siginfo->si_addr)) {
65 if (CallChainedAction(old_segv_handler, signal, siginfo, context_void)) {
66 errno = saved_errno;
67 return;
68 } else {
69 SIMPLE_ASSERT(false, "actual SIGSEGV");
70 }
71 }
72 SIMPLE_ASSERT(my_global_state->state == DieAtState::kRunning,
73 "bad state for SIGSEGV");
74
75 my_global_state->HandleWrite(siginfo->si_addr);
76
77 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
78 my_global_state->state = DieAtState::kWriting;
79 errno = saved_errno;
80
81#if defined(__x86_64__)
82 __asm__ __volatile__("int $3" ::: "memory", "cc");
83#elif defined(__aarch64__)
84 __asm__ __volatile__("brk #0" ::: "memory", "cc");
85#else
86#error Unhandled architecture
87#endif
88}
89
90// The SEGV handler has set a breakpoint 1 instruction in the future. This
91// clears it, marks memory readonly, and continues.
92void trap_handler(int signal, siginfo_t *, void * /*context*/) {
93 GlobalState *my_global_state = GlobalState::Get();
94 const int saved_errno = errno;
95 SIMPLE_ASSERT(signal == SIGTRAP, "wrong signal for SIGTRAP handler");
96
97 my_global_state->state = DieAtState::kWriting;
98 SIMPLE_ASSERT(my_global_state->state == DieAtState::kWriting,
99 "bad state for SIGTRAP");
100 my_global_state->ShmProtectOrDie(PROT_READ);
101 my_global_state->state = DieAtState::kRunning;
102 errno = saved_errno;
103}
104
105// Installs the signal handler.
106void InstallHandler(int signal, void (*handler)(int, siginfo_t *, void *),
107 struct sigaction *old_action) {
108 struct sigaction action;
109 memset(&action, 0, sizeof(action));
110 action.sa_sigaction = handler;
111 // We don't do a full normal signal handler exit with ptrace, so SA_NODEFER is
112 // necessary to keep our signal handler active.
113 action.sa_flags = SA_RESTART | SA_SIGINFO | SA_NODEFER;
114#ifdef AOS_SANITIZER_thread
115 // Tsan messes with signal handlers to check for race conditions, and it
116 // causes problems, so we have to work around it for SIGTRAP.
117 if (signal == SIGTRAP) {
118 typedef int (*SigactionType)(int, const struct sigaction *,
119 struct sigaction *);
120 SigactionType real_sigaction =
121 reinterpret_cast<SigactionType>(dlsym(RTLD_NEXT, "sigaction"));
122 if (sigaction == real_sigaction) {
123 LOG(WARNING) << "failed to work around tsan signal handling weirdness";
124 }
125 PCHECK(real_sigaction(signal, &action, old_action) == 0);
126 return;
127 }
128#endif
129 PCHECK(sigaction(signal, &action, old_action) == 0);
130}
131
132// A mutex lock is about to happen. Mark the memory rw, and check to see if we
133// should die.
134void futex_before(void *address, bool) {
135 GlobalState *my_global_state = GlobalState::Get();
136 if (my_global_state->IsInLocklessQueueMemory(address)) {
137 assert(my_global_state->state == DieAtState::kRunning);
138 my_global_state->HandleWrite(address);
139 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
140 my_global_state->state = DieAtState::kWriting;
141 }
142}
143
144// We have a manual trap for mutexes. Check to see if we were supposed to die
145// on this write (the compare/exchange for the mutex), and mark the memory ro
146// again.
147void futex_after(void *address, bool) {
148 GlobalState *my_global_state = GlobalState::Get();
149 if (my_global_state->IsInLocklessQueueMemory(address)) {
150 assert(my_global_state->state == DieAtState::kWriting);
151 my_global_state->ShmProtectOrDie(PROT_READ);
152 my_global_state->state = DieAtState::kRunning;
153 }
154}
155
156} // namespace
157
158void GlobalState::HandleWrite(void *address) {
159 uintptr_t address_offset = reinterpret_cast<uintptr_t>(address) -
160 reinterpret_cast<uintptr_t>(lockless_queue_memory);
161 if (writes_in != nullptr) {
162 SIMPLE_ASSERT(writes_in->At(current_location) == address_offset,
163 "wrong write order");
164 }
165 if (writes_out != nullptr) {
166 writes_out->Add(address_offset);
167 }
168 if (die_at != 0) {
169 if (die_at == current_location) {
170 _exit(kExitEarlyValue);
171 }
172 }
173 ++current_location;
174}
175
176GlobalState *GlobalState::Get() {
177 return global_state.load(::std::memory_order_relaxed);
178}
179
180std::tuple<GlobalState *, WritesArray *> GlobalState::MakeGlobalState() {
181 // Map the global state and memory for the Writes array so it exists across
182 // the process boundary.
183 void *shared_allocations = static_cast<GlobalState *>(
184 mmap(nullptr, sizeof(GlobalState) + sizeof(WritesArray),
185 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
186 CHECK_NE(MAP_FAILED, shared_allocations);
187
188 global_state.store(static_cast<GlobalState *>(shared_allocations));
189 void *expected_writes_shared_allocations = static_cast<void *>(
190 static_cast<uint8_t *>(shared_allocations) + sizeof(GlobalState));
191 WritesArray *expected_writes =
192 static_cast<WritesArray *>(expected_writes_shared_allocations);
193 new (expected_writes) WritesArray();
194 return std::make_pair(static_cast<GlobalState *>(shared_allocations),
195 expected_writes);
196}
197
198bool GlobalState::IsInLocklessQueueMemory(void *address) {
199 void *read_lockless_queue_memory = lockless_queue_memory;
200 if (address < read_lockless_queue_memory) {
201 return false;
202 if (reinterpret_cast<uintptr_t>(address) >
203 reinterpret_cast<uintptr_t>(read_lockless_queue_memory) +
204 lockless_queue_memory_size)
205 return false;
206 }
207 return true;
208}
209
210void GlobalState::ShmProtectOrDie(int prot) {
211 PCHECK(mprotect(lockless_queue_memory, lockless_queue_memory_size, prot) !=
212 -1)
213 << ": mprotect(" << lockless_queue_memory << ", "
214 << lockless_queue_memory_size << ", 0x" << std::hex << prot << ") failed";
215}
216
217void GlobalState::RegisterSegvAndTrapHandlers() {
218 InstallHandler(SIGSEGV, segv_handler, &old_segv_handler);
219 InstallHandler(SIGTRAP, trap_handler, &old_trap_handler);
220 CHECK_EQ(old_trap_handler.sa_handler, SIG_DFL);
221 linux_code::ipc_lib::SetShmAccessorObservers(futex_before, futex_after);
222}
223
224// gtest only allows creating fatal failures in functions returning void...
225// status is from wait(2).
226void DetectFatalFailures(int status) {
227 if (WIFEXITED(status)) {
228 FAIL() << " child returned status "
229 << ::std::to_string(WEXITSTATUS(status));
230 } else if (WIFSIGNALED(status)) {
231 FAIL() << " child exited because of signal "
232 << aos_strsignal(WTERMSIG(status));
233 } else {
234 FAIL() << " child exited with status " << ::std::hex << status;
235 }
236}
237
238// Returns true if it runs all the way through.
239bool RunFunctionDieAt(::std::function<void(void *)> prepare,
240 ::std::function<void(void *)> function,
241 bool *test_failure, size_t die_at,
242 uintptr_t writable_offset, const WritesArray *writes_in,
243 WritesArray *writes_out) {
244 GlobalState *my_global_state = GlobalState::Get();
245 my_global_state->writes_in = writes_in;
246 my_global_state->writes_out = writes_out;
247 my_global_state->die_at = die_at;
248 my_global_state->current_location = 0;
249 my_global_state->state = DieAtState::kDisabled;
250
251 const pid_t pid = fork();
252 PCHECK(pid != -1) << ": fork() failed";
253 if (pid == 0) {
254 // Run the test.
255 ::aos::testing::PreventExit();
256
257 prepare(my_global_state->lockless_queue_memory);
258
259 // Update the robust list offset.
260 linux_code::ipc_lib::SetRobustListOffset(writable_offset);
261 // Install a segv handler (to detect writes to the memory block), and a trap
262 // handler so we can single step.
263 my_global_state->RegisterSegvAndTrapHandlers();
264
265 PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
266 my_global_state->ShmProtectOrDie(PROT_READ);
267 my_global_state->state = DieAtState::kRunning;
268
269 function(my_global_state->lockless_queue_memory);
270 my_global_state->state = DieAtState::kDisabled;
271 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
272 _exit(0);
273 } else {
274 // Annoying wrapper type because elf_gregset_t is an array, which C++
275 // handles poorly.
276 struct RestoreState {
277 RestoreState(elf_gregset_t regs_in) {
278 memcpy(regs, regs_in, sizeof(regs));
279 }
280 elf_gregset_t regs;
281 };
282 std::optional<RestoreState> restore_regs;
283 bool pass_trap = false;
284 // Wait until the child process dies.
285 while (true) {
286 int status;
287 pid_t waited_on = waitpid(pid, &status, 0);
288 if (waited_on == -1) {
289 if (errno == EINTR) continue;
290 PCHECK(false) << ": waitpid(" << pid << ", " << &status
291 << ", 0) failed";
292 }
293 CHECK_EQ(waited_on, pid)
294 << ": waitpid got child " << waited_on << " instead of " << pid;
295 if (WIFSTOPPED(status)) {
296 // The child was stopped via ptrace.
297 const int stop_signal = WSTOPSIG(status);
298 elf_gregset_t regs;
299 {
300 struct iovec iov;
301 iov.iov_base = &regs;
302 iov.iov_len = sizeof(regs);
303 PCHECK(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov) == 0);
304 CHECK_EQ(iov.iov_len, sizeof(regs))
305 << ": ptrace regset is the wrong size";
306 }
307 if (stop_signal == SIGSEGV) {
308 // It's a SEGV, hopefully due to writing to the shared memory which is
309 // marked read-only. We record the instruction that faulted so we can
310 // look for it while single-stepping, then deliver the signal so the
311 // child can mark it read-write and then poke us to single-step that
312 // instruction.
313
314 CHECK(!restore_regs)
315 << ": Traced child got a SEGV while single-stepping";
316 // Save all the registers to resume execution at the current location
317 // in the child.
318 restore_regs = RestoreState(regs);
319 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGSEGV) == 0);
320 continue;
321 }
322 if (stop_signal == SIGTRAP) {
323 if (pass_trap) {
324 // This is the new SIGTRAP we generated, which we just want to pass
325 // through so the child's signal handler can restore the memory to
326 // read-only
327 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGTRAP) == 0);
328 pass_trap = false;
329 continue;
330 }
331 if (restore_regs) {
332 // Restore the state we saved before delivering the SEGV, and then
333 // single-step that one instruction.
334 struct iovec iov;
335 iov.iov_base = &restore_regs->regs;
336 iov.iov_len = sizeof(restore_regs->regs);
337 PCHECK(ptrace(PTRACE_SETREGSET, pid, NT_PRSTATUS, &iov) == 0);
338 restore_regs = std::nullopt;
339 PCHECK(ptrace(PTRACE_SINGLESTEP, pid, nullptr, nullptr) == 0);
340 continue;
341 }
342 // We executed the single instruction that originally faulted, so
343 // now deliver a SIGTRAP to the child so it can mark the memory
344 // read-only again.
345 pass_trap = true;
346 PCHECK(kill(pid, SIGTRAP) == 0);
347 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, nullptr) == 0);
348 continue;
349 }
350 LOG(FATAL) << "Traced child was stopped with unexpected signal: "
351 << static_cast<int>(WSTOPSIG(status));
352 }
353 if (WIFEXITED(status)) {
354 if (WEXITSTATUS(status) == 0) return true;
355 if (WEXITSTATUS(status) == kExitEarlyValue) return false;
356 }
357 DetectFatalFailures(status);
358 if (test_failure) *test_failure = true;
359 return false;
360 }
361 }
362}
363
364bool RunFunctionDieAtAndCheck(const LocklessQueueConfiguration &config,
365 ::std::function<void(void *)> prepare,
366 ::std::function<void(void *)> function,
367 ::std::function<void(void *)> check,
368 bool *test_failure, size_t die_at,
369 const WritesArray *writes_in,
370 WritesArray *writes_out) {
371 // Allocate shared memory.
372 GlobalState *my_global_state = GlobalState::Get();
373 my_global_state->lockless_queue_memory_size = LocklessQueueMemorySize(config);
374 my_global_state->lockless_queue_memory = static_cast<void *>(
375 mmap(nullptr, my_global_state->lockless_queue_memory_size,
376 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
377 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory);
378
379 // And the backup used to point the robust list at.
380 my_global_state->lockless_queue_memory_lock_backup = static_cast<void *>(
381 mmap(nullptr, my_global_state->lockless_queue_memory_size,
382 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
383 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory_lock_backup);
384
385 // The writable offset tells us how to convert from a pointer in the queue to
386 // a pointer that is safe to write. This is so robust futexes don't spin the
387 // kernel when the user maps a page PROT_READ, and the kernel tries to clear
388 // the futex there.
389 const uintptr_t writable_offset =
390 reinterpret_cast<uintptr_t>(
391 my_global_state->lockless_queue_memory_lock_backup) -
392 reinterpret_cast<uintptr_t>(my_global_state->lockless_queue_memory);
393
394 bool r;
395 // Do the actual test in a new thread so any locked mutexes will be cleaned up
396 // nicely with owner-died at the end.
397 ::std::thread test_thread([&prepare, &function, &check, test_failure, die_at,
398 writes_in, writes_out, writable_offset, &r]() {
399 r = RunFunctionDieAt(prepare, function, test_failure, die_at,
400 writable_offset, writes_in, writes_out);
401 if (::testing::Test::HasFailure()) {
402 r = false;
403 if (test_failure) *test_failure = true;
404 return;
405 }
406
407 check(GlobalState::Get()->lockless_queue_memory);
408 });
409 test_thread.join();
410 return r;
411}
412
413// Tests function to make sure it handles dying after each store it makes to
414// shared memory. check should make sure function behaved correctly.
415// This will repeatedly create a new TestSharedMemory, run prepare, run
416// function, and then
417// run check, killing the process function is running in at various points. It
418// will stop if anything reports a fatal gtest failure.
419void TestShmRobustness(const LocklessQueueConfiguration &config,
420 ::std::function<void(void *)> prepare,
421 ::std::function<void(void *)> function,
422 ::std::function<void(void *)> check) {
423 auto [my_global_state, expected_writes] = GlobalState::MakeGlobalState();
424
425 bool test_failed = false;
426 ASSERT_TRUE(RunFunctionDieAtAndCheck(config, prepare, function, check,
427 &test_failed, 0, nullptr,
428 expected_writes));
429 if (test_failed) {
430 ADD_FAILURE();
431 return;
432 }
433
434 size_t die_at = 1;
435 while (true) {
436 SCOPED_TRACE("dying at " + ::std::to_string(die_at) + "/" +
437 ::std::to_string(expected_writes->size()));
438 if (RunFunctionDieAtAndCheck(config, prepare, function, check, &test_failed,
439 die_at, expected_writes, nullptr)) {
440 LOG(INFO) << "Tested " << die_at << " death points";
441 return;
442 }
443 if (test_failed) {
444 ADD_FAILURE();
445 }
446 if (::testing::Test::HasFailure()) return;
447 ++die_at;
448 }
449}
450
451SharedTid::SharedTid() {
452 // Capture the tid in the child so we can tell if it died. Use mmap so it
453 // works across the process boundary.
454 tid_ =
455 static_cast<pid_t *>(mmap(nullptr, sizeof(pid_t), PROT_READ | PROT_WRITE,
456 MAP_SHARED | MAP_ANONYMOUS, -1, 0));
457 CHECK_NE(MAP_FAILED, tid_);
458}
459
460SharedTid::~SharedTid() { CHECK_EQ(munmap(tid_, sizeof(pid_t)), 0); }
461
462void SharedTid::Set() { *tid_ = gettid(); }
463
464pid_t SharedTid::Get() { return *tid_; }
465
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -0800466} // namespace aos::ipc_lib::testing
Austin Schuhfaec51a2023-09-08 17:43:32 -0700467
468#endif // SUPPORTS_SHM_ROBSTNESS_TEST