blob: 2b8fcb2aaa488133b447b5382f6883c39bb86162 [file] [log] [blame]
Brian Silvermanf7bd1c22015-12-24 16:07:11 -08001/*----------------------------------------------------------------------------*/
2/* Copyright (c) FIRST 2015. All Rights Reserved. */
3/* Open Source Software - may be modified and shared by FRC teams. The code */
4/* must be accompanied by the FIRST BSD license file in the root directory of */
5/* the project. */
6/*----------------------------------------------------------------------------*/
7
8#include "Dispatcher.h"
9
10#include <algorithm>
11#include <iterator>
12
13#include "tcpsockets/TCPAcceptor.h"
14#include "tcpsockets/TCPConnector.h"
15#include "Log.h"
16
17using namespace nt;
18
19ATOMIC_STATIC_INIT(Dispatcher)
20
21void Dispatcher::StartServer(StringRef persist_filename,
22 const char* listen_address, unsigned int port) {
23 DispatcherBase::StartServer(persist_filename,
24 std::unique_ptr<NetworkAcceptor>(new TCPAcceptor(
25 static_cast<int>(port), listen_address)));
26}
27
28void Dispatcher::StartClient(const char* server_name, unsigned int port) {
29 std::string server_name_copy(server_name);
30 DispatcherBase::StartClient([=]() -> std::unique_ptr<NetworkStream> {
31 return TCPConnector::connect(server_name_copy.c_str(),
32 static_cast<int>(port), 1);
33 });
34}
35
36Dispatcher::Dispatcher()
37 : Dispatcher(Storage::GetInstance(), Notifier::GetInstance()) {}
38
39DispatcherBase::DispatcherBase(Storage& storage, Notifier& notifier)
40 : m_storage(storage), m_notifier(notifier) {
41 m_active = false;
42 m_update_rate = 100;
43}
44
45DispatcherBase::~DispatcherBase() {
46 Logger::GetInstance().SetLogger(nullptr);
47 Stop();
48}
49
50void DispatcherBase::StartServer(StringRef persist_filename,
51 std::unique_ptr<NetworkAcceptor> acceptor) {
52 {
53 std::lock_guard<std::mutex> lock(m_user_mutex);
54 if (m_active) return;
55 m_active = true;
56 }
57 m_server = true;
58 m_persist_filename = persist_filename;
59 m_server_acceptor = std::move(acceptor);
60
61 // Load persistent file. Ignore errors, but pass along warnings.
62 if (!persist_filename.empty()) {
63 bool first = true;
64 m_storage.LoadPersistent(
65 persist_filename, [&](std::size_t line, const char* msg) {
66 if (first) {
67 first = false;
68 WARNING("When reading initial persistent values from '"
69 << persist_filename << "':");
70 }
71 WARNING(persist_filename << ":" << line << ": " << msg);
72 });
73 }
74
75 using namespace std::placeholders;
76 m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3),
77 m_server);
78
79 m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
80 m_clientserver_thread = std::thread(&Dispatcher::ServerThreadMain, this);
81}
82
83void DispatcherBase::StartClient(
84 std::function<std::unique_ptr<NetworkStream>()> connect) {
85 {
86 std::lock_guard<std::mutex> lock(m_user_mutex);
87 if (m_active) return;
88 m_active = true;
89 }
90 m_server = false;
91
92 using namespace std::placeholders;
93 m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3),
94 m_server);
95
96 m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
97 m_clientserver_thread =
98 std::thread(&Dispatcher::ClientThreadMain, this, connect);
99}
100
101void DispatcherBase::Stop() {
102 m_active = false;
103
104 // wake up dispatch thread with a flush
105 m_flush_cv.notify_one();
106
107 // wake up client thread with a reconnect
108 ClientReconnect();
109
110 // wake up server thread by shutting down the socket
111 if (m_server_acceptor) m_server_acceptor->shutdown();
112
113 // join threads, with timeout
114 if (m_dispatch_thread.joinable()) m_dispatch_thread.join();
115 if (m_clientserver_thread.joinable()) m_clientserver_thread.join();
116
117 std::vector<std::shared_ptr<NetworkConnection>> conns;
118 {
119 std::lock_guard<std::mutex> lock(m_user_mutex);
120 conns.swap(m_connections);
121 }
122
123 // close all connections
124 conns.resize(0);
125}
126
127void DispatcherBase::SetUpdateRate(double interval) {
128 // don't allow update rates faster than 100 ms or slower than 1 second
129 if (interval < 0.1)
130 interval = 0.1;
131 else if (interval > 1.0)
132 interval = 1.0;
133 m_update_rate = static_cast<unsigned int>(interval * 1000);
134}
135
136void DispatcherBase::SetIdentity(llvm::StringRef name) {
137 std::lock_guard<std::mutex> lock(m_user_mutex);
138 m_identity = name;
139}
140
141void DispatcherBase::Flush() {
142 auto now = std::chrono::steady_clock::now();
143 {
144 std::lock_guard<std::mutex> lock(m_flush_mutex);
145 // don't allow flushes more often than every 100 ms
146 if ((now - m_last_flush) < std::chrono::milliseconds(100))
147 return;
148 m_last_flush = now;
149 m_do_flush = true;
150 }
151 m_flush_cv.notify_one();
152}
153
154std::vector<ConnectionInfo> DispatcherBase::GetConnections() const {
155 std::vector<ConnectionInfo> conns;
156 if (!m_active) return conns;
157
158 std::lock_guard<std::mutex> lock(m_user_mutex);
159 for (auto& conn : m_connections) {
160 if (conn->state() != NetworkConnection::kActive) continue;
161 conns.emplace_back(conn->info());
162 }
163
164 return conns;
165}
166
167void DispatcherBase::NotifyConnections(
168 ConnectionListenerCallback callback) const {
169 std::lock_guard<std::mutex> lock(m_user_mutex);
170 for (auto& conn : m_connections) {
171 if (conn->state() != NetworkConnection::kActive) continue;
172 m_notifier.NotifyConnection(true, conn->info(), callback);
173 }
174}
175
176void DispatcherBase::DispatchThreadMain() {
177 auto timeout_time = std::chrono::steady_clock::now();
178
179 static const auto save_delta_time = std::chrono::seconds(1);
180 auto next_save_time = timeout_time + save_delta_time;
181
182 int count = 0;
183
184 std::unique_lock<std::mutex> flush_lock(m_flush_mutex);
185 while (m_active) {
186 // handle loop taking too long
187 auto start = std::chrono::steady_clock::now();
188 if (start > timeout_time)
189 timeout_time = start;
190
191 // wait for periodic or when flushed
192 timeout_time += std::chrono::milliseconds(m_update_rate);
193 m_flush_cv.wait_until(flush_lock, timeout_time,
194 [&] { return !m_active || m_do_flush; });
195 m_do_flush = false;
196 if (!m_active) break; // in case we were woken up to terminate
197
198 // perform periodic persistent save
199 if (m_server && !m_persist_filename.empty() && start > next_save_time) {
200 next_save_time += save_delta_time;
201 // handle loop taking too long
202 if (start > next_save_time) next_save_time = start + save_delta_time;
203 const char* err = m_storage.SavePersistent(m_persist_filename, true);
204 if (err) WARNING("periodic persistent save: " << err);
205 }
206
207 {
208 std::lock_guard<std::mutex> user_lock(m_user_mutex);
209 bool reconnect = false;
210
211 if (++count > 10) {
212 DEBUG("dispatch running " << m_connections.size() << " connections");
213 count = 0;
214 }
215
216 for (auto& conn : m_connections) {
217 // post outgoing messages if connection is active
218 // only send keep-alives on client
219 if (conn->state() == NetworkConnection::kActive)
220 conn->PostOutgoing(!m_server);
221
222 // if client, reconnect if connection died
223 if (!m_server && conn->state() == NetworkConnection::kDead)
224 reconnect = true;
225 }
226 // reconnect if we disconnected (and a reconnect is not in progress)
227 if (reconnect && !m_do_reconnect) {
228 m_do_reconnect = true;
229 m_reconnect_cv.notify_one();
230 }
231 }
232 }
233}
234
235void DispatcherBase::QueueOutgoing(std::shared_ptr<Message> msg,
236 NetworkConnection* only,
237 NetworkConnection* except) {
238 std::lock_guard<std::mutex> user_lock(m_user_mutex);
239 for (auto& conn : m_connections) {
240 if (conn.get() == except) continue;
241 if (only && conn.get() != only) continue;
242 auto state = conn->state();
243 if (state != NetworkConnection::kSynchronized &&
244 state != NetworkConnection::kActive) continue;
245 conn->QueueOutgoing(msg);
246 }
247}
248
249void DispatcherBase::ServerThreadMain() {
250 if (m_server_acceptor->start() != 0) {
251 m_active = false;
252 return;
253 }
254 while (m_active) {
255 auto stream = m_server_acceptor->accept();
256 if (!stream) {
257 m_active = false;
258 return;
259 }
260 if (!m_active) return;
261 DEBUG("server: client connection from " << stream->getPeerIP() << " port "
262 << stream->getPeerPort());
263
264 // add to connections list
265 using namespace std::placeholders;
266 auto conn = std::make_shared<NetworkConnection>(
267 std::move(stream), m_notifier,
268 std::bind(&Dispatcher::ServerHandshake, this, _1, _2, _3),
269 std::bind(&Storage::GetEntryType, &m_storage, _1));
270 conn->set_process_incoming(
271 std::bind(&Storage::ProcessIncoming, &m_storage, _1, _2,
272 std::weak_ptr<NetworkConnection>(conn)));
273 {
274 std::lock_guard<std::mutex> lock(m_user_mutex);
275 // reuse dead connection slots
276 bool placed = false;
277 for (auto& c : m_connections) {
278 if (c->state() == NetworkConnection::kDead) {
279 c = conn;
280 placed = true;
281 break;
282 }
283 }
284 if (!placed) m_connections.emplace_back(conn);
285 conn->Start();
286 }
287 }
288}
289
290void DispatcherBase::ClientThreadMain(
291 std::function<std::unique_ptr<NetworkStream>()> connect) {
292 while (m_active) {
293 // sleep between retries
294 std::this_thread::sleep_for(std::chrono::milliseconds(500));
295
296 // try to connect (with timeout)
297 DEBUG("client trying to connect");
298 auto stream = connect();
299 if (!stream) continue; // keep retrying
300 DEBUG("client connected");
301
302 std::unique_lock<std::mutex> lock(m_user_mutex);
303 using namespace std::placeholders;
304 auto conn = std::make_shared<NetworkConnection>(
305 std::move(stream), m_notifier,
306 std::bind(&Dispatcher::ClientHandshake, this, _1, _2, _3),
307 std::bind(&Storage::GetEntryType, &m_storage, _1));
308 conn->set_process_incoming(
309 std::bind(&Storage::ProcessIncoming, &m_storage, _1, _2,
310 std::weak_ptr<NetworkConnection>(conn)));
311 m_connections.resize(0); // disconnect any current
312 m_connections.emplace_back(conn);
313 conn->set_proto_rev(m_reconnect_proto_rev);
314 conn->Start();
315
316 // block until told to reconnect
317 m_do_reconnect = false;
318 m_reconnect_cv.wait(lock, [&] { return !m_active || m_do_reconnect; });
319 }
320}
321
322bool DispatcherBase::ClientHandshake(
323 NetworkConnection& conn,
324 std::function<std::shared_ptr<Message>()> get_msg,
325 std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
326 // get identity
327 std::string self_id;
328 {
329 std::lock_guard<std::mutex> lock(m_user_mutex);
330 self_id = m_identity;
331 }
332
333 // send client hello
334 DEBUG("client: sending hello");
335 send_msgs(Message::ClientHello(self_id));
336
337 // wait for response
338 auto msg = get_msg();
339 if (!msg) {
340 // disconnected, retry
341 DEBUG("client: server disconnected before first response");
342 return false;
343 }
344
345 if (msg->Is(Message::kProtoUnsup)) {
346 if (msg->id() == 0x0200) ClientReconnect(0x0200);
347 return false;
348 }
349
350 bool new_server = true;
351 if (conn.proto_rev() >= 0x0300) {
352 // should be server hello; if not, disconnect.
353 if (!msg->Is(Message::kServerHello)) return false;
354 conn.set_remote_id(msg->str());
355 if ((msg->flags() & 1) != 0) new_server = false;
356 // get the next message
357 msg = get_msg();
358 }
359
360 // receive initial assignments
361 std::vector<std::shared_ptr<Message>> incoming;
362 for (;;) {
363 if (!msg) {
364 // disconnected, retry
365 DEBUG("client: server disconnected during initial entries");
366 return false;
367 }
368 DEBUG4("received init str=" << msg->str() << " id=" << msg->id()
369 << " seq_num=" << msg->seq_num_uid());
370 if (msg->Is(Message::kServerHelloDone)) break;
371 if (!msg->Is(Message::kEntryAssign)) {
372 // unexpected message
373 DEBUG("client: received message (" << msg->type() << ") other than entry assignment during initial handshake");
374 return false;
375 }
376 incoming.emplace_back(std::move(msg));
377 // get the next message
378 msg = get_msg();
379 }
380
381 // generate outgoing assignments
382 NetworkConnection::Outgoing outgoing;
383
384 m_storage.ApplyInitialAssignments(conn, incoming, new_server, &outgoing);
385
386 if (conn.proto_rev() >= 0x0300)
387 outgoing.emplace_back(Message::ClientHelloDone());
388
389 if (!outgoing.empty()) send_msgs(outgoing);
390
391 INFO("client: CONNECTED to server " << conn.stream().getPeerIP() << " port "
392 << conn.stream().getPeerPort());
393 return true;
394}
395
396bool DispatcherBase::ServerHandshake(
397 NetworkConnection& conn,
398 std::function<std::shared_ptr<Message>()> get_msg,
399 std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
400 // Wait for the client to send us a hello.
401 auto msg = get_msg();
402 if (!msg) {
403 DEBUG("server: client disconnected before sending hello");
404 return false;
405 }
406 if (!msg->Is(Message::kClientHello)) {
407 DEBUG("server: client initial message was not client hello");
408 return false;
409 }
410
411 // Check that the client requested version is not too high.
412 unsigned int proto_rev = msg->id();
413 if (proto_rev > 0x0300) {
414 DEBUG("server: client requested proto > 0x0300");
415 send_msgs(Message::ProtoUnsup());
416 return false;
417 }
418
419 if (proto_rev >= 0x0300) conn.set_remote_id(msg->str());
420
421 // Set the proto version to the client requested version
422 DEBUG("server: client protocol " << proto_rev);
423 conn.set_proto_rev(proto_rev);
424
425 // Send initial set of assignments
426 NetworkConnection::Outgoing outgoing;
427
428 // Start with server hello. TODO: initial connection flag
429 if (proto_rev >= 0x0300) {
430 std::lock_guard<std::mutex> lock(m_user_mutex);
431 outgoing.emplace_back(Message::ServerHello(0u, m_identity));
432 }
433
434 // Get snapshot of initial assignments
435 m_storage.GetInitialAssignments(conn, &outgoing);
436
437 // Finish with server hello done
438 outgoing.emplace_back(Message::ServerHelloDone());
439
440 // Batch transmit
441 DEBUG("server: sending initial assignments");
442 send_msgs(outgoing);
443
444 // In proto rev 3.0 and later, the handshake concludes with a client hello
445 // done message, so we can batch the assigns before marking the connection
446 // active. In pre-3.0, we need to just immediately mark it active and hand
447 // off control to the dispatcher to assign them as they arrive.
448 if (proto_rev >= 0x0300) {
449 // receive client initial assignments
450 std::vector<std::shared_ptr<Message>> incoming;
451 msg = get_msg();
452 for (;;) {
453 if (!msg) {
454 // disconnected, retry
455 DEBUG("server: disconnected waiting for initial entries");
456 return false;
457 }
458 if (msg->Is(Message::kClientHelloDone)) break;
459 if (!msg->Is(Message::kEntryAssign)) {
460 // unexpected message
461 DEBUG("server: received message ("
462 << msg->type()
463 << ") other than entry assignment during initial handshake");
464 return false;
465 }
466 incoming.push_back(msg);
467 // get the next message (blocks)
468 msg = get_msg();
469 }
470 for (auto& msg : incoming)
471 m_storage.ProcessIncoming(msg, &conn, std::weak_ptr<NetworkConnection>());
472 }
473
474 INFO("server: client CONNECTED: " << conn.stream().getPeerIP() << " port "
475 << conn.stream().getPeerPort());
476 return true;
477}
478
479void DispatcherBase::ClientReconnect(unsigned int proto_rev) {
480 if (m_server) return;
481 {
482 std::lock_guard<std::mutex> lock(m_user_mutex);
483 m_reconnect_proto_rev = proto_rev;
484 m_do_reconnect = true;
485 }
486 m_reconnect_cv.notify_one();
487}