blob: 2b1bf872dc9f27a5799d0228cebec3129b7f8ef7 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2018 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: vitus@google.com (Michael Vitus)
30
31// This include must come before any #ifndef check on Ceres compile options.
32#include "ceres/internal/port.h"
33
34#ifdef CERES_USE_CXX11_THREADS
35
36#include "ceres/thread_pool.h"
37
38#include <chrono>
39#include <condition_variable>
40#include <mutex>
41#include <thread>
42
43#include "gmock/gmock.h"
44#include "gtest/gtest.h"
45#include "glog/logging.h"
46
47namespace ceres {
48namespace internal {
49
50// Adds a number of tasks to the thread pool and ensures they all run.
51TEST(ThreadPool, AddTask) {
52 int value = 0;
53 const int num_tasks = 100;
54 {
55 ThreadPool thread_pool(2);
56
57 std::condition_variable condition;
58 std::mutex mutex;
59
60 for (int i = 0; i < num_tasks; ++i) {
61 thread_pool.AddTask([&]() {
62 std::lock_guard<std::mutex> lock(mutex);
63 ++value;
64 condition.notify_all();
65 });
66 }
67
68 std::unique_lock<std::mutex> lock(mutex);
69 condition.wait(lock, [&](){return value == num_tasks;});
70 }
71
72 EXPECT_EQ(num_tasks, value);
73}
74
75// Adds a number of tasks to the queue and resizes the thread pool while the
76// threads are executing their work.
77TEST(ThreadPool, ResizingDuringExecution) {
78 int value = 0;
79
80 const int num_tasks = 100;
81
82 // Run this test in a scope to delete the thread pool and all of the threads
83 // are stopped.
84 {
85 ThreadPool thread_pool(/*num_threads=*/2);
86
87 std::condition_variable condition;
88 std::mutex mutex;
89
90 // Acquire a lock on the mutex to prevent the threads from finishing their
91 // execution so we can test resizing the thread pool while the workers are
92 // executing a task.
93 std::unique_lock<std::mutex> lock(mutex);
94
95 // The same task for all of the workers to execute.
96 auto task = [&]() {
97 // This will block until the mutex is released inside the condition
98 // variable.
99 std::lock_guard<std::mutex> lock(mutex);
100 ++value;
101 condition.notify_all();
102 };
103
104 // Add the initial set of tasks to run.
105 for (int i = 0; i < num_tasks / 2; ++i) {
106 thread_pool.AddTask(task);
107 }
108
109 // Resize the thread pool while tasks are executing.
110 thread_pool.Resize(/*num_threads=*/3);
111
112 // Add more tasks to the thread pool to guarantee these are also completed.
113 for (int i = 0; i < num_tasks / 2; ++i) {
114 thread_pool.AddTask(task);
115 }
116
117 // Unlock the mutex to unblock all of the threads and wait until all of the
118 // tasks are completed.
119 condition.wait(lock, [&](){return value == num_tasks;});
120 }
121
122 EXPECT_EQ(num_tasks, value);
123}
124
125// Tests the destructor will wait until all running tasks are finished before
126// destructing the thread pool.
127TEST(ThreadPool, Destructor) {
128 // Ensure the hardware supports more than 1 thread to ensure the test will
129 // pass.
130 const int num_hardware_threads = std::thread::hardware_concurrency();
131 if (num_hardware_threads <= 1) {
132 LOG(ERROR)
133 << "Test not supported, the hardware does not support threading.";
134 return;
135 }
136
137 std::condition_variable condition;
138 std::mutex mutex;
139 // Lock the mutex to ensure the tasks are blocked.
140 std::unique_lock<std::mutex> master_lock(mutex);
141 int value = 0;
142
143 // Create a thread that will instantiate and delete the thread pool. This is
144 // required because we need to block on the thread pool being deleted and
145 // signal the tasks to finish.
146 std::thread thread([&]() {
147 ThreadPool thread_pool(/*num_threads=*/2);
148
149 for (int i = 0; i < 100; ++i) {
150 thread_pool.AddTask([&]() {
151 // This will block until the mutex is released inside the condition
152 // variable.
153 std::lock_guard<std::mutex> lock(mutex);
154 ++value;
155 condition.notify_all();
156 });
157 }
158 // The thread pool should be deleted.
159 });
160
161 // Give the thread pool time to start, add all the tasks, and then delete
162 // itself.
163 std::this_thread::sleep_for(std::chrono::milliseconds(500));
164
165 // Unlock the tasks.
166 master_lock.unlock();
167
168 // Wait for the thread to complete.
169 thread.join();
170
171 EXPECT_EQ(100, value);
172}
173
174TEST(ThreadPool, Resize) {
175 // Ensure the hardware supports more than 1 thread to ensure the test will
176 // pass.
177 const int num_hardware_threads = std::thread::hardware_concurrency();
178 if (num_hardware_threads <= 1) {
179 LOG(ERROR)
180 << "Test not supported, the hardware does not support threading.";
181 return;
182 }
183
184 ThreadPool thread_pool(1);
185
186 EXPECT_EQ(1, thread_pool.Size());
187
188 thread_pool.Resize(2);
189
190 EXPECT_EQ(2, thread_pool.Size());
191
192 // Try reducing the thread pool size and verify it stays the same size.
193 thread_pool.Resize(1);
194 EXPECT_EQ(2, thread_pool.Size());
195}
196
197} // namespace internal
198} // namespace ceres
199
200#endif // CERES_USE_CXX11_THREADS