Add a helper class for implementing intrusive linked lists
Change-Id: I6a5b25be916c6cfeb760fc3ed38637914e8af3a4
diff --git a/aos/common/util/BUILD b/aos/common/util/BUILD
index bd70463..7f045bd 100644
--- a/aos/common/util/BUILD
+++ b/aos/common/util/BUILD
@@ -179,3 +179,25 @@
'compiler_memory_barrier.h',
],
)
+
+cc_library(
+ name = 'linked_list',
+ hdrs = [
+ 'linked_list.h',
+ ],
+ deps = [
+ '//aos/common:transaction',
+ ],
+)
+
+cc_test(
+ name = 'linked_list_test',
+ srcs = [
+ 'linked_list_test.cc',
+ ],
+ deps = [
+ ':linked_list',
+ '//aos/testing:googletest',
+ '//aos/common/logging',
+ ],
+)
diff --git a/aos/common/util/linked_list.h b/aos/common/util/linked_list.h
new file mode 100644
index 0000000..5821389
--- /dev/null
+++ b/aos/common/util/linked_list.h
@@ -0,0 +1,113 @@
+#ifndef AOS_COMMON_UTIL_LINKED_LIST_H_
+#define AOS_COMMON_UTIL_LINKED_LIST_H_
+
+#include <functional>
+
+#include "aos/common/transaction.h"
+
+namespace aos {
+namespace util {
+
+// Handles manipulating an intrusive linked list. T must look like the
+// following:
+// struct T {
+// ...
+// T *next;
+// ...
+// };
+// This class doesn't deal with creating or destroying them, so
+// constructors/destructors/other members variables/member functions are all
+// fine, but the next pointer must be there for this class to work.
+// This class will handle all manipulations of next. It does not need to be
+// initialized before calling Add and should not be changed afterwards.
+// next can (and probably should) be private if the appropriate instantiation of
+// this class is friended.
+template <class T>
+class LinkedList {
+ public:
+ T *head() const { return head_; }
+
+ bool Empty() const { return head() == nullptr; }
+
+ void Add(T *t) {
+ Add<0>(t, nullptr);
+ }
+
+ // restore_points (if non-null) will be used so the operation can be safely
+ // reverted at any point.
+ template <int number_works>
+ void Add(T *t, transaction::WorkStack<transaction::RestorePointerWork,
+ number_works> *restore_pointers) {
+ if (restore_pointers != nullptr) restore_pointers->AddWork(&t->next);
+ t->next = head();
+ if (restore_pointers != nullptr) restore_pointers->AddWork(&head_);
+ head_ = t;
+ }
+
+ void Remove(T *t) {
+ Remove<0>(t, nullptr);
+ }
+
+ // restore_points (if non-null) will be used so the operation can be safely
+ // reverted at any point.
+ template <int number_works>
+ void Remove(T *t, transaction::WorkStack<transaction::RestorePointerWork,
+ number_works> *restore_pointers) {
+ T **pointer = &head_;
+ while (*pointer != nullptr) {
+ if (*pointer == t) {
+ if (restore_pointers != nullptr) {
+ restore_pointers->AddWork(pointer);
+ }
+ *pointer = t->next;
+ return;
+ }
+ pointer = &(*pointer)->next;
+ }
+ LOG(FATAL, "%p is not in the list\n", t);
+ }
+
+ // Calls function for each element of the list.
+ // function can modify these elements in any way except touching the next
+ // pointer (including by calling other methods of this object).
+ void Each(::std::function<void(T *)> function) const {
+ T *c = head();
+ while (c != nullptr) {
+ T *const next = c->next;
+ function(c);
+ c = next;
+ }
+ }
+
+ // Returns the first element of the list where function returns true or
+ // nullptr if it returns false for all.
+ T *Find(::std::function<bool(const T *)> function) const {
+ T *c = head();
+ while (c != nullptr) {
+ if (function(c)) return c;
+ c = c->next;
+ }
+ return nullptr;
+ }
+
+ private:
+ T *head_ = nullptr;
+};
+
+// Keeps track of something along with a next pointer. Useful for things that
+// either have types without next pointers or for storing pointers to things
+// that belong in multiple lists.
+template <class V>
+struct LinkedListReference {
+ V item;
+
+ private:
+ friend class LinkedList<LinkedListReference>;
+
+ LinkedListReference *next;
+};
+
+} // namespace util
+} // namespace aos
+
+#endif // AOS_COMMON_UTIL_LINKED_LIST_H_
diff --git a/aos/common/util/linked_list_test.cc b/aos/common/util/linked_list_test.cc
new file mode 100644
index 0000000..0f4b2c8
--- /dev/null
+++ b/aos/common/util/linked_list_test.cc
@@ -0,0 +1,119 @@
+#include "aos/common/util/linked_list.h"
+
+#include <vector>
+
+#include "gtest/gtest.h"
+
+namespace aos {
+namespace util {
+namespace testing {
+
+class LinkedListTest : public ::testing::Test {
+ public:
+ virtual ~LinkedListTest() {
+ while (list.head() != nullptr) {
+ RemoveElement(list.head());
+ }
+ }
+
+ struct Member {
+ Member(int i) : i(i) {}
+
+ int i;
+ Member *next = nullptr;
+ };
+ LinkedList<Member> list;
+
+ Member *AddElement(int i) {
+ Member *member = new Member(i);
+ list.Add(member);
+ return member;
+ }
+
+ void RemoveElement(Member *member) {
+ list.Remove(member);
+ delete member;
+ }
+
+ Member *GetMember(int i) const {
+ return list.Find([i](const Member *member) { return member->i == i; });
+ }
+
+ bool HasMember(int i) const { return GetMember(i) != nullptr; }
+
+ ::std::vector<int> GetMembers() {
+ ::std::vector<int> r;
+ list.Each([&r](Member *member) { r.push_back(member->i); });
+ return r;
+ }
+};
+
+// Tests that adding and removing elements works correctly.
+TEST_F(LinkedListTest, Basic) {
+ EXPECT_TRUE(list.Empty());
+ AddElement(971);
+ EXPECT_FALSE(list.Empty());
+ AddElement(254);
+ EXPECT_FALSE(list.Empty());
+ AddElement(1678);
+ EXPECT_FALSE(list.Empty());
+
+ EXPECT_EQ((::std::vector<int>{1678, 254, 971}), GetMembers());
+
+ EXPECT_EQ(1678, list.head()->i);
+ RemoveElement(list.head());
+ EXPECT_EQ(254, list.head()->i);
+ EXPECT_FALSE(list.Empty());
+ RemoveElement(list.head());
+ EXPECT_EQ(971, list.head()->i);
+ EXPECT_FALSE(list.Empty());
+ RemoveElement(list.head());
+ EXPECT_TRUE(list.Empty());
+}
+
+TEST_F(LinkedListTest, Each) {
+ ::std::vector<int> found;
+ auto add_to_found = [&found](Member *member) {
+ found.push_back(member->i);
+ };
+
+ AddElement(971);
+ found.clear();
+ list.Each(add_to_found);
+ EXPECT_EQ((::std::vector<int>{971}), found);
+
+ AddElement(254);
+ found.clear();
+ list.Each(add_to_found);
+ EXPECT_EQ((::std::vector<int>{254, 971}), found);
+
+ AddElement(1678);
+ found.clear();
+ list.Each(add_to_found);
+ EXPECT_EQ((::std::vector<int>{1678, 254, 971}), found);
+}
+
+TEST_F(LinkedListTest, Find) {
+ auto find_254 = [](const Member *member) { return member->i == 254; };
+
+ AddElement(971);
+ EXPECT_EQ(nullptr, list.Find(find_254));
+ Member *member = AddElement(254);
+ EXPECT_EQ(member, list.Find(find_254));
+ AddElement(1678);
+ EXPECT_EQ(member, list.Find(find_254));
+}
+
+// Removing an element from the middle of the list used to break it.
+TEST_F(LinkedListTest, RemoveFromMiddle) {
+ AddElement(971);
+ auto in_middle = AddElement(254);
+ AddElement(1678);
+ RemoveElement(in_middle);
+
+ EXPECT_EQ((::std::vector<int>{1678, 971}), GetMembers());
+}
+
+} // namespace testing
+} // namespace util
+} // namespace aos