add IPCRecursiveMutexLocker + tests

Change-Id: Ie7fca7032266935278e646534bea1180136fd3d0
diff --git a/aos/common/mutex.h b/aos/common/mutex.h
index db1b012..b952644 100644
--- a/aos/common/mutex.h
+++ b/aos/common/mutex.h
@@ -43,6 +43,10 @@
   // Doesn't wait for the mutex to be unlocked if it is locked.
   State TryLock() __attribute__((warn_unused_result));
 
+  // Returns true iff the current task has this mutex locked.
+  // This is mainly for IPCRecursiveMutexLocker to use.
+  bool OwnedBySelf() const;
+
  private:
   aos_mutex impl_;
 
@@ -101,6 +105,36 @@
   DISALLOW_COPY_AND_ASSIGN(IPCMutexLocker);
 };
 
+// A version of IPCMutexLocker which only locks (and unlocks) the mutex if the
+// current task does not already hold it.
+class IPCRecursiveMutexLocker {
+ public:
+  explicit IPCRecursiveMutexLocker(Mutex *mutex)
+      : mutex_(mutex),
+        locked_(!mutex_->OwnedBySelf()),
+        owner_died_(locked_ ? mutex_->Lock() : false) {}
+  ~IPCRecursiveMutexLocker() {
+    if (__builtin_expect(!owner_died_checked_, false)) {
+      ::aos::Die("nobody checked if the previous owner of mutex %p died", this);
+    }
+    if (locked_) mutex_->Unlock();
+  }
+
+  // Whether or not the previous owner died. If this is not called at least
+  // once, the destructor will ::aos::Die.
+  __attribute__((warn_unused_result)) bool owner_died() {
+    owner_died_checked_ = true;
+    return __builtin_expect(owner_died_, false);
+  }
+
+ private:
+  Mutex *const mutex_;
+  const bool locked_, owner_died_;
+  bool owner_died_checked_ = false;
+
+  DISALLOW_COPY_AND_ASSIGN(IPCRecursiveMutexLocker);
+};
+
 }  // namespace aos
 
 #endif  // AOS_COMMON_MUTEX_H_
diff --git a/aos/common/mutex_test.cc b/aos/common/mutex_test.cc
index 56f7772..75323c8 100644
--- a/aos/common/mutex_test.cc
+++ b/aos/common/mutex_test.cc
@@ -21,7 +21,7 @@
 
 class MutexTest : public ::testing::Test {
  public:
-  Mutex test_mutex;
+  Mutex test_mutex_;
 
  protected:
   void SetUp() override {
@@ -35,39 +35,40 @@
 typedef MutexTest MutexLockerDeathTest;
 typedef MutexTest IPCMutexLockerTest;
 typedef MutexTest IPCMutexLockerDeathTest;
+typedef MutexTest IPCRecursiveMutexLockerTest;
 
 TEST_F(MutexTest, TryLock) {
-  EXPECT_EQ(Mutex::State::kLocked, test_mutex.TryLock());
-  EXPECT_EQ(Mutex::State::kUnlocked, test_mutex.TryLock());
+  EXPECT_EQ(Mutex::State::kLocked, test_mutex_.TryLock());
+  EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
 
-  test_mutex.Unlock();
+  test_mutex_.Unlock();
 }
 
 TEST_F(MutexTest, Lock) {
-  ASSERT_FALSE(test_mutex.Lock());
-  EXPECT_EQ(Mutex::State::kUnlocked, test_mutex.TryLock());
+  ASSERT_FALSE(test_mutex_.Lock());
+  EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
 
-  test_mutex.Unlock();
+  test_mutex_.Unlock();
 }
 
 TEST_F(MutexTest, Unlock) {
-  ASSERT_FALSE(test_mutex.Lock());
-  EXPECT_EQ(Mutex::State::kUnlocked, test_mutex.TryLock());
-  test_mutex.Unlock();
-  EXPECT_EQ(Mutex::State::kLocked, test_mutex.TryLock());
+  ASSERT_FALSE(test_mutex_.Lock());
+  EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
+  test_mutex_.Unlock();
+  EXPECT_EQ(Mutex::State::kLocked, test_mutex_.TryLock());
 
-  test_mutex.Unlock();
+  test_mutex_.Unlock();
 }
 
 // Sees what happens with multiple unlocks.
 TEST_F(MutexDeathTest, RepeatUnlock) {
   logging::Init();
-  ASSERT_FALSE(test_mutex.Lock());
-  test_mutex.Unlock();
+  ASSERT_FALSE(test_mutex_.Lock());
+  test_mutex_.Unlock();
   EXPECT_DEATH(
       {
         logging::AddImplementation(new util::DeathTestLogImplementation());
-        test_mutex.Unlock();
+        test_mutex_.Unlock();
       },
       ".*multiple unlock.*");
 }
@@ -78,7 +79,7 @@
   EXPECT_DEATH(
       {
         logging::AddImplementation(new util::DeathTestLogImplementation());
-        test_mutex.Unlock();
+        test_mutex_.Unlock();
       },
       ".*multiple unlock.*");
 }
@@ -88,8 +89,8 @@
   EXPECT_DEATH(
       {
         logging::AddImplementation(new util::DeathTestLogImplementation());
-        ASSERT_FALSE(test_mutex.Lock());
-        ASSERT_FALSE(test_mutex.Lock());
+        ASSERT_FALSE(test_mutex_.Lock());
+        ASSERT_FALSE(test_mutex_.Lock());
       },
       ".*multiple lock.*");
 }
@@ -134,9 +135,9 @@
 TEST_F(MutexTest, ThreadSanitizerContended) {
   int counter = 0;
   AdderThread threads[2]{
-      {&counter, &test_mutex, ::aos::time::Time::InSeconds(0.2),
+      {&counter, &test_mutex_, ::aos::time::Time::InSeconds(0.2),
        ::aos::time::Time::InSeconds(0)},
-      {&counter, &test_mutex, ::aos::time::Time::InSeconds(0),
+      {&counter, &test_mutex_, ::aos::time::Time::InSeconds(0),
        ::aos::time::Time::InSeconds(0)}, };
   for (auto &c : threads) {
     c.Start();
@@ -153,12 +154,12 @@
   int counter = 0;
   ::std::thread thread([&counter, this]() {
     for (int i = 0; i < 1000; ++i) {
-      MutexLocker locker(&test_mutex);
+      MutexLocker locker(&test_mutex_);
       ++counter;
     }
   });
   for (int i = 0; i < 1000; ++i) {
-    MutexLocker locker(&test_mutex);
+    MutexLocker locker(&test_mutex_);
     --counter;
   }
   thread.join();
@@ -170,9 +171,9 @@
 TEST_F(MutexTest, ThreadSanitizerUncontended) {
   int counter = 0;
   AdderThread threads[2]{
-      {&counter, &test_mutex, ::aos::time::Time::InSeconds(0.2),
+      {&counter, &test_mutex_, ::aos::time::Time::InSeconds(0.2),
        ::aos::time::Time::InSeconds(0)},
-      {&counter, &test_mutex, ::aos::time::Time::InSeconds(0),
+      {&counter, &test_mutex_, ::aos::time::Time::InSeconds(0),
        ::aos::time::Time::InSeconds(0)}, };
   for (auto &c : threads) {
     c.Start();
@@ -183,31 +184,90 @@
   EXPECT_EQ(2, counter);
 }
 
+namespace {
+
+class LockerThread : public util::Thread {
+ public:
+  LockerThread(Mutex *mutex, bool lock, bool unlock)
+      : mutex_(mutex), lock_(lock), unlock_(unlock) {}
+
+ private:
+  virtual void Run() override {
+    if (lock_) ASSERT_FALSE(mutex_->Lock());
+    if (unlock_) mutex_->Unlock();
+  }
+
+  Mutex *const mutex_;
+  const bool lock_, unlock_;
+};
+
+}  // namespace
+
+// Makes sure that we don't SIGSEGV or something with multiple threads.
+TEST_F(MutexTest, MultiThreadedLock) {
+  LockerThread t(&test_mutex_, true, true);
+  t.Start();
+  ASSERT_FALSE(test_mutex_.Lock());
+  test_mutex_.Unlock();
+  t.Join();
+}
+
 TEST_F(MutexLockerTest, Basic) {
   {
-    aos::MutexLocker locker(&test_mutex);
-    EXPECT_EQ(Mutex::State::kUnlocked, test_mutex.TryLock());
+    aos::MutexLocker locker(&test_mutex_);
+    EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
   }
-  EXPECT_EQ(Mutex::State::kLocked, test_mutex.TryLock());
+  EXPECT_EQ(Mutex::State::kLocked, test_mutex_.TryLock());
 
-  test_mutex.Unlock();
+  test_mutex_.Unlock();
 }
 
 TEST_F(IPCMutexLockerTest, Basic) {
   {
-    aos::IPCMutexLocker locker(&test_mutex);
-    EXPECT_EQ(Mutex::State::kUnlocked, test_mutex.TryLock());
+    aos::IPCMutexLocker locker(&test_mutex_);
+    EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
     EXPECT_FALSE(locker.owner_died());
   }
-  EXPECT_EQ(Mutex::State::kLocked, test_mutex.TryLock());
+  EXPECT_EQ(Mutex::State::kLocked, test_mutex_.TryLock());
 
-  test_mutex.Unlock();
+  test_mutex_.Unlock();
 }
 
+// Tests what happens when the caller doesn't check if the previous owner died
+// with an IPCMutexLocker.
 TEST_F(IPCMutexLockerDeathTest, NoCheckOwnerDied) {
-  EXPECT_DEATH({ aos::IPCMutexLocker locker(&test_mutex); },
+  EXPECT_DEATH({ aos::IPCMutexLocker locker(&test_mutex_); },
                "nobody checked if the previous owner of mutex [^ ]+ died.*");
 }
 
+TEST_F(IPCRecursiveMutexLockerTest, Basic) {
+  {
+    aos::IPCRecursiveMutexLocker locker(&test_mutex_);
+    EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
+    EXPECT_FALSE(locker.owner_died());
+  }
+  EXPECT_EQ(Mutex::State::kLocked, test_mutex_.TryLock());
+
+  test_mutex_.Unlock();
+}
+
+// Tests actually locking a mutex recursively with IPCRecursiveMutexLocker.
+TEST_F(IPCRecursiveMutexLockerTest, RecursiveLock) {
+  {
+    aos::IPCRecursiveMutexLocker locker(&test_mutex_);
+    EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
+    {
+      aos::IPCRecursiveMutexLocker locker(&test_mutex_);
+      EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
+      EXPECT_FALSE(locker.owner_died());
+    }
+    EXPECT_EQ(Mutex::State::kUnlocked, test_mutex_.TryLock());
+    EXPECT_FALSE(locker.owner_died());
+  }
+  EXPECT_EQ(Mutex::State::kLocked, test_mutex_.TryLock());
+
+  test_mutex_.Unlock();
+}
+
 }  // namespace testing
 }  // namespace aos
diff --git a/aos/linux_code/ipc_lib/mutex.cc b/aos/linux_code/ipc_lib/mutex.cc
index 9e270c9..796b841 100644
--- a/aos/linux_code/ipc_lib/mutex.cc
+++ b/aos/linux_code/ipc_lib/mutex.cc
@@ -15,7 +15,7 @@
 }
 
 Mutex::~Mutex() {
-  if (__builtin_expect(mutex_islocked(&impl_), 0)) {
+  if (__builtin_expect(mutex_islocked(&impl_), false)) {
     LOG(FATAL, "destroying locked mutex %p (aka %p)\n",
         this, &impl_);
   }
@@ -51,4 +51,8 @@
   }
 }
 
+bool Mutex::OwnedBySelf() const {
+  return mutex_islocked(&impl_);
+}
+
 }  // namespace aos