blob: e2903e0a5eacbbede0e09b7de830e322428ba28c [file] [log] [blame]
Brian Silvermanf7bd1c22015-12-24 16:07:11 -08001/*----------------------------------------------------------------------------*/
2/* Copyright (c) FIRST 2015. All Rights Reserved. */
3/* Open Source Software - may be modified and shared by FRC teams. The code */
4/* must be accompanied by the FIRST BSD license file in the root directory of */
5/* the project. */
6/*----------------------------------------------------------------------------*/
7
8#include "Storage.h"
9
10#include <cctype>
11#include <string>
12#include <tuple>
13
14#include "llvm/StringExtras.h"
15#include "Base64.h"
16#include "Log.h"
17#include "NetworkConnection.h"
18
19using namespace nt;
20
21ATOMIC_STATIC_INIT(Storage)
22
23Storage::Storage()
24 : Storage(Notifier::GetInstance(), RpcServer::GetInstance()) {}
25
26Storage::Storage(Notifier& notifier, RpcServer& rpc_server)
27 : m_notifier(notifier), m_rpc_server(rpc_server) {
28 m_terminating = false;
29}
30
31Storage::~Storage() {
32 Logger::GetInstance().SetLogger(nullptr);
33 m_terminating = true;
34 m_rpc_results_cond.notify_all();
35}
36
37void Storage::SetOutgoing(QueueOutgoingFunc queue_outgoing, bool server) {
38 std::lock_guard<std::mutex> lock(m_mutex);
39 m_queue_outgoing = queue_outgoing;
40 m_server = server;
41}
42
43void Storage::ClearOutgoing() {
44 m_queue_outgoing = nullptr;
45}
46
47NT_Type Storage::GetEntryType(unsigned int id) const {
48 std::lock_guard<std::mutex> lock(m_mutex);
49 if (id >= m_idmap.size()) return NT_UNASSIGNED;
50 Entry* entry = m_idmap[id];
51 if (!entry || !entry->value) return NT_UNASSIGNED;
52 return entry->value->type();
53}
54
55void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
56 NetworkConnection* conn,
57 std::weak_ptr<NetworkConnection> conn_weak) {
58 std::unique_lock<std::mutex> lock(m_mutex);
59 switch (msg->type()) {
60 case Message::kKeepAlive:
61 break; // ignore
62 case Message::kClientHello:
63 case Message::kProtoUnsup:
64 case Message::kServerHelloDone:
65 case Message::kServerHello:
66 case Message::kClientHelloDone:
67 // shouldn't get these, but ignore if we do
68 break;
69 case Message::kEntryAssign: {
70 unsigned int id = msg->id();
71 StringRef name = msg->str();
72 Entry* entry;
73 bool may_need_update = false;
74 if (m_server) {
75 // if we're a server, id=0xffff requests are requests for an id
76 // to be assigned, and we need to send the new assignment back to
77 // the sender as well as all other connections.
78 if (id == 0xffff) {
79 // see if it was already assigned; ignore if so.
80 if (m_entries.count(name) != 0) return;
81
82 // create it locally
83 id = m_idmap.size();
84 auto& new_entry = m_entries[name];
85 if (!new_entry) new_entry.reset(new Entry(name));
86 entry = new_entry.get();
87 entry->value = msg->value();
88 entry->flags = msg->flags();
89 entry->id = id;
90 m_idmap.push_back(entry);
91
92 // update persistent dirty flag if it's persistent
93 if (entry->IsPersistent()) m_persistent_dirty = true;
94
95 // notify
96 m_notifier.NotifyEntry(name, entry->value, NT_NOTIFY_NEW);
97
98 // send the assignment to everyone (including the originator)
99 if (m_queue_outgoing) {
100 auto queue_outgoing = m_queue_outgoing;
101 auto outmsg = Message::EntryAssign(
102 name, id, entry->seq_num.value(), msg->value(), msg->flags());
103 lock.unlock();
104 queue_outgoing(outmsg, nullptr, nullptr);
105 }
106 return;
107 }
108 if (id >= m_idmap.size() || !m_idmap[id]) {
109 // ignore arbitrary entry assignments
110 // this can happen due to e.g. assignment to deleted entry
111 lock.unlock();
112 DEBUG("server: received assignment to unknown entry");
113 return;
114 }
115 entry = m_idmap[id];
116 } else {
117 // clients simply accept new assignments
118 if (id == 0xffff) {
119 lock.unlock();
120 DEBUG("client: received entry assignment request?");
121 return;
122 }
123 if (id >= m_idmap.size()) m_idmap.resize(id+1);
124 entry = m_idmap[id];
125 if (!entry) {
126 // create local
127 auto& new_entry = m_entries[name];
128 if (!new_entry) {
129 // didn't exist at all (rather than just being a response to a
130 // id assignment request)
131 new_entry.reset(new Entry(name));
132 new_entry->value = msg->value();
133 new_entry->flags = msg->flags();
134 new_entry->id = id;
135 m_idmap[id] = new_entry.get();
136
137 // notify
138 m_notifier.NotifyEntry(name, new_entry->value, NT_NOTIFY_NEW);
139 return;
140 }
141 may_need_update = true; // we may need to send an update message
142 entry = new_entry.get();
143 entry->id = id;
144 m_idmap[id] = entry;
145
146 // if the received flags don't match what we sent, we most likely
147 // updated flags locally in the interim; send flags update message.
148 if (msg->flags() != entry->flags) {
149 auto queue_outgoing = m_queue_outgoing;
150 auto outmsg = Message::FlagsUpdate(id, entry->flags);
151 lock.unlock();
152 queue_outgoing(outmsg, nullptr, nullptr);
153 lock.lock();
154 }
155 }
156 }
157
158 // common client and server handling
159
160 // already exists; ignore if sequence number not higher than local
161 SequenceNumber seq_num(msg->seq_num_uid());
162 if (seq_num < entry->seq_num) {
163 if (may_need_update) {
164 auto queue_outgoing = m_queue_outgoing;
165 auto outmsg = Message::EntryUpdate(entry->id, entry->seq_num.value(),
166 entry->value);
167 lock.unlock();
168 queue_outgoing(outmsg, nullptr, nullptr);
169 }
170 return;
171 }
172
173 // sanity check: name should match id
174 if (msg->str() != entry->name) {
175 lock.unlock();
176 DEBUG("entry assignment for same id with different name?");
177 return;
178 }
179
180 unsigned int notify_flags = NT_NOTIFY_UPDATE;
181
182 // don't update flags from a <3.0 remote (not part of message)
183 // don't update flags if this is a server response to a client id request
184 if (!may_need_update && conn->proto_rev() >= 0x0300) {
185 // update persistent dirty flag if persistent flag changed
186 if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT))
187 m_persistent_dirty = true;
188 if (entry->flags != msg->flags())
189 notify_flags |= NT_NOTIFY_FLAGS;
190 entry->flags = msg->flags();
191 }
192
193 // update persistent dirty flag if the value changed and it's persistent
194 if (entry->IsPersistent() && *entry->value != *msg->value())
195 m_persistent_dirty = true;
196
197 // update local
198 entry->value = msg->value();
199 entry->seq_num = seq_num;
200
201 // notify
202 m_notifier.NotifyEntry(name, entry->value, notify_flags);
203
204 // broadcast to all other connections (note for client there won't
205 // be any other connections, so don't bother)
206 if (m_server && m_queue_outgoing) {
207 auto queue_outgoing = m_queue_outgoing;
208 auto outmsg =
209 Message::EntryAssign(entry->name, id, msg->seq_num_uid(),
210 msg->value(), entry->flags);
211 lock.unlock();
212 queue_outgoing(outmsg, nullptr, conn);
213 }
214 break;
215 }
216 case Message::kEntryUpdate: {
217 unsigned int id = msg->id();
218 if (id >= m_idmap.size() || !m_idmap[id]) {
219 // ignore arbitrary entry updates;
220 // this can happen due to deleted entries
221 lock.unlock();
222 DEBUG("received update to unknown entry");
223 return;
224 }
225 Entry* entry = m_idmap[id];
226
227 // ignore if sequence number not higher than local
228 SequenceNumber seq_num(msg->seq_num_uid());
229 if (seq_num <= entry->seq_num) return;
230
231 // update local
232 entry->value = msg->value();
233 entry->seq_num = seq_num;
234
235 // update persistent dirty flag if it's a persistent value
236 if (entry->IsPersistent()) m_persistent_dirty = true;
237
238 // notify
239 m_notifier.NotifyEntry(entry->name, entry->value, NT_NOTIFY_UPDATE);
240
241 // broadcast to all other connections (note for client there won't
242 // be any other connections, so don't bother)
243 if (m_server && m_queue_outgoing) {
244 auto queue_outgoing = m_queue_outgoing;
245 lock.unlock();
246 queue_outgoing(msg, nullptr, conn);
247 }
248 break;
249 }
250 case Message::kFlagsUpdate: {
251 unsigned int id = msg->id();
252 if (id >= m_idmap.size() || !m_idmap[id]) {
253 // ignore arbitrary entry updates;
254 // this can happen due to deleted entries
255 lock.unlock();
256 DEBUG("received flags update to unknown entry");
257 return;
258 }
259 Entry* entry = m_idmap[id];
260
261 // ignore if flags didn't actually change
262 if (entry->flags == msg->flags()) return;
263
264 // update persistent dirty flag if persistent flag changed
265 if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT))
266 m_persistent_dirty = true;
267
268 // update local
269 entry->flags = msg->flags();
270
271 // notify
272 m_notifier.NotifyEntry(entry->name, entry->value, NT_NOTIFY_FLAGS);
273
274 // broadcast to all other connections (note for client there won't
275 // be any other connections, so don't bother)
276 if (m_server && m_queue_outgoing) {
277 auto queue_outgoing = m_queue_outgoing;
278 lock.unlock();
279 queue_outgoing(msg, nullptr, conn);
280 }
281 break;
282 }
283 case Message::kEntryDelete: {
284 unsigned int id = msg->id();
285 if (id >= m_idmap.size() || !m_idmap[id]) {
286 // ignore arbitrary entry updates;
287 // this can happen due to deleted entries
288 lock.unlock();
289 DEBUG("received delete to unknown entry");
290 return;
291 }
292 Entry* entry = m_idmap[id];
293
294 // update persistent dirty flag if it's a persistent value
295 if (entry->IsPersistent()) m_persistent_dirty = true;
296
297 // delete it from idmap
298 m_idmap[id] = nullptr;
299
300 // get entry (as we'll need it for notify) and erase it from the map
301 // it should always be in the map, but sanity check just in case
302 auto i = m_entries.find(entry->name);
303 if (i != m_entries.end()) {
304 auto entry2 = std::move(i->getValue()); // move the value out
305 m_entries.erase(i);
306
307 // notify
308 m_notifier.NotifyEntry(entry2->name, entry2->value, NT_NOTIFY_DELETE);
309 }
310
311 // broadcast to all other connections (note for client there won't
312 // be any other connections, so don't bother)
313 if (m_server && m_queue_outgoing) {
314 auto queue_outgoing = m_queue_outgoing;
315 lock.unlock();
316 queue_outgoing(msg, nullptr, conn);
317 }
318 break;
319 }
320 case Message::kClearEntries: {
321 // update local
322 EntriesMap map;
323 m_entries.swap(map);
324 m_idmap.resize(0);
325
326 // set persistent dirty flag
327 m_persistent_dirty = true;
328
329 // notify
330 for (auto& entry : map)
331 m_notifier.NotifyEntry(entry.getKey(), entry.getValue()->value,
332 NT_NOTIFY_DELETE);
333
334 // broadcast to all other connections (note for client there won't
335 // be any other connections, so don't bother)
336 if (m_server && m_queue_outgoing) {
337 auto queue_outgoing = m_queue_outgoing;
338 lock.unlock();
339 queue_outgoing(msg, nullptr, conn);
340 }
341 break;
342 }
343 case Message::kExecuteRpc: {
344 if (!m_server) return; // only process on server
345 unsigned int id = msg->id();
346 if (id >= m_idmap.size() || !m_idmap[id]) {
347 // ignore call to non-existent RPC
348 // this can happen due to deleted entries
349 lock.unlock();
350 DEBUG("received RPC call to unknown entry");
351 return;
352 }
353 Entry* entry = m_idmap[id];
354 if (!entry->value->IsRpc()) {
355 lock.unlock();
356 DEBUG("received RPC call to non-RPC entry");
357 return;
358 }
359 m_rpc_server.ProcessRpc(entry->name, msg, entry->rpc_callback,
360 conn->uid(), [=](std::shared_ptr<Message> msg) {
361 auto c = conn_weak.lock();
362 if (c) c->QueueOutgoing(msg);
363 });
364 break;
365 }
366 case Message::kRpcResponse: {
367 if (m_server) return; // only process on client
368 m_rpc_results.insert(std::make_pair(
369 std::make_pair(msg->id(), msg->seq_num_uid()), msg->str()));
370 m_rpc_results_cond.notify_all();
371 break;
372 }
373 default:
374 break;
375 }
376}
377
378void Storage::GetInitialAssignments(
379 NetworkConnection& conn, std::vector<std::shared_ptr<Message>>* msgs) {
380 std::lock_guard<std::mutex> lock(m_mutex);
381 conn.set_state(NetworkConnection::kSynchronized);
382 for (auto& i : m_entries) {
383 Entry* entry = i.getValue().get();
384 msgs->emplace_back(Message::EntryAssign(i.getKey(), entry->id,
385 entry->seq_num.value(),
386 entry->value, entry->flags));
387 }
388}
389
390void Storage::ApplyInitialAssignments(
391 NetworkConnection& conn, llvm::ArrayRef<std::shared_ptr<Message>> msgs,
392 bool new_server, std::vector<std::shared_ptr<Message>>* out_msgs) {
393 std::unique_lock<std::mutex> lock(m_mutex);
394 if (m_server) return; // should not do this on server
395
396 conn.set_state(NetworkConnection::kSynchronized);
397
398 std::vector<std::shared_ptr<Message>> update_msgs;
399
400 // clear existing id's
401 for (auto& i : m_entries) i.getValue()->id = 0xffff;
402
403 // clear existing idmap
404 m_idmap.resize(0);
405
406 // apply assignments
407 for (auto& msg : msgs) {
408 if (!msg->Is(Message::kEntryAssign)) {
409 DEBUG("client: received non-entry assignment request?");
410 continue;
411 }
412
413 unsigned int id = msg->id();
414 if (id == 0xffff) {
415 DEBUG("client: received entry assignment request?");
416 continue;
417 }
418
419 SequenceNumber seq_num(msg->seq_num_uid());
420 StringRef name = msg->str();
421
422 auto& entry = m_entries[name];
423 if (!entry) {
424 // doesn't currently exist
425 entry.reset(new Entry(name));
426 entry->value = msg->value();
427 entry->flags = msg->flags();
428 entry->seq_num = seq_num;
429 // notify
430 m_notifier.NotifyEntry(name, entry->value, NT_NOTIFY_NEW);
431 } else {
432 // if reconnect and sequence number not higher than local, then we
433 // don't update the local value and instead send it back to the server
434 // as an update message
435 if (!new_server && seq_num <= entry->seq_num) {
436 update_msgs.emplace_back(Message::EntryUpdate(
437 entry->id, entry->seq_num.value(), entry->value));
438 } else {
439 entry->value = msg->value();
440 entry->seq_num = seq_num;
441 unsigned int notify_flags = NT_NOTIFY_UPDATE;
442 // don't update flags from a <3.0 remote (not part of message)
443 if (conn.proto_rev() >= 0x0300) {
444 if (entry->flags != msg->flags()) notify_flags |= NT_NOTIFY_FLAGS;
445 entry->flags = msg->flags();
446 }
447 // notify
448 m_notifier.NotifyEntry(name, entry->value, notify_flags);
449 }
450 }
451
452 // set id and save to idmap
453 entry->id = id;
454 if (id >= m_idmap.size()) m_idmap.resize(id+1);
455 m_idmap[id] = entry.get();
456 }
457
458 // generate assign messages for unassigned local entries
459 for (auto& i : m_entries) {
460 Entry* entry = i.getValue().get();
461 if (entry->id != 0xffff) continue;
462 out_msgs->emplace_back(Message::EntryAssign(entry->name, entry->id,
463 entry->seq_num.value(),
464 entry->value, entry->flags));
465 }
466 auto queue_outgoing = m_queue_outgoing;
467 lock.unlock();
468 for (auto& msg : update_msgs) queue_outgoing(msg, nullptr, nullptr);
469}
470
471std::shared_ptr<Value> Storage::GetEntryValue(StringRef name) const {
472 std::lock_guard<std::mutex> lock(m_mutex);
473 auto i = m_entries.find(name);
474 return i == m_entries.end() ? nullptr : i->getValue()->value;
475}
476
477bool Storage::SetEntryValue(StringRef name, std::shared_ptr<Value> value) {
478 if (name.empty()) return true;
479 if (!value) return true;
480 std::unique_lock<std::mutex> lock(m_mutex);
481 auto& new_entry = m_entries[name];
482 if (!new_entry) new_entry.reset(new Entry(name));
483 Entry* entry = new_entry.get();
484 auto old_value = entry->value;
485 if (old_value && old_value->type() != value->type())
486 return false; // error on type mismatch
487 entry->value = value;
488
489 // if we're the server, assign an id if it doesn't have one
490 if (m_server && entry->id == 0xffff) {
491 unsigned int id = m_idmap.size();
492 entry->id = id;
493 m_idmap.push_back(entry);
494 }
495
496 // update persistent dirty flag if value changed and it's persistent
497 if (entry->IsPersistent() && *old_value != *value) m_persistent_dirty = true;
498
499 // notify (for local listeners)
500 if (m_notifier.local_notifiers()) {
501 if (!old_value)
502 m_notifier.NotifyEntry(name, value, NT_NOTIFY_NEW | NT_NOTIFY_LOCAL);
503 else if (*old_value != *value)
504 m_notifier.NotifyEntry(name, value, NT_NOTIFY_UPDATE | NT_NOTIFY_LOCAL);
505 }
506
507 // generate message
508 if (!m_queue_outgoing) return true;
509 auto queue_outgoing = m_queue_outgoing;
510 if (!old_value) {
511 auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
512 value, entry->flags);
513 lock.unlock();
514 queue_outgoing(msg, nullptr, nullptr);
515 } else if (*old_value != *value) {
516 ++entry->seq_num;
517 // don't send an update if we don't have an assigned id yet
518 if (entry->id != 0xffff) {
519 auto msg =
520 Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
521 lock.unlock();
522 queue_outgoing(msg, nullptr, nullptr);
523 }
524 }
525 return true;
526}
527
528void Storage::SetEntryTypeValue(StringRef name, std::shared_ptr<Value> value) {
529 if (name.empty()) return;
530 if (!value) return;
531 std::unique_lock<std::mutex> lock(m_mutex);
532 auto& new_entry = m_entries[name];
533 if (!new_entry) new_entry.reset(new Entry(name));
534 Entry* entry = new_entry.get();
535 auto old_value = entry->value;
536 entry->value = value;
537 if (old_value && *old_value == *value) return;
538
539 // if we're the server, assign an id if it doesn't have one
540 if (m_server && entry->id == 0xffff) {
541 unsigned int id = m_idmap.size();
542 entry->id = id;
543 m_idmap.push_back(entry);
544 }
545
546 // update persistent dirty flag if it's a persistent value
547 if (entry->IsPersistent()) m_persistent_dirty = true;
548
549 // notify (for local listeners)
550 if (m_notifier.local_notifiers()) {
551 if (!old_value)
552 m_notifier.NotifyEntry(name, value, NT_NOTIFY_NEW | NT_NOTIFY_LOCAL);
553 else
554 m_notifier.NotifyEntry(name, value, NT_NOTIFY_UPDATE | NT_NOTIFY_LOCAL);
555 }
556
557 // generate message
558 if (!m_queue_outgoing) return;
559 auto queue_outgoing = m_queue_outgoing;
560 if (!old_value || old_value->type() != value->type()) {
561 ++entry->seq_num;
562 auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
563 value, entry->flags);
564 lock.unlock();
565 queue_outgoing(msg, nullptr, nullptr);
566 } else {
567 ++entry->seq_num;
568 // don't send an update if we don't have an assigned id yet
569 if (entry->id != 0xffff) {
570 auto msg =
571 Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
572 lock.unlock();
573 queue_outgoing(msg, nullptr, nullptr);
574 }
575 }
576}
577
578void Storage::SetEntryFlags(StringRef name, unsigned int flags) {
579 if (name.empty()) return;
580 std::unique_lock<std::mutex> lock(m_mutex);
581 auto i = m_entries.find(name);
582 if (i == m_entries.end()) return;
583 Entry* entry = i->getValue().get();
584 if (entry->flags == flags) return;
585
586 // update persistent dirty flag if persistent flag changed
587 if ((entry->flags & NT_PERSISTENT) != (flags & NT_PERSISTENT))
588 m_persistent_dirty = true;
589
590 entry->flags = flags;
591
592 // notify
593 m_notifier.NotifyEntry(name, entry->value, NT_NOTIFY_FLAGS | NT_NOTIFY_LOCAL);
594
595 // generate message
596 if (!m_queue_outgoing) return;
597 auto queue_outgoing = m_queue_outgoing;
598 unsigned int id = entry->id;
599 // don't send an update if we don't have an assigned id yet
600 if (id != 0xffff) {
601 lock.unlock();
602 queue_outgoing(Message::FlagsUpdate(id, flags), nullptr, nullptr);
603 }
604}
605
606unsigned int Storage::GetEntryFlags(StringRef name) const {
607 std::lock_guard<std::mutex> lock(m_mutex);
608 auto i = m_entries.find(name);
609 return i == m_entries.end() ? 0 : i->getValue()->flags;
610}
611
612void Storage::DeleteEntry(StringRef name) {
613 std::unique_lock<std::mutex> lock(m_mutex);
614 auto i = m_entries.find(name);
615 if (i == m_entries.end()) return;
616 auto entry = std::move(i->getValue());
617 unsigned int id = entry->id;
618
619 // update persistent dirty flag if it's a persistent value
620 if (entry->IsPersistent()) m_persistent_dirty = true;
621
622 m_entries.erase(i); // erase from map
623 if (id < m_idmap.size()) m_idmap[id] = nullptr;
624
625 if (!entry->value) return;
626
627 // notify
628 m_notifier.NotifyEntry(name, entry->value,
629 NT_NOTIFY_DELETE | NT_NOTIFY_LOCAL);
630
631 // if it had a value, generate message
632 // don't send an update if we don't have an assigned id yet
633 if (id != 0xffff) {
634 if (!m_queue_outgoing) return;
635 auto queue_outgoing = m_queue_outgoing;
636 lock.unlock();
637 queue_outgoing(Message::EntryDelete(id), nullptr, nullptr);
638 }
639}
640
641void Storage::DeleteAllEntries() {
642 std::unique_lock<std::mutex> lock(m_mutex);
643 if (m_entries.empty()) return;
644 EntriesMap map;
645 m_entries.swap(map);
646 m_idmap.resize(0);
647
648 // set persistent dirty flag
649 m_persistent_dirty = true;
650
651 // notify
652 if (m_notifier.local_notifiers()) {
653 for (auto& entry : map)
654 m_notifier.NotifyEntry(entry.getKey(), entry.getValue()->value,
655 NT_NOTIFY_DELETE | NT_NOTIFY_LOCAL);
656 }
657
658 // generate message
659 if (!m_queue_outgoing) return;
660 auto queue_outgoing = m_queue_outgoing;
661 lock.unlock();
662 queue_outgoing(Message::ClearEntries(), nullptr, nullptr);
663}
664
665std::vector<EntryInfo> Storage::GetEntryInfo(StringRef prefix,
666 unsigned int types) {
667 std::lock_guard<std::mutex> lock(m_mutex);
668 std::vector<EntryInfo> infos;
669 for (auto& i : m_entries) {
670 if (!i.getKey().startswith(prefix)) continue;
671 Entry* entry = i.getValue().get();
672 auto value = entry->value;
673 if (!value) continue;
674 if (types != 0 && (types & value->type()) == 0) continue;
675 EntryInfo info;
676 info.name = i.getKey();
677 info.type = value->type();
678 info.flags = entry->flags;
679 info.last_change = value->last_change();
680 infos.push_back(std::move(info));
681 }
682 return infos;
683}
684
685void Storage::NotifyEntries(StringRef prefix,
686 EntryListenerCallback only) const {
687 std::lock_guard<std::mutex> lock(m_mutex);
688 for (auto& i : m_entries) {
689 if (!i.getKey().startswith(prefix)) continue;
690 m_notifier.NotifyEntry(i.getKey(), i.getValue()->value, NT_NOTIFY_IMMEDIATE,
691 only);
692 }
693}
694
695/* Escapes and writes a string, including start and end double quotes */
696static void WriteString(std::ostream& os, llvm::StringRef str) {
697 os << '"';
698 for (auto c : str) {
699 switch (c) {
700 case '\\':
701 os << "\\\\";
702 break;
703 case '\t':
704 os << "\\t";
705 break;
706 case '\n':
707 os << "\\n";
708 break;
709 case '"':
710 os << "\\\"";
711 break;
712 default:
713 if (std::isprint(c)) {
714 os << c;
715 break;
716 }
717
718 // Write out the escaped representation.
719 os << "\\x";
720 os << llvm::hexdigit((c >> 4) & 0xF);
721 os << llvm::hexdigit((c >> 0) & 0xF);
722 }
723 }
724 os << '"';
725}
726
727bool Storage::GetPersistentEntries(
728 bool periodic,
729 std::vector<std::pair<std::string, std::shared_ptr<Value>>>* entries)
730 const {
731 // copy values out of storage as quickly as possible so lock isn't held
732 {
733 std::lock_guard<std::mutex> lock(m_mutex);
734 // for periodic, don't re-save unless something has changed
735 if (periodic && !m_persistent_dirty) return false;
736 m_persistent_dirty = false;
737 entries->reserve(m_entries.size());
738 for (auto& i : m_entries) {
739 Entry* entry = i.getValue().get();
740 // only write persistent-flagged values
741 if (!entry->IsPersistent()) continue;
742 entries->emplace_back(i.getKey(), entry->value);
743 }
744 }
745
746 // sort in name order
747 std::sort(entries->begin(), entries->end(),
748 [](const std::pair<std::string, std::shared_ptr<Value>>& a,
749 const std::pair<std::string, std::shared_ptr<Value>>& b) {
750 return a.first < b.first;
751 });
752 return true;
753}
754
755static void SavePersistentImpl(
756 std::ostream& os,
757 llvm::ArrayRef<std::pair<std::string, std::shared_ptr<Value>>> entries) {
758 std::string base64_encoded;
759
760 // header
761 os << "[NetworkTables Storage 3.0]\n";
762
763 for (auto& i : entries) {
764 // type
765 auto v = i.second;
766 if (!v) continue;
767 switch (v->type()) {
768 case NT_BOOLEAN:
769 os << "boolean ";
770 break;
771 case NT_DOUBLE:
772 os << "double ";
773 break;
774 case NT_STRING:
775 os << "string ";
776 break;
777 case NT_RAW:
778 os << "raw ";
779 break;
780 case NT_BOOLEAN_ARRAY:
781 os << "array boolean ";
782 break;
783 case NT_DOUBLE_ARRAY:
784 os << "array double ";
785 break;
786 case NT_STRING_ARRAY:
787 os << "array string ";
788 break;
789 default:
790 continue;
791 }
792
793 // name
794 WriteString(os, i.first);
795
796 // =
797 os << '=';
798
799 // value
800 switch (v->type()) {
801 case NT_BOOLEAN:
802 os << (v->GetBoolean() ? "true" : "false");
803 break;
804 case NT_DOUBLE:
805 os << v->GetDouble();
806 break;
807 case NT_STRING:
808 WriteString(os, v->GetString());
809 break;
810 case NT_RAW:
811 Base64Encode(v->GetRaw(), &base64_encoded);
812 os << base64_encoded;
813 break;
814 case NT_BOOLEAN_ARRAY: {
815 bool first = true;
816 for (auto elem : v->GetBooleanArray()) {
817 if (!first) os << ',';
818 first = false;
819 os << (elem ? "true" : "false");
820 }
821 break;
822 }
823 case NT_DOUBLE_ARRAY: {
824 bool first = true;
825 for (auto elem : v->GetDoubleArray()) {
826 if (!first) os << ',';
827 first = false;
828 os << elem;
829 }
830 break;
831 }
832 case NT_STRING_ARRAY: {
833 bool first = true;
834 for (auto& elem : v->GetStringArray()) {
835 if (!first) os << ',';
836 first = false;
837 WriteString(os, elem);
838 }
839 break;
840 }
841 default:
842 break;
843 }
844
845 // eol
846 os << '\n';
847 }
848}
849
850void Storage::SavePersistent(std::ostream& os, bool periodic) const {
851 std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
852 if (!GetPersistentEntries(periodic, &entries)) return;
853 SavePersistentImpl(os, entries);
854}
855
856const char* Storage::SavePersistent(StringRef filename, bool periodic) const {
857 std::string fn = filename;
858 std::string tmp = filename;
859 tmp += ".tmp";
860 std::string bak = filename;
861 bak += ".bak";
862
863 // Get entries before creating file
864 std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
865 if (!GetPersistentEntries(periodic, &entries)) return nullptr;
866
867 const char* err = nullptr;
868
869 // start by writing to temporary file
870 std::ofstream os(tmp);
871 if (!os) {
872 err = "could not open file";
873 goto done;
874 }
875 DEBUG("saving persistent file '" << filename << "'");
876 SavePersistentImpl(os, entries);
877 os.flush();
878 if (!os) {
879 os.close();
880 std::remove(tmp.c_str());
881 err = "error saving file";
882 goto done;
883 }
884 os.close();
885
886 // Safely move to real file. We ignore any failures related to the backup.
887 std::remove(bak.c_str());
888 std::rename(fn.c_str(), bak.c_str());
889 if (std::rename(tmp.c_str(), fn.c_str()) != 0) {
890 std::rename(bak.c_str(), fn.c_str()); // attempt to restore backup
891 err = "could not rename temp file to real file";
892 goto done;
893 }
894
895done:
896 // try again if there was an error
897 if (err && periodic) m_persistent_dirty = true;
898 return err;
899}
900
901/* Extracts an escaped string token. Does not unescape the string.
902 * If a string cannot be matched, an empty string is returned.
903 * If the string is unterminated, an empty tail string is returned.
904 * The returned token includes the starting and trailing quotes (unless the
905 * string is unterminated).
906 * Returns a pair containing the extracted token (if any) and the remaining
907 * tail string.
908 */
909static std::pair<llvm::StringRef, llvm::StringRef> ReadStringToken(
910 llvm::StringRef source) {
911 // Match opening quote
912 if (source.empty() || source.front() != '"')
913 return std::make_pair(llvm::StringRef(), source);
914
915 // Scan for ending double quote, checking for escaped as we go.
916 std::size_t size = source.size();
917 std::size_t pos;
918 for (pos = 1; pos < size; ++pos) {
919 if (source[pos] == '"' && source[pos - 1] != '\\') {
920 ++pos; // we want to include the trailing quote in the result
921 break;
922 }
923 }
924 return std::make_pair(source.slice(0, pos), source.substr(pos));
925}
926
927static int fromxdigit(char ch) {
928 if (ch >= 'a' && ch <= 'f')
929 return (ch - 'a' + 10);
930 else if (ch >= 'A' && ch <= 'F')
931 return (ch - 'A' + 10);
932 else
933 return ch - '0';
934}
935
936static void UnescapeString(llvm::StringRef source, std::string* dest) {
937 assert(source.size() >= 2 && source.front() == '"' && source.back() == '"');
938 dest->clear();
939 dest->reserve(source.size() - 2);
940 for (auto s = source.begin() + 1, end = source.end() - 1; s != end; ++s) {
941 if (*s != '\\') {
942 dest->push_back(*s);
943 continue;
944 }
945 switch (*++s) {
946 case '\\':
947 case '"':
948 dest->push_back(s[-1]);
949 break;
950 case 't':
951 dest->push_back('\t');
952 break;
953 case 'n':
954 dest->push_back('\n');
955 break;
956 case 'x': {
957 if (!isxdigit(*(s+1))) {
958 dest->push_back('x'); // treat it like a unknown escape
959 break;
960 }
961 int ch = fromxdigit(*++s);
962 if (isxdigit(*(s+1))) {
963 ch <<= 4;
964 ch |= fromxdigit(*++s);
965 }
966 dest->push_back(static_cast<char>(ch));
967 break;
968 }
969 default:
970 dest->push_back(s[-1]);
971 break;
972 }
973 }
974}
975
976bool Storage::LoadPersistent(
977 std::istream& is,
978 std::function<void(std::size_t line, const char* msg)> warn) {
979 std::string line_str;
980 std::size_t line_num = 1;
981
982 // entries to add
983 std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
984
985 // declare these outside the loop to reduce reallocs
986 std::string name, str;
987 std::vector<int> boolean_array;
988 std::vector<double> double_array;
989 std::vector<std::string> string_array;
990
991 // ignore blank lines and lines that start with ; or # (comments)
992 while (std::getline(is, line_str)) {
993 llvm::StringRef line = llvm::StringRef(line_str).trim();
994 if (!line.empty() && line.front() != ';' && line.front() != '#')
995 break;
996 }
997
998 // header
999 if (line_str != "[NetworkTables Storage 3.0]") {
1000 if (warn) warn(line_num, "header line mismatch, ignoring rest of file");
1001 return false;
1002 }
1003
1004 while (std::getline(is, line_str)) {
1005 llvm::StringRef line = llvm::StringRef(line_str).trim();
1006 ++line_num;
1007
1008 // ignore blank lines and lines that start with ; or # (comments)
1009 if (line.empty() || line.front() == ';' || line.front() == '#')
1010 continue;
1011
1012 // type
1013 llvm::StringRef type_tok;
1014 std::tie(type_tok, line) = line.split(' ');
1015 NT_Type type = NT_UNASSIGNED;
1016 if (type_tok == "boolean") type = NT_BOOLEAN;
1017 else if (type_tok == "double") type = NT_DOUBLE;
1018 else if (type_tok == "string") type = NT_STRING;
1019 else if (type_tok == "raw") type = NT_RAW;
1020 else if (type_tok == "array") {
1021 llvm::StringRef array_tok;
1022 std::tie(array_tok, line) = line.split(' ');
1023 if (array_tok == "boolean") type = NT_BOOLEAN_ARRAY;
1024 else if (array_tok == "double") type = NT_DOUBLE_ARRAY;
1025 else if (array_tok == "string") type = NT_STRING_ARRAY;
1026 }
1027 if (type == NT_UNASSIGNED) {
1028 if (warn) warn(line_num, "unrecognized type");
1029 continue;
1030 }
1031
1032 // name
1033 llvm::StringRef name_tok;
1034 std::tie(name_tok, line) = ReadStringToken(line);
1035 if (name_tok.empty()) {
1036 if (warn) warn(line_num, "missing name");
1037 continue;
1038 }
1039 if (name_tok.back() != '"') {
1040 if (warn) warn(line_num, "unterminated name string");
1041 continue;
1042 }
1043 UnescapeString(name_tok, &name);
1044
1045 // =
1046 line = line.ltrim(" \t");
1047 if (line.empty() || line.front() != '=') {
1048 if (warn) warn(line_num, "expected = after name");
1049 continue;
1050 }
1051 line = line.drop_front().ltrim(" \t");
1052
1053 // value
1054 std::shared_ptr<Value> value;
1055 switch (type) {
1056 case NT_BOOLEAN:
1057 // only true or false is accepted
1058 if (line == "true")
1059 value = Value::MakeBoolean(true);
1060 else if (line == "false")
1061 value = Value::MakeBoolean(false);
1062 else {
1063 if (warn)
1064 warn(line_num, "unrecognized boolean value, not 'true' or 'false'");
1065 goto next_line;
1066 }
1067 break;
1068 case NT_DOUBLE: {
1069 // need to convert to null-terminated string for strtod()
1070 str.clear();
1071 str += line;
1072 char* end;
1073 double v = std::strtod(str.c_str(), &end);
1074 if (*end != '\0') {
1075 if (warn) warn(line_num, "invalid double value");
1076 goto next_line;
1077 }
1078 value = Value::MakeDouble(v);
1079 break;
1080 }
1081 case NT_STRING: {
1082 llvm::StringRef str_tok;
1083 std::tie(str_tok, line) = ReadStringToken(line);
1084 if (str_tok.empty()) {
1085 if (warn) warn(line_num, "missing string value");
1086 goto next_line;
1087 }
1088 if (str_tok.back() != '"') {
1089 if (warn) warn(line_num, "unterminated string value");
1090 goto next_line;
1091 }
1092 UnescapeString(str_tok, &str);
1093 value = Value::MakeString(std::move(str));
1094 break;
1095 }
1096 case NT_RAW:
1097 Base64Decode(line, &str);
1098 value = Value::MakeRaw(std::move(str));
1099 break;
1100 case NT_BOOLEAN_ARRAY: {
1101 llvm::StringRef elem_tok;
1102 boolean_array.clear();
1103 while (!line.empty()) {
1104 std::tie(elem_tok, line) = line.split(',');
1105 elem_tok = elem_tok.trim(" \t");
1106 if (elem_tok == "true")
1107 boolean_array.push_back(1);
1108 else if (elem_tok == "false")
1109 boolean_array.push_back(0);
1110 else {
1111 if (warn)
1112 warn(line_num,
1113 "unrecognized boolean value, not 'true' or 'false'");
1114 goto next_line;
1115 }
1116 }
1117
1118 value = Value::MakeBooleanArray(std::move(boolean_array));
1119 break;
1120 }
1121 case NT_DOUBLE_ARRAY: {
1122 llvm::StringRef elem_tok;
1123 double_array.clear();
1124 while (!line.empty()) {
1125 std::tie(elem_tok, line) = line.split(',');
1126 elem_tok = elem_tok.trim(" \t");
1127 // need to convert to null-terminated string for strtod()
1128 str.clear();
1129 str += elem_tok;
1130 char* end;
1131 double v = std::strtod(str.c_str(), &end);
1132 if (*end != '\0') {
1133 if (warn) warn(line_num, "invalid double value");
1134 goto next_line;
1135 }
1136 double_array.push_back(v);
1137 }
1138
1139 value = Value::MakeDoubleArray(std::move(double_array));
1140 break;
1141 }
1142 case NT_STRING_ARRAY: {
1143 llvm::StringRef elem_tok;
1144 string_array.clear();
1145 while (!line.empty()) {
1146 std::tie(elem_tok, line) = ReadStringToken(line);
1147 if (elem_tok.empty()) {
1148 if (warn) warn(line_num, "missing string value");
1149 goto next_line;
1150 }
1151 if (elem_tok.back() != '"') {
1152 if (warn) warn(line_num, "unterminated string value");
1153 goto next_line;
1154 }
1155
1156 UnescapeString(elem_tok, &str);
1157 string_array.push_back(std::move(str));
1158
1159 line = line.ltrim(" \t");
1160 if (line.empty()) break;
1161 if (line.front() != ',') {
1162 if (warn) warn(line_num, "expected comma between strings");
1163 goto next_line;
1164 }
1165 line = line.drop_front().ltrim(" \t");
1166 }
1167
1168 value = Value::MakeStringArray(std::move(string_array));
1169 break;
1170 }
1171 default:
1172 break;
1173 }
1174 if (!name.empty() && value)
1175 entries.push_back(std::make_pair(std::move(name), std::move(value)));
1176next_line:
1177 ;
1178 }
1179
1180 // copy values into storage as quickly as possible so lock isn't held
1181 {
1182 std::vector<std::shared_ptr<Message>> msgs;
1183 std::unique_lock<std::mutex> lock(m_mutex);
1184 for (auto& i : entries) {
1185 auto& new_entry = m_entries[i.first];
1186 if (!new_entry) new_entry.reset(new Entry(i.first));
1187 Entry* entry = new_entry.get();
1188 auto old_value = entry->value;
1189 entry->value = i.second;
1190 bool was_persist = entry->IsPersistent();
1191 if (!was_persist) entry->flags |= NT_PERSISTENT;
1192
1193 // if we're the server, assign an id if it doesn't have one
1194 if (m_server && entry->id == 0xffff) {
1195 unsigned int id = m_idmap.size();
1196 entry->id = id;
1197 m_idmap.push_back(entry);
1198 }
1199
1200 // notify (for local listeners)
1201 if (m_notifier.local_notifiers()) {
1202 if (!old_value)
1203 m_notifier.NotifyEntry(i.first, i.second,
1204 NT_NOTIFY_NEW | NT_NOTIFY_LOCAL);
1205 else if (*old_value != *i.second) {
1206 unsigned int notify_flags = NT_NOTIFY_UPDATE | NT_NOTIFY_LOCAL;
1207 if (!was_persist) notify_flags |= NT_NOTIFY_FLAGS;
1208 m_notifier.NotifyEntry(i.first, i.second, notify_flags);
1209 }
1210 }
1211
1212 if (!m_queue_outgoing) continue; // shortcut
1213 ++entry->seq_num;
1214
1215 // put on update queue
1216 if (!old_value || old_value->type() != i.second->type())
1217 msgs.emplace_back(Message::EntryAssign(i.first, entry->id,
1218 entry->seq_num.value(),
1219 i.second, entry->flags));
1220 else if (entry->id != 0xffff) {
1221 // don't send an update if we don't have an assigned id yet
1222 if (*old_value != *i.second)
1223 msgs.emplace_back(Message::EntryUpdate(
1224 entry->id, entry->seq_num.value(), i.second));
1225 if (!was_persist)
1226 msgs.emplace_back(Message::FlagsUpdate(entry->id, entry->flags));
1227 }
1228 }
1229
1230 if (m_queue_outgoing) {
1231 auto queue_outgoing = m_queue_outgoing;
1232 lock.unlock();
1233 for (auto& msg : msgs) queue_outgoing(std::move(msg), nullptr, nullptr);
1234 }
1235 }
1236
1237 return true;
1238}
1239
1240const char* Storage::LoadPersistent(
1241 StringRef filename,
1242 std::function<void(std::size_t line, const char* msg)> warn) {
1243 std::ifstream is(filename);
1244 if (!is) return "could not open file";
1245 if (!LoadPersistent(is, warn)) return "error reading file";
1246 return nullptr;
1247}
1248
1249void Storage::CreateRpc(StringRef name, StringRef def, RpcCallback callback) {
1250 if (name.empty() || def.empty() || !callback) return;
1251 std::unique_lock<std::mutex> lock(m_mutex);
1252 if (!m_server) return; // only server can create RPCs
1253
1254 auto& new_entry = m_entries[name];
1255 if (!new_entry) new_entry.reset(new Entry(name));
1256 Entry* entry = new_entry.get();
1257 auto old_value = entry->value;
1258 auto value = Value::MakeRpc(def);
1259 entry->value = value;
1260
1261 // set up the new callback
1262 entry->rpc_callback = callback;
1263
1264 // start the RPC server
1265 if (!m_rpc_server.active()) m_rpc_server.Start();
1266
1267 if (old_value && *old_value == *value) return;
1268
1269 // assign an id if it doesn't have one
1270 if (entry->id == 0xffff) {
1271 unsigned int id = m_idmap.size();
1272 entry->id = id;
1273 m_idmap.push_back(entry);
1274 }
1275
1276 // generate message
1277 if (!m_queue_outgoing) return;
1278 auto queue_outgoing = m_queue_outgoing;
1279 if (!old_value || old_value->type() != value->type()) {
1280 ++entry->seq_num;
1281 auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
1282 value, entry->flags);
1283 lock.unlock();
1284 queue_outgoing(msg, nullptr, nullptr);
1285 } else {
1286 ++entry->seq_num;
1287 auto msg = Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
1288 lock.unlock();
1289 queue_outgoing(msg, nullptr, nullptr);
1290 }
1291}
1292
1293void Storage::CreatePolledRpc(StringRef name, StringRef def) {
1294 if (name.empty() || def.empty()) return;
1295 std::unique_lock<std::mutex> lock(m_mutex);
1296 if (!m_server) return; // only server can create RPCs
1297
1298 auto& new_entry = m_entries[name];
1299 if (!new_entry) new_entry.reset(new Entry(name));
1300 Entry* entry = new_entry.get();
1301 auto old_value = entry->value;
1302 auto value = Value::MakeRpc(def);
1303 entry->value = value;
1304
1305 // a nullptr callback indicates a polled RPC
1306 entry->rpc_callback = nullptr;
1307
1308 if (old_value && *old_value == *value) return;
1309
1310 // assign an id if it doesn't have one
1311 if (entry->id == 0xffff) {
1312 unsigned int id = m_idmap.size();
1313 entry->id = id;
1314 m_idmap.push_back(entry);
1315 }
1316
1317 // generate message
1318 if (!m_queue_outgoing) return;
1319 auto queue_outgoing = m_queue_outgoing;
1320 if (!old_value || old_value->type() != value->type()) {
1321 ++entry->seq_num;
1322 auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
1323 value, entry->flags);
1324 lock.unlock();
1325 queue_outgoing(msg, nullptr, nullptr);
1326 } else {
1327 ++entry->seq_num;
1328 auto msg = Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
1329 lock.unlock();
1330 queue_outgoing(msg, nullptr, nullptr);
1331 }
1332}
1333
1334unsigned int Storage::CallRpc(StringRef name, StringRef params) {
1335 std::unique_lock<std::mutex> lock(m_mutex);
1336 auto i = m_entries.find(name);
1337 if (i == m_entries.end()) return 0;
1338 auto& entry = i->getValue();
1339 if (!entry->value->IsRpc()) return 0;
1340
1341 ++entry->rpc_call_uid;
1342 if (entry->rpc_call_uid > 0xffff) entry->rpc_call_uid = 0;
1343 unsigned int combined_uid = (entry->id << 16) | entry->rpc_call_uid;
1344 auto msg = Message::ExecuteRpc(entry->id, entry->rpc_call_uid, params);
1345 if (m_server) {
1346 // RPCs are unlikely to be used locally on the server, but handle it
1347 // gracefully anyway.
1348 auto rpc_callback = entry->rpc_callback;
1349 lock.unlock();
1350 m_rpc_server.ProcessRpc(
1351 name, msg, rpc_callback, 0xffffU, [this](std::shared_ptr<Message> msg) {
1352 std::lock_guard<std::mutex> lock(m_mutex);
1353 m_rpc_results.insert(std::make_pair(
1354 std::make_pair(msg->id(), msg->seq_num_uid()), msg->str()));
1355 m_rpc_results_cond.notify_all();
1356 });
1357 } else {
1358 auto queue_outgoing = m_queue_outgoing;
1359 lock.unlock();
1360 queue_outgoing(msg, nullptr, nullptr);
1361 }
1362 return combined_uid;
1363}
1364
1365bool Storage::GetRpcResult(bool blocking, unsigned int call_uid,
1366 std::string* result) {
1367 std::unique_lock<std::mutex> lock(m_mutex);
1368 for (;;) {
1369 auto i =
1370 m_rpc_results.find(std::make_pair(call_uid >> 16, call_uid & 0xffff));
1371 if (i == m_rpc_results.end()) {
1372 if (!blocking || m_terminating) return false;
1373 m_rpc_results_cond.wait(lock);
1374 if (m_terminating) return false;
1375 continue;
1376 }
1377 result->swap(i->getSecond());
1378 m_rpc_results.erase(i);
1379 return true;
1380 }
1381}