add support for thread-local objects with destructors
Change-Id: I19d809fec90d6d2ac4f5d772e3ab5af30025fb2e
diff --git a/aos/linux_code/complex_thread_local.cc b/aos/linux_code/complex_thread_local.cc
new file mode 100644
index 0000000..d57323e
--- /dev/null
+++ b/aos/linux_code/complex_thread_local.cc
@@ -0,0 +1,69 @@
+#include "aos/linux_code/complex_thread_local.h"
+
+#include <pthread.h>
+
+#include "aos/common/once.h"
+#include "aos/common/die.h"
+
+#define SIMPLE_CHECK(call) \
+ do { \
+ const int value = call; \
+ if (value != 0) { \
+ PRDie(value, "%s failed", #call); \
+ } \
+ } while (false)
+
+namespace aos {
+namespace {
+
+void ExecuteDestructorList(void *v) {
+ for (const ComplexThreadLocalDestructor *c =
+ static_cast<ComplexThreadLocalDestructor *>(v);
+ c != nullptr; c = c->next) {
+ c->function(c->param);
+ }
+}
+
+pthread_key_t *CreateKey() {
+ static pthread_key_t r;
+ SIMPLE_CHECK(pthread_key_create(&r, ExecuteDestructorList));
+ return &r;
+}
+
+::aos::Once<pthread_key_t> key_once(CreateKey);
+
+} // namespace
+
+void ComplexThreadLocalDestructor::Add() {
+ static_assert(
+ ::std::is_pod<ComplexThreadLocalDestructor>::value,
+ "ComplexThreadLocalDestructor might not be safe to pass through void*");
+ pthread_key_t *const key = key_once.Get();
+
+ next = static_cast<ComplexThreadLocalDestructor *>(pthread_getspecific(*key));
+ SIMPLE_CHECK(pthread_setspecific(*key, this));
+}
+
+void ComplexThreadLocalDestructor::Remove() {
+ pthread_key_t *const key = key_once.Get();
+
+ ComplexThreadLocalDestructor *previous = nullptr;
+ for (ComplexThreadLocalDestructor *c =
+ static_cast<ComplexThreadLocalDestructor *>(
+ pthread_getspecific(*key));
+ c != nullptr; c = c->next) {
+ if (c == this) {
+ // If it's the first one.
+ if (previous == nullptr) {
+ SIMPLE_CHECK(pthread_setspecific(*key, next));
+ } else {
+ previous->next = next;
+ }
+ return;
+ }
+ previous = c;
+ }
+ ::aos::Die("%p is not in the destructor list\n", this);
+}
+
+} // namespace aos
diff --git a/aos/linux_code/complex_thread_local.h b/aos/linux_code/complex_thread_local.h
new file mode 100644
index 0000000..7ada875
--- /dev/null
+++ b/aos/linux_code/complex_thread_local.h
@@ -0,0 +1,131 @@
+#ifndef AOS_LINUX_CODE_COMPLEX_THREAD_LOCAL_H_
+#define AOS_LINUX_CODE_COMPLEX_THREAD_LOCAL_H_
+
+#include <assert.h>
+
+#include <type_traits>
+#include <utility>
+
+namespace aos {
+
+// Instances form a (per-thread) list of destructor functions to call when the
+// thread exits.
+// Only ComplexThreadLocal should use this.
+struct ComplexThreadLocalDestructor {
+ // Adds this to the list of destructors in this thread.
+ void Add();
+ // Removes this from the list of destructors in this thread. ::aos::Dies if it
+ // is not there.
+ void Remove();
+
+ void (*function)(void *);
+ void *param;
+
+ ComplexThreadLocalDestructor *next;
+};
+
+// Handles creating a thread-local (per type) object with non-trivial
+// constructor and/or destructor. It will be correctly destroyed on thread exit.
+//
+// Each thread using an instantiation of this class has its own independent slot
+// for storing a T. An instance of T is not actually constructed until a thread
+// calls Create, after which a pointer to it will be returned from get() etc
+// until after Clear is called.
+//
+// Example usage:
+// class Something {
+// private:
+// class Data {
+// public:
+// Data(const ::std::string &value) : value_(value) {}
+//
+// int DoSomething() {
+// if (cached_result_ == 0) {
+// // Do something expensive with value_ and store it in
+// // cached_result_.
+// }
+// return cached_result_;
+// }
+//
+// private:
+// const ::std::string value_;
+// int cached_result_ = 0;
+// };
+// ComplexThreadLocal<Data> thread_local_;
+// ::std::string a_string_;
+//
+// int DoSomething() {
+// thread_local_.Create(a_string_);
+// return thread_local_->DoSomething();
+// }
+// };
+//
+// The current implementation is based on
+// <http://stackoverflow.com/questions/12049684/gcc-4-7-on-linux-pthreads-nontrivial-thread-local-workaround-using-thread-n>.
+// TODO(brians): Change this to just simple standard C++ thread_local once all
+// of our compilers have support.
+template <typename T>
+class ComplexThreadLocal {
+ public:
+ // Actually creates the object in this thread if there is not one there
+ // already.
+ // args are all perfectly forwarded to the constructor.
+ template <typename... Args>
+ void Create(Args &&... args) {
+ if (initialized) return;
+ new (&storage) T(::std::forward<Args>(args)...);
+ destructor.function = PlacementDelete;
+ destructor.param = &storage;
+ destructor.Add();
+ initialized = true;
+ }
+
+ // Removes the object in this thread (if any), including calling its
+ // destructor.
+ void Clear() {
+ if (!initialized) return;
+ destructor.Remove();
+ PlacementDelete(&storage);
+ initialized = false;
+ }
+
+ // Returns true if there is already an object in this thread.
+ bool created() const { return initialized; }
+
+ // Returns the object currently created in this thread or nullptr.
+ T *operator->() const {
+ return get();
+ }
+ T *get() const {
+ if (initialized) {
+ return static_cast<T *>(static_cast<void *>(&storage));
+ } else {
+ return nullptr;
+ }
+ }
+
+ private:
+ typedef typename ::std::aligned_storage<
+ sizeof(T), ::std::alignment_of<T>::value>::type Storage;
+
+ // Convenient helper for calling a destructor.
+ static void PlacementDelete(void *t) { static_cast<T *>(t)->~T(); }
+
+ // True iff this storage has been initialized.
+ static __thread bool initialized;
+ // Where we actually store the object for this thread (if any).
+ static __thread Storage storage;
+ // The linked list element representing this storage.
+ static __thread ComplexThreadLocalDestructor destructor;
+};
+
+template <typename T>
+__thread bool ComplexThreadLocal<T>::initialized;
+template <typename T>
+__thread typename ComplexThreadLocal<T>::Storage ComplexThreadLocal<T>::storage;
+template <typename T>
+__thread ComplexThreadLocalDestructor ComplexThreadLocal<T>::destructor;
+
+} // namespace aos
+
+#endif // AOS_LINUX_CODE_COMPLEX_THREAD_LOCAL_H_
diff --git a/aos/linux_code/complex_thread_local_test.cc b/aos/linux_code/complex_thread_local_test.cc
new file mode 100644
index 0000000..97f0568
--- /dev/null
+++ b/aos/linux_code/complex_thread_local_test.cc
@@ -0,0 +1,101 @@
+#include "aos/linux_code/complex_thread_local.h"
+
+#include <atomic>
+
+#include "gtest/gtest.h"
+
+#include "aos/common/util/thread.h"
+
+namespace aos {
+namespace testing {
+
+class ComplexThreadLocalTest : public ::testing::Test {
+ protected:
+ struct TraceableObject {
+ TraceableObject(int data = 0) : data(data) { ++constructions; }
+ ~TraceableObject() { ++destructions; }
+
+ static ::std::atomic<int> constructions, destructions;
+
+ int data;
+ };
+ ComplexThreadLocal<TraceableObject> local;
+
+ private:
+ void SetUp() override {
+ local.Clear();
+ EXPECT_EQ(TraceableObject::constructions, TraceableObject::destructions)
+ << "There should be no way to create and destroy different numbers.";
+ TraceableObject::constructions = TraceableObject::destructions = 0;
+ }
+};
+::std::atomic<int> ComplexThreadLocalTest::TraceableObject::constructions;
+::std::atomic<int> ComplexThreadLocalTest::TraceableObject::destructions;
+
+TEST_F(ComplexThreadLocalTest, Basic) {
+ EXPECT_EQ(0, TraceableObject::constructions);
+ EXPECT_EQ(0, TraceableObject::destructions);
+ EXPECT_FALSE(local.created());
+ EXPECT_EQ(nullptr, local.get());
+
+ local.Create(971);
+ EXPECT_EQ(1, TraceableObject::constructions);
+ EXPECT_EQ(0, TraceableObject::destructions);
+ EXPECT_TRUE(local.created());
+ EXPECT_EQ(971, local->data);
+
+ local.Create(254);
+ EXPECT_EQ(1, TraceableObject::constructions);
+ EXPECT_EQ(0, TraceableObject::destructions);
+ EXPECT_TRUE(local.created());
+ EXPECT_EQ(971, local->data);
+
+ local.Clear();
+ EXPECT_EQ(1, TraceableObject::constructions);
+ EXPECT_EQ(1, TraceableObject::destructions);
+ EXPECT_FALSE(local.created());
+ EXPECT_EQ(nullptr, local.get());
+
+ local.Create(973);
+ EXPECT_EQ(2, TraceableObject::constructions);
+ EXPECT_EQ(1, TraceableObject::destructions);
+ EXPECT_TRUE(local.created());
+ EXPECT_EQ(973, local->data);
+}
+
+TEST_F(ComplexThreadLocalTest, AnotherThread) {
+ EXPECT_FALSE(local.created());
+ util::FunctionThread::RunInOtherThread([this]() {
+ EXPECT_FALSE(local.created());
+ local.Create(971);
+ EXPECT_TRUE(local.created());
+ EXPECT_EQ(971, local->data);
+ EXPECT_EQ(1, TraceableObject::constructions);
+ EXPECT_EQ(0, TraceableObject::destructions);
+ });
+ EXPECT_EQ(1, TraceableObject::constructions);
+ EXPECT_EQ(1, TraceableObject::destructions);
+ EXPECT_FALSE(local.created());
+}
+
+TEST_F(ComplexThreadLocalTest, TwoThreads) {
+ util::FunctionThread thread([this](util::FunctionThread *) {
+ local.Create(971);
+ EXPECT_EQ(971, local->data);
+ EXPECT_EQ(0, TraceableObject::destructions);
+ });
+ thread.Start();
+ local.Create(973);
+ EXPECT_EQ(973, local->data);
+ thread.Join();
+ EXPECT_TRUE(local.created());
+ EXPECT_EQ(2, TraceableObject::constructions);
+ EXPECT_EQ(1, TraceableObject::destructions);
+ local.Clear();
+ EXPECT_EQ(2, TraceableObject::constructions);
+ EXPECT_EQ(2, TraceableObject::destructions);
+ EXPECT_FALSE(local.created());
+}
+
+} // namespace testing
+} // namespace aos
diff --git a/aos/linux_code/linux_code.gyp b/aos/linux_code/linux_code.gyp
index 74b788d..c34b42e 100644
--- a/aos/linux_code/linux_code.gyp
+++ b/aos/linux_code/linux_code.gyp
@@ -1,6 +1,30 @@
{
'targets': [
{
+ 'target_name': 'complex_thread_local',
+ 'type': 'static_library',
+ 'sources': [
+ 'complex_thread_local.cc',
+ ],
+ 'dependencies': [
+ '<(AOS)/common/common.gyp:once',
+ '<(AOS)/common/common.gyp:die',
+ ],
+ },
+ {
+ 'target_name': 'complex_thread_local_test',
+ 'type': 'executable',
+ 'sources': [
+ 'complex_thread_local_test.cc',
+ ],
+ 'dependencies': [
+ 'complex_thread_local',
+ '<(EXTERNALS):gtest',
+ '<(AOS)/common/util/util.gyp:thread',
+ '<(AOS)/build/aos.gyp:logging',
+ ],
+ },
+ {
'target_name': 'init',
'type': 'static_library',
'sources': [