blob: a7bd233dad3d1286c397f8c35d28d88addd0a477 [file] [log] [blame]
Austin Schuhfaec51a2023-09-08 17:43:32 -07001#include "aos/ipc_lib/lockless_queue_stepping.h"
2
Stephan Pleines682928d2024-05-31 20:43:48 -07003#include <assert.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -07004#include <elf.h>
Stephan Pleines682928d2024-05-31 20:43:48 -07005#include <errno.h>
6#include <signal.h>
7#include <string.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -07008#include <sys/mman.h>
9#include <sys/procfs.h>
10#include <sys/ptrace.h>
11#include <sys/syscall.h>
12#include <sys/uio.h>
Stephan Pleines682928d2024-05-31 20:43:48 -070013#include <sys/wait.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -070014#include <unistd.h>
Austin Schuhfaec51a2023-09-08 17:43:32 -070015
Stephan Pleines682928d2024-05-31 20:43:48 -070016#include <atomic>
17#include <new>
18#include <optional>
19#include <ostream>
20#include <string>
Austin Schuhfaec51a2023-09-08 17:43:32 -070021#include <thread>
Stephan Pleines682928d2024-05-31 20:43:48 -070022#include <utility>
Austin Schuhfaec51a2023-09-08 17:43:32 -070023
24#include "glog/logging.h"
25#include "gtest/gtest.h"
26
27#include "aos/ipc_lib/aos_sync.h"
Austin Schuhfaec51a2023-09-08 17:43:32 -070028#include "aos/ipc_lib/shm_observers.h"
29#include "aos/libc/aos_strsignal.h"
30#include "aos/testing/prevent_exit.h"
31
32#ifdef SUPPORTS_SHM_ROBUSTNESS_TEST
33
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -080034namespace aos::ipc_lib::testing {
Austin Schuhfaec51a2023-09-08 17:43:32 -070035
36namespace {
37pid_t gettid() { return syscall(SYS_gettid); }
38
39::std::atomic<GlobalState *> global_state;
40
41struct sigaction old_segv_handler, old_trap_handler;
42
43// Calls the original signal handler.
44bool CallChainedAction(const struct sigaction &action, int signal,
45 siginfo_t *siginfo, void *context) {
46 if (action.sa_handler == SIG_IGN || action.sa_handler == SIG_DFL) {
47 return false;
48 }
49 if (action.sa_flags & SA_SIGINFO) {
50 action.sa_sigaction(signal, siginfo, context);
51 } else {
52 action.sa_handler(signal);
53 }
54 return true;
55}
56
57void segv_handler(int signal, siginfo_t *siginfo, void *context_void) {
58 GlobalState *my_global_state = GlobalState::Get();
59 const int saved_errno = errno;
60 SIMPLE_ASSERT(signal == SIGSEGV, "wrong signal for SIGSEGV handler");
61
62 // Only process memory addresses in our shared memory block.
63 if (!my_global_state->IsInLocklessQueueMemory(siginfo->si_addr)) {
64 if (CallChainedAction(old_segv_handler, signal, siginfo, context_void)) {
65 errno = saved_errno;
66 return;
67 } else {
68 SIMPLE_ASSERT(false, "actual SIGSEGV");
69 }
70 }
71 SIMPLE_ASSERT(my_global_state->state == DieAtState::kRunning,
72 "bad state for SIGSEGV");
73
74 my_global_state->HandleWrite(siginfo->si_addr);
75
76 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
77 my_global_state->state = DieAtState::kWriting;
78 errno = saved_errno;
79
80#if defined(__x86_64__)
81 __asm__ __volatile__("int $3" ::: "memory", "cc");
82#elif defined(__aarch64__)
83 __asm__ __volatile__("brk #0" ::: "memory", "cc");
84#else
85#error Unhandled architecture
86#endif
87}
88
89// The SEGV handler has set a breakpoint 1 instruction in the future. This
90// clears it, marks memory readonly, and continues.
91void trap_handler(int signal, siginfo_t *, void * /*context*/) {
92 GlobalState *my_global_state = GlobalState::Get();
93 const int saved_errno = errno;
94 SIMPLE_ASSERT(signal == SIGTRAP, "wrong signal for SIGTRAP handler");
95
96 my_global_state->state = DieAtState::kWriting;
97 SIMPLE_ASSERT(my_global_state->state == DieAtState::kWriting,
98 "bad state for SIGTRAP");
99 my_global_state->ShmProtectOrDie(PROT_READ);
100 my_global_state->state = DieAtState::kRunning;
101 errno = saved_errno;
102}
103
104// Installs the signal handler.
105void InstallHandler(int signal, void (*handler)(int, siginfo_t *, void *),
106 struct sigaction *old_action) {
107 struct sigaction action;
108 memset(&action, 0, sizeof(action));
109 action.sa_sigaction = handler;
110 // We don't do a full normal signal handler exit with ptrace, so SA_NODEFER is
111 // necessary to keep our signal handler active.
112 action.sa_flags = SA_RESTART | SA_SIGINFO | SA_NODEFER;
113#ifdef AOS_SANITIZER_thread
114 // Tsan messes with signal handlers to check for race conditions, and it
115 // causes problems, so we have to work around it for SIGTRAP.
116 if (signal == SIGTRAP) {
117 typedef int (*SigactionType)(int, const struct sigaction *,
118 struct sigaction *);
119 SigactionType real_sigaction =
120 reinterpret_cast<SigactionType>(dlsym(RTLD_NEXT, "sigaction"));
121 if (sigaction == real_sigaction) {
122 LOG(WARNING) << "failed to work around tsan signal handling weirdness";
123 }
124 PCHECK(real_sigaction(signal, &action, old_action) == 0);
125 return;
126 }
127#endif
128 PCHECK(sigaction(signal, &action, old_action) == 0);
129}
130
131// A mutex lock is about to happen. Mark the memory rw, and check to see if we
132// should die.
133void futex_before(void *address, bool) {
134 GlobalState *my_global_state = GlobalState::Get();
135 if (my_global_state->IsInLocklessQueueMemory(address)) {
136 assert(my_global_state->state == DieAtState::kRunning);
137 my_global_state->HandleWrite(address);
138 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
139 my_global_state->state = DieAtState::kWriting;
140 }
141}
142
143// We have a manual trap for mutexes. Check to see if we were supposed to die
144// on this write (the compare/exchange for the mutex), and mark the memory ro
145// again.
146void futex_after(void *address, bool) {
147 GlobalState *my_global_state = GlobalState::Get();
148 if (my_global_state->IsInLocklessQueueMemory(address)) {
149 assert(my_global_state->state == DieAtState::kWriting);
150 my_global_state->ShmProtectOrDie(PROT_READ);
151 my_global_state->state = DieAtState::kRunning;
152 }
153}
154
155} // namespace
156
157void GlobalState::HandleWrite(void *address) {
158 uintptr_t address_offset = reinterpret_cast<uintptr_t>(address) -
159 reinterpret_cast<uintptr_t>(lockless_queue_memory);
160 if (writes_in != nullptr) {
161 SIMPLE_ASSERT(writes_in->At(current_location) == address_offset,
162 "wrong write order");
163 }
164 if (writes_out != nullptr) {
165 writes_out->Add(address_offset);
166 }
167 if (die_at != 0) {
168 if (die_at == current_location) {
169 _exit(kExitEarlyValue);
170 }
171 }
172 ++current_location;
173}
174
175GlobalState *GlobalState::Get() {
176 return global_state.load(::std::memory_order_relaxed);
177}
178
179std::tuple<GlobalState *, WritesArray *> GlobalState::MakeGlobalState() {
180 // Map the global state and memory for the Writes array so it exists across
181 // the process boundary.
182 void *shared_allocations = static_cast<GlobalState *>(
183 mmap(nullptr, sizeof(GlobalState) + sizeof(WritesArray),
184 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
185 CHECK_NE(MAP_FAILED, shared_allocations);
186
187 global_state.store(static_cast<GlobalState *>(shared_allocations));
188 void *expected_writes_shared_allocations = static_cast<void *>(
189 static_cast<uint8_t *>(shared_allocations) + sizeof(GlobalState));
190 WritesArray *expected_writes =
191 static_cast<WritesArray *>(expected_writes_shared_allocations);
192 new (expected_writes) WritesArray();
193 return std::make_pair(static_cast<GlobalState *>(shared_allocations),
194 expected_writes);
195}
196
197bool GlobalState::IsInLocklessQueueMemory(void *address) {
198 void *read_lockless_queue_memory = lockless_queue_memory;
199 if (address < read_lockless_queue_memory) {
200 return false;
201 if (reinterpret_cast<uintptr_t>(address) >
202 reinterpret_cast<uintptr_t>(read_lockless_queue_memory) +
203 lockless_queue_memory_size)
204 return false;
205 }
206 return true;
207}
208
209void GlobalState::ShmProtectOrDie(int prot) {
210 PCHECK(mprotect(lockless_queue_memory, lockless_queue_memory_size, prot) !=
211 -1)
212 << ": mprotect(" << lockless_queue_memory << ", "
213 << lockless_queue_memory_size << ", 0x" << std::hex << prot << ") failed";
214}
215
216void GlobalState::RegisterSegvAndTrapHandlers() {
217 InstallHandler(SIGSEGV, segv_handler, &old_segv_handler);
218 InstallHandler(SIGTRAP, trap_handler, &old_trap_handler);
219 CHECK_EQ(old_trap_handler.sa_handler, SIG_DFL);
220 linux_code::ipc_lib::SetShmAccessorObservers(futex_before, futex_after);
221}
222
223// gtest only allows creating fatal failures in functions returning void...
224// status is from wait(2).
225void DetectFatalFailures(int status) {
226 if (WIFEXITED(status)) {
227 FAIL() << " child returned status "
228 << ::std::to_string(WEXITSTATUS(status));
229 } else if (WIFSIGNALED(status)) {
230 FAIL() << " child exited because of signal "
231 << aos_strsignal(WTERMSIG(status));
232 } else {
233 FAIL() << " child exited with status " << ::std::hex << status;
234 }
235}
236
237// Returns true if it runs all the way through.
238bool RunFunctionDieAt(::std::function<void(void *)> prepare,
239 ::std::function<void(void *)> function,
240 bool *test_failure, size_t die_at,
241 uintptr_t writable_offset, const WritesArray *writes_in,
242 WritesArray *writes_out) {
243 GlobalState *my_global_state = GlobalState::Get();
244 my_global_state->writes_in = writes_in;
245 my_global_state->writes_out = writes_out;
246 my_global_state->die_at = die_at;
247 my_global_state->current_location = 0;
248 my_global_state->state = DieAtState::kDisabled;
249
250 const pid_t pid = fork();
251 PCHECK(pid != -1) << ": fork() failed";
252 if (pid == 0) {
253 // Run the test.
254 ::aos::testing::PreventExit();
255
256 prepare(my_global_state->lockless_queue_memory);
257
258 // Update the robust list offset.
259 linux_code::ipc_lib::SetRobustListOffset(writable_offset);
260 // Install a segv handler (to detect writes to the memory block), and a trap
261 // handler so we can single step.
262 my_global_state->RegisterSegvAndTrapHandlers();
263
264 PCHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) == 0);
265 my_global_state->ShmProtectOrDie(PROT_READ);
266 my_global_state->state = DieAtState::kRunning;
267
268 function(my_global_state->lockless_queue_memory);
269 my_global_state->state = DieAtState::kDisabled;
270 my_global_state->ShmProtectOrDie(PROT_READ | PROT_WRITE);
271 _exit(0);
272 } else {
273 // Annoying wrapper type because elf_gregset_t is an array, which C++
274 // handles poorly.
275 struct RestoreState {
276 RestoreState(elf_gregset_t regs_in) {
277 memcpy(regs, regs_in, sizeof(regs));
278 }
279 elf_gregset_t regs;
280 };
281 std::optional<RestoreState> restore_regs;
282 bool pass_trap = false;
283 // Wait until the child process dies.
284 while (true) {
285 int status;
286 pid_t waited_on = waitpid(pid, &status, 0);
287 if (waited_on == -1) {
288 if (errno == EINTR) continue;
289 PCHECK(false) << ": waitpid(" << pid << ", " << &status
290 << ", 0) failed";
291 }
292 CHECK_EQ(waited_on, pid)
293 << ": waitpid got child " << waited_on << " instead of " << pid;
294 if (WIFSTOPPED(status)) {
295 // The child was stopped via ptrace.
296 const int stop_signal = WSTOPSIG(status);
297 elf_gregset_t regs;
298 {
299 struct iovec iov;
300 iov.iov_base = &regs;
301 iov.iov_len = sizeof(regs);
302 PCHECK(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov) == 0);
303 CHECK_EQ(iov.iov_len, sizeof(regs))
304 << ": ptrace regset is the wrong size";
305 }
306 if (stop_signal == SIGSEGV) {
307 // It's a SEGV, hopefully due to writing to the shared memory which is
308 // marked read-only. We record the instruction that faulted so we can
309 // look for it while single-stepping, then deliver the signal so the
310 // child can mark it read-write and then poke us to single-step that
311 // instruction.
312
313 CHECK(!restore_regs)
314 << ": Traced child got a SEGV while single-stepping";
315 // Save all the registers to resume execution at the current location
316 // in the child.
317 restore_regs = RestoreState(regs);
318 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGSEGV) == 0);
319 continue;
320 }
321 if (stop_signal == SIGTRAP) {
322 if (pass_trap) {
323 // This is the new SIGTRAP we generated, which we just want to pass
324 // through so the child's signal handler can restore the memory to
325 // read-only
326 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, SIGTRAP) == 0);
327 pass_trap = false;
328 continue;
329 }
330 if (restore_regs) {
331 // Restore the state we saved before delivering the SEGV, and then
332 // single-step that one instruction.
333 struct iovec iov;
334 iov.iov_base = &restore_regs->regs;
335 iov.iov_len = sizeof(restore_regs->regs);
336 PCHECK(ptrace(PTRACE_SETREGSET, pid, NT_PRSTATUS, &iov) == 0);
337 restore_regs = std::nullopt;
338 PCHECK(ptrace(PTRACE_SINGLESTEP, pid, nullptr, nullptr) == 0);
339 continue;
340 }
341 // We executed the single instruction that originally faulted, so
342 // now deliver a SIGTRAP to the child so it can mark the memory
343 // read-only again.
344 pass_trap = true;
345 PCHECK(kill(pid, SIGTRAP) == 0);
346 PCHECK(ptrace(PTRACE_CONT, pid, nullptr, nullptr) == 0);
347 continue;
348 }
349 LOG(FATAL) << "Traced child was stopped with unexpected signal: "
350 << static_cast<int>(WSTOPSIG(status));
351 }
352 if (WIFEXITED(status)) {
353 if (WEXITSTATUS(status) == 0) return true;
354 if (WEXITSTATUS(status) == kExitEarlyValue) return false;
355 }
356 DetectFatalFailures(status);
357 if (test_failure) *test_failure = true;
358 return false;
359 }
360 }
361}
362
363bool RunFunctionDieAtAndCheck(const LocklessQueueConfiguration &config,
364 ::std::function<void(void *)> prepare,
365 ::std::function<void(void *)> function,
366 ::std::function<void(void *)> check,
367 bool *test_failure, size_t die_at,
368 const WritesArray *writes_in,
369 WritesArray *writes_out) {
370 // Allocate shared memory.
371 GlobalState *my_global_state = GlobalState::Get();
372 my_global_state->lockless_queue_memory_size = LocklessQueueMemorySize(config);
373 my_global_state->lockless_queue_memory = static_cast<void *>(
374 mmap(nullptr, my_global_state->lockless_queue_memory_size,
375 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
376 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory);
377
378 // And the backup used to point the robust list at.
379 my_global_state->lockless_queue_memory_lock_backup = static_cast<void *>(
380 mmap(nullptr, my_global_state->lockless_queue_memory_size,
381 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0));
382 CHECK_NE(MAP_FAILED, my_global_state->lockless_queue_memory_lock_backup);
383
384 // The writable offset tells us how to convert from a pointer in the queue to
385 // a pointer that is safe to write. This is so robust futexes don't spin the
386 // kernel when the user maps a page PROT_READ, and the kernel tries to clear
387 // the futex there.
388 const uintptr_t writable_offset =
389 reinterpret_cast<uintptr_t>(
390 my_global_state->lockless_queue_memory_lock_backup) -
391 reinterpret_cast<uintptr_t>(my_global_state->lockless_queue_memory);
392
393 bool r;
394 // Do the actual test in a new thread so any locked mutexes will be cleaned up
395 // nicely with owner-died at the end.
396 ::std::thread test_thread([&prepare, &function, &check, test_failure, die_at,
397 writes_in, writes_out, writable_offset, &r]() {
398 r = RunFunctionDieAt(prepare, function, test_failure, die_at,
399 writable_offset, writes_in, writes_out);
400 if (::testing::Test::HasFailure()) {
401 r = false;
402 if (test_failure) *test_failure = true;
403 return;
404 }
405
406 check(GlobalState::Get()->lockless_queue_memory);
407 });
408 test_thread.join();
409 return r;
410}
411
412// Tests function to make sure it handles dying after each store it makes to
413// shared memory. check should make sure function behaved correctly.
414// This will repeatedly create a new TestSharedMemory, run prepare, run
415// function, and then
416// run check, killing the process function is running in at various points. It
417// will stop if anything reports a fatal gtest failure.
418void TestShmRobustness(const LocklessQueueConfiguration &config,
419 ::std::function<void(void *)> prepare,
420 ::std::function<void(void *)> function,
421 ::std::function<void(void *)> check) {
422 auto [my_global_state, expected_writes] = GlobalState::MakeGlobalState();
423
424 bool test_failed = false;
425 ASSERT_TRUE(RunFunctionDieAtAndCheck(config, prepare, function, check,
426 &test_failed, 0, nullptr,
427 expected_writes));
428 if (test_failed) {
429 ADD_FAILURE();
430 return;
431 }
432
433 size_t die_at = 1;
434 while (true) {
435 SCOPED_TRACE("dying at " + ::std::to_string(die_at) + "/" +
436 ::std::to_string(expected_writes->size()));
437 if (RunFunctionDieAtAndCheck(config, prepare, function, check, &test_failed,
438 die_at, expected_writes, nullptr)) {
439 LOG(INFO) << "Tested " << die_at << " death points";
440 return;
441 }
442 if (test_failed) {
443 ADD_FAILURE();
444 }
445 if (::testing::Test::HasFailure()) return;
446 ++die_at;
447 }
448}
449
450SharedTid::SharedTid() {
451 // Capture the tid in the child so we can tell if it died. Use mmap so it
452 // works across the process boundary.
453 tid_ =
454 static_cast<pid_t *>(mmap(nullptr, sizeof(pid_t), PROT_READ | PROT_WRITE,
455 MAP_SHARED | MAP_ANONYMOUS, -1, 0));
456 CHECK_NE(MAP_FAILED, tid_);
457}
458
459SharedTid::~SharedTid() { CHECK_EQ(munmap(tid_, sizeof(pid_t)), 0); }
460
461void SharedTid::Set() { *tid_ = gettid(); }
462
463pid_t SharedTid::Get() { return *tid_; }
464
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -0800465} // namespace aos::ipc_lib::testing
Austin Schuhfaec51a2023-09-08 17:43:32 -0700466
467#endif // SUPPORTS_SHM_ROBSTNESS_TEST