add an implementation of memory transactions + tests
Change-Id: Id901328da7d246268b7af951729acb1fada06836
diff --git a/aos/build/aos_all.gyp b/aos/build/aos_all.gyp
index 195997a..b546087 100644
--- a/aos/build/aos_all.gyp
+++ b/aos/build/aos_all.gyp
@@ -37,6 +37,7 @@
'<(AOS)/common/common.gyp:once_test',
'<(AOS)/common/common.gyp:event_test',
'<(AOS)/common/common.gyp:queue_testutils_test',
+ '<(AOS)/common/common.gyp:transaction_test',
'<(AOS)/common/logging/logging.gyp:logging_impl_test',
'<(AOS)/common/util/util.gyp:options_test',
'<(AOS)/common/common.gyp:queue_test',
diff --git a/aos/common/BUILD b/aos/common/BUILD
index e21a9fe..45c1d72 100644
--- a/aos/common/BUILD
+++ b/aos/common/BUILD
@@ -399,3 +399,27 @@
':die',
],
)
+
+cc_library(
+ name = 'transaction',
+ hdrs = [
+ 'transaction.h',
+ ],
+ deps = [
+ '//aos/common/logging:logging_interface',
+ '//aos/common/util:compiler_memory_barrier',
+ ],
+)
+
+cc_test(
+ name = 'transaction_test',
+ srcs = [
+ 'transaction_test.cc',
+ ],
+ deps = [
+ ':transaction',
+ '//aos/testing:googletest',
+ '//aos/common/logging',
+ '//aos/common/util:death_test_log_implementation',
+ ],
+)
diff --git a/aos/common/common.gyp b/aos/common/common.gyp
index c9e1b6a..9c43936 100644
--- a/aos/common/common.gyp
+++ b/aos/common/common.gyp
@@ -361,5 +361,31 @@
'die',
],
},
+ {
+ 'target_name': 'transaction',
+ 'type': 'static_library',
+ 'sources': [
+ #'transaction.h',
+ ],
+ 'dependencies': [
+ '<(AOS)/build/aos.gyp:logging_interface',
+ ],
+ 'export_dependent_settings': [
+ '<(AOS)/build/aos.gyp:logging_interface',
+ ],
+ },
+ {
+ 'target_name': 'transaction_test',
+ 'type': 'executable',
+ 'sources': [
+ 'transaction_test.cc',
+ ],
+ 'dependencies': [
+ 'transaction',
+ '<(EXTERNALS):gtest',
+ '<(AOS)/build/aos.gyp:logging',
+ '<(AOS)/common/util/util.gyp:death_test_log_implementation',
+ ],
+ },
],
}
diff --git a/aos/common/transaction.h b/aos/common/transaction.h
new file mode 100644
index 0000000..e6e9e41
--- /dev/null
+++ b/aos/common/transaction.h
@@ -0,0 +1,103 @@
+#ifndef AOS_COMMON_TRANSACTION_H_
+#define AOS_COMMON_TRANSACTION_H_
+
+#include <stdint.h>
+
+#include <array>
+
+#include "aos/common/util/compiler_memory_barrier.h"
+#include "aos/common/logging/logging.h"
+
+namespace aos {
+namespace transaction {
+
+// Manages a LIFO stack of Work objects. Designed to help implement transactions
+// by providing a safe way to undo things etc.
+//
+// number_works Work objects are created statically and then Create is called on
+// each as it is added to the stack. When the work should do whatever it does,
+// DoWork() will be called. The work objects get no notification when they are
+// dropped off of the stack.
+//
+// Work::DoWork() must be idempotent because it may get called multiple times if
+// CompleteWork() is interrupted part of the way through.
+//
+// This class handles compiler memory barriers etc to make sure only fully
+// created works are ever invoked, and each work will be fully created by the
+// time AddWork returns. This does not mean it's safe for multiple threads to
+// interact with an instance of this class at the same time.
+template <class Work, int number_works>
+class WorkStack {
+ public:
+ // Calls DoWork() on all the works that have been added and then removes them
+ // all from the stack.
+ void CompleteWork() {
+ int current = stack_index_;
+ while (current > 0) {
+ stack_.at(--current).DoWork();
+ }
+ aos_compiler_memory_barrier();
+ stack_index_ = 0;
+ aos_compiler_memory_barrier();
+ }
+
+ // Drops all works that have been added.
+ void DropWork() {
+ stack_index_ = 0;
+ aos_compiler_memory_barrier();
+ }
+
+ // Returns true if we have any works to complete right now.
+ bool HasWork() const { return stack_index_ != 0; }
+
+ // Forwards all of its arguments to Work::Create, which it calls on the next
+ // work to be added.
+ template <class... A>
+ void AddWork(A &&... a) {
+ if (stack_index_ >= number_works) {
+ LOG(FATAL, "too many works\n");
+ }
+ stack_.at(stack_index_).Create(::std::forward<A>(a)...);
+ aos_compiler_memory_barrier();
+ ++stack_index_;
+ aos_compiler_memory_barrier();
+ }
+
+ private:
+ // The next index into stack_ for a new work to be added.
+ int stack_index_ = 0;
+ ::std::array<Work, number_works> stack_;
+};
+
+// When invoked, sets *pointer to the value it had when this work was Created.
+template <class T>
+class RestoreValueWork {
+ public:
+ void Create(T *pointer) {
+ pointer_ = pointer;
+ value_ = *pointer;
+ }
+ void DoWork() {
+ *pointer_ = value_;
+ }
+
+ private:
+ T *pointer_;
+ T value_;
+};
+
+// Handles the casting necessary to restore any kind of pointer.
+class RestorePointerWork : public RestoreValueWork<void *> {
+ public:
+ template <class T>
+ void Create(T **pointer) {
+ static_assert(sizeof(T *) == sizeof(void *),
+ "that's a weird pointer");
+ RestoreValueWork<void *>::Create(reinterpret_cast<void **>(pointer));
+ }
+};
+
+} // namespace transaction
+} // namespace aos
+
+#endif // AOS_COMMON_TRANSACTION_H_
diff --git a/aos/common/transaction_test.cc b/aos/common/transaction_test.cc
new file mode 100644
index 0000000..e74fd69
--- /dev/null
+++ b/aos/common/transaction_test.cc
@@ -0,0 +1,106 @@
+#include "aos/common/transaction.h"
+
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "aos/common/util/death_test_log_implementation.h"
+
+namespace aos {
+namespace transaction {
+namespace testing {
+
+class WorkStackTest : public ::testing::Test {
+ public:
+ // Contains an index which it adds to the created_works and invoked_works
+ // vectors of its containing WorkStackTest.
+ class TestWork {
+ public:
+ void Create(WorkStackTest *test, int i) {
+ test->created_works()->push_back(i);
+ i_ = i;
+ test_ = test;
+ }
+ void DoWork() {
+ test_->invoked_works()->push_back(i_);
+ }
+
+ int i() const { return i_; }
+
+ private:
+ int i_;
+ WorkStackTest *test_;
+ };
+
+ ::std::vector<int> *created_works() { return &created_works_; }
+ ::std::vector<int> *invoked_works() { return &invoked_works_; }
+ WorkStack<TestWork, 20> *work_stack() { return &work_stack_; }
+
+ // Creates a TestWork with index i and adds it to work_stack().
+ void CreateWork(int i) {
+ work_stack_.AddWork(this, i);
+ }
+
+ private:
+ ::std::vector<int> created_works_, invoked_works_;
+ WorkStack<TestWork, 20> work_stack_;
+};
+
+typedef WorkStackTest WorkStackDeathTest;
+
+TEST_F(WorkStackTest, Basic) {
+ EXPECT_FALSE(work_stack()->HasWork());
+ EXPECT_EQ(0u, created_works()->size());
+ EXPECT_EQ(0u, invoked_works()->size());
+
+ CreateWork(971);
+ EXPECT_TRUE(work_stack()->HasWork());
+ EXPECT_EQ(1u, created_works()->size());
+ EXPECT_EQ(0u, invoked_works()->size());
+ EXPECT_EQ(971, created_works()->at(0));
+
+ work_stack()->CompleteWork();
+ EXPECT_FALSE(work_stack()->HasWork());
+ EXPECT_EQ(1u, created_works()->size());
+ EXPECT_EQ(1u, invoked_works()->size());
+ EXPECT_EQ(971, invoked_works()->at(0));
+}
+
+TEST_F(WorkStackTest, DropWork) {
+ CreateWork(971);
+ CreateWork(254);
+ EXPECT_EQ(2u, created_works()->size());
+
+ work_stack()->DropWork();
+ EXPECT_FALSE(work_stack()->HasWork());
+ work_stack()->CompleteWork();
+ EXPECT_EQ(0u, invoked_works()->size());
+}
+
+// Tests that the works get run in the correct order.
+TEST_F(WorkStackTest, InvocationOrder) {
+ CreateWork(971);
+ CreateWork(254);
+ CreateWork(1678);
+
+ work_stack()->CompleteWork();
+ EXPECT_EQ((::std::vector<int>{971, 254, 1678}), *created_works());
+ EXPECT_EQ((::std::vector<int>{1678, 254, 971}), *invoked_works());
+}
+
+// Tests that it handles adding too many works intelligently.
+TEST_F(WorkStackDeathTest, TooManyWorks) {
+ logging::Init();
+ EXPECT_DEATH(
+ {
+ logging::AddImplementation(new util::DeathTestLogImplementation());
+ for (int i = 0; i < 1000; ++i) {
+ CreateWork(i);
+ }
+ },
+ ".*too many works.*");
+}
+
+} // namespace testing
+} // namespace transaction
+} // namespace aos
diff --git a/aos/common/util/BUILD b/aos/common/util/BUILD
index 646f3bf..bd70463 100644
--- a/aos/common/util/BUILD
+++ b/aos/common/util/BUILD
@@ -172,3 +172,10 @@
'//aos/testing:googletest',
],
)
+
+cc_library(
+ name = 'compiler_memory_barrier',
+ hdrs = [
+ 'compiler_memory_barrier.h',
+ ],
+)
diff --git a/aos/common/util/compiler_memory_barrier.h b/aos/common/util/compiler_memory_barrier.h
new file mode 100644
index 0000000..5126941
--- /dev/null
+++ b/aos/common/util/compiler_memory_barrier.h
@@ -0,0 +1,11 @@
+#ifndef AOS_COMMON_UTIL_COMPILER_MEMORY_BARRIER_H_
+#define AOS_COMMON_UTIL_COMPILER_MEMORY_BARRIER_H_
+
+// Prevents the compiler from reordering memory operations around this.
+// Using this function makes it clearer what you're doing and easier to be
+// portable.
+static inline void aos_compiler_memory_barrier(void) {
+ __asm__ __volatile__("" ::: "memory");
+}
+
+#endif // AOS_COMMON_UTIL_COMPILER_MEMORY_BARRIER_H_