Add a helper class for implementing intrusive linked lists

Change-Id: I6a5b25be916c6cfeb760fc3ed38637914e8af3a4
diff --git a/aos/common/util/BUILD b/aos/common/util/BUILD
index bd70463..7f045bd 100644
--- a/aos/common/util/BUILD
+++ b/aos/common/util/BUILD
@@ -179,3 +179,25 @@
     'compiler_memory_barrier.h',
   ],
 )
+
+cc_library(
+  name = 'linked_list',
+  hdrs = [
+    'linked_list.h',
+  ],
+  deps = [
+    '//aos/common:transaction',
+  ],
+)
+
+cc_test(
+  name = 'linked_list_test',
+  srcs = [
+    'linked_list_test.cc',
+  ],
+  deps = [
+    ':linked_list',
+    '//aos/testing:googletest',
+    '//aos/common/logging',
+  ],
+)
diff --git a/aos/common/util/linked_list.h b/aos/common/util/linked_list.h
new file mode 100644
index 0000000..5821389
--- /dev/null
+++ b/aos/common/util/linked_list.h
@@ -0,0 +1,113 @@
+#ifndef AOS_COMMON_UTIL_LINKED_LIST_H_
+#define AOS_COMMON_UTIL_LINKED_LIST_H_
+
+#include <functional>
+
+#include "aos/common/transaction.h"
+
+namespace aos {
+namespace util {
+
+// Handles manipulating an intrusive linked list. T must look like the
+// following:
+// struct T {
+//   ...
+//   T *next;
+//   ...
+// };
+// This class doesn't deal with creating or destroying them, so
+// constructors/destructors/other members variables/member functions are all
+// fine, but the next pointer must be there for this class to work.
+// This class will handle all manipulations of next. It does not need to be
+// initialized before calling Add and should not be changed afterwards.
+// next can (and probably should) be private if the appropriate instantiation of
+// this class is friended.
+template <class T>
+class LinkedList {
+ public:
+  T *head() const { return head_; }
+
+  bool Empty() const { return head() == nullptr; }
+
+  void Add(T *t) {
+    Add<0>(t, nullptr);
+  }
+
+  // restore_points (if non-null) will be used so the operation can be safely
+  // reverted at any point.
+  template <int number_works>
+  void Add(T *t, transaction::WorkStack<transaction::RestorePointerWork,
+                                        number_works> *restore_pointers) {
+    if (restore_pointers != nullptr) restore_pointers->AddWork(&t->next);
+    t->next = head();
+    if (restore_pointers != nullptr) restore_pointers->AddWork(&head_);
+    head_ = t;
+  }
+
+  void Remove(T *t) {
+    Remove<0>(t, nullptr);
+  }
+
+  // restore_points (if non-null) will be used so the operation can be safely
+  // reverted at any point.
+  template <int number_works>
+  void Remove(T *t, transaction::WorkStack<transaction::RestorePointerWork,
+                                           number_works> *restore_pointers) {
+    T **pointer = &head_;
+    while (*pointer != nullptr) {
+      if (*pointer == t) {
+        if (restore_pointers != nullptr) {
+          restore_pointers->AddWork(pointer);
+        }
+        *pointer = t->next;
+        return;
+      }
+      pointer = &(*pointer)->next;
+    }
+    LOG(FATAL, "%p is not in the list\n", t);
+  }
+
+  // Calls function for each element of the list.
+  // function can modify these elements in any way except touching the next
+  // pointer (including by calling other methods of this object).
+  void Each(::std::function<void(T *)> function) const {
+    T *c = head();
+    while (c != nullptr) {
+      T *const next = c->next;
+      function(c);
+      c = next;
+    }
+  }
+
+  // Returns the first element of the list where function returns true or
+  // nullptr if it returns false for all.
+  T *Find(::std::function<bool(const T *)> function) const {
+    T *c = head();
+    while (c != nullptr) {
+      if (function(c)) return c;
+      c = c->next;
+    }
+    return nullptr;
+  }
+
+ private:
+  T *head_ = nullptr;
+};
+
+// Keeps track of something along with a next pointer. Useful for things that
+// either have types without next pointers or for storing pointers to things
+// that belong in multiple lists.
+template <class V>
+struct LinkedListReference {
+  V item;
+
+ private:
+  friend class LinkedList<LinkedListReference>;
+
+  LinkedListReference *next;
+};
+
+}  // namespace util
+}  // namespace aos
+
+#endif  // AOS_COMMON_UTIL_LINKED_LIST_H_
diff --git a/aos/common/util/linked_list_test.cc b/aos/common/util/linked_list_test.cc
new file mode 100644
index 0000000..0f4b2c8
--- /dev/null
+++ b/aos/common/util/linked_list_test.cc
@@ -0,0 +1,119 @@
+#include "aos/common/util/linked_list.h"
+
+#include <vector>
+
+#include "gtest/gtest.h"
+
+namespace aos {
+namespace util {
+namespace testing {
+
+class LinkedListTest : public ::testing::Test {
+ public:
+  virtual ~LinkedListTest() {
+    while (list.head() != nullptr) {
+      RemoveElement(list.head());
+    }
+  }
+
+  struct Member {
+    Member(int i) : i(i) {}
+
+    int i;
+    Member *next = nullptr;
+  };
+  LinkedList<Member> list;
+
+  Member *AddElement(int i) {
+    Member *member = new Member(i);
+    list.Add(member);
+    return member;
+  }
+
+  void RemoveElement(Member *member) {
+    list.Remove(member);
+    delete member;
+  }
+
+  Member *GetMember(int i) const {
+    return list.Find([i](const Member *member) { return member->i == i; });
+  }
+
+  bool HasMember(int i) const { return GetMember(i) != nullptr; }
+
+  ::std::vector<int> GetMembers() {
+    ::std::vector<int> r;
+    list.Each([&r](Member *member) { r.push_back(member->i); });
+    return r;
+  }
+};
+
+// Tests that adding and removing elements works correctly.
+TEST_F(LinkedListTest, Basic) {
+  EXPECT_TRUE(list.Empty());
+  AddElement(971);
+  EXPECT_FALSE(list.Empty());
+  AddElement(254);
+  EXPECT_FALSE(list.Empty());
+  AddElement(1678);
+  EXPECT_FALSE(list.Empty());
+
+  EXPECT_EQ((::std::vector<int>{1678, 254, 971}), GetMembers());
+
+  EXPECT_EQ(1678, list.head()->i);
+  RemoveElement(list.head());
+  EXPECT_EQ(254, list.head()->i);
+  EXPECT_FALSE(list.Empty());
+  RemoveElement(list.head());
+  EXPECT_EQ(971, list.head()->i);
+  EXPECT_FALSE(list.Empty());
+  RemoveElement(list.head());
+  EXPECT_TRUE(list.Empty());
+}
+
+TEST_F(LinkedListTest, Each) {
+  ::std::vector<int> found;
+  auto add_to_found = [&found](Member *member) {
+    found.push_back(member->i);
+  };
+
+  AddElement(971);
+  found.clear();
+  list.Each(add_to_found);
+  EXPECT_EQ((::std::vector<int>{971}), found);
+
+  AddElement(254);
+  found.clear();
+  list.Each(add_to_found);
+  EXPECT_EQ((::std::vector<int>{254, 971}), found);
+
+  AddElement(1678);
+  found.clear();
+  list.Each(add_to_found);
+  EXPECT_EQ((::std::vector<int>{1678, 254, 971}), found);
+}
+
+TEST_F(LinkedListTest, Find) {
+  auto find_254 = [](const Member *member) { return member->i == 254; };
+
+  AddElement(971);
+  EXPECT_EQ(nullptr, list.Find(find_254));
+  Member *member = AddElement(254);
+  EXPECT_EQ(member, list.Find(find_254));
+  AddElement(1678);
+  EXPECT_EQ(member, list.Find(find_254));
+}
+
+// Removing an element from the middle of the list used to break it.
+TEST_F(LinkedListTest, RemoveFromMiddle) {
+  AddElement(971);
+  auto in_middle = AddElement(254);
+  AddElement(1678);
+  RemoveElement(in_middle);
+
+  EXPECT_EQ((::std::vector<int>{1678, 971}), GetMembers());
+}
+
+}  // namespace testing
+}  // namespace util
+}  // namespace aos