Squashed 'third_party/ntcore_2016/' content from commit d8de5e4
Change-Id: Id4839f41b6a620d8bae58dcf1710016671cc4992
git-subtree-dir: third_party/ntcore_2016
git-subtree-split: d8de5e4f19e612e7102172c0dbf152ce82d3d63a
diff --git a/src/Base64.cpp b/src/Base64.cpp
new file mode 100644
index 0000000..17aa125
--- /dev/null
+++ b/src/Base64.cpp
@@ -0,0 +1,152 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+/* ====================================================================
+ * Copyright (c) 1995-1999 The Apache Group. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in
+ * the documentation and/or other materials provided with the
+ * distribution.
+ *
+ * 3. All advertising materials mentioning features or use of this
+ * software must display the following acknowledgment:
+ * "This product includes software developed by the Apache Group
+ * for use in the Apache HTTP server project (http://www.apache.org/)."
+ *
+ * 4. The names "Apache Server" and "Apache Group" must not be used to
+ * endorse or promote products derived from this software without
+ * prior written permission. For written permission, please contact
+ * apache@apache.org.
+ *
+ * 5. Products derived from this software may not be called "Apache"
+ * nor may "Apache" appear in their names without prior written
+ * permission of the Apache Group.
+ *
+ * 6. Redistributions of any form whatsoever must retain the following
+ * acknowledgment:
+ * "This product includes software developed by the Apache Group
+ * for use in the Apache HTTP server project (http://www.apache.org/)."
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE APACHE GROUP ``AS IS'' AND ANY
+ * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE APACHE GROUP OR
+ * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
+ * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ * ====================================================================
+ *
+ * This software consists of voluntary contributions made by many
+ * individuals on behalf of the Apache Group and was originally based
+ * on public domain software written at the National Center for
+ * Supercomputing Applications, University of Illinois, Urbana-Champaign.
+ * For more information on the Apache Group and the Apache HTTP server
+ * project, please see <http://www.apache.org/>.
+ *
+ */
+
+#include "Base64.h"
+
+namespace nt {
+
+// aaaack but it's fast and const should make it shared text page.
+static const unsigned char pr2six[256] =
+{
+ // ASCII table
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63,
+ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64,
+ 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64,
+ 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
+ 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+ 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64
+};
+
+std::size_t Base64Decode(llvm::StringRef encoded, std::string* plain) {
+ const unsigned char *end = encoded.bytes_begin();
+ while (pr2six[*end] <= 63 && end != encoded.bytes_end()) ++end;
+ std::size_t nprbytes = end - encoded.bytes_begin();
+
+ plain->clear();
+ if (nprbytes == 0)
+ return 0;
+ plain->reserve(((nprbytes + 3) / 4) * 3);
+
+ const unsigned char *cur = encoded.bytes_begin();
+
+ while (nprbytes > 4) {
+ (*plain) += (pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4);
+ (*plain) += (pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2);
+ (*plain) += (pr2six[cur[2]] << 6 | pr2six[cur[3]]);
+ cur += 4;
+ nprbytes -= 4;
+ }
+
+ // Note: (nprbytes == 1) would be an error, so just ignore that case
+ if (nprbytes > 1) (*plain) += (pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4);
+ if (nprbytes > 2) (*plain) += (pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2);
+ if (nprbytes > 3) (*plain) += (pr2six[cur[2]] << 6 | pr2six[cur[3]]);
+
+ return (end - encoded.bytes_begin()) + ((4 - nprbytes) & 3);
+}
+
+static const char basis_64[] =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+void Base64Encode(llvm::StringRef plain, std::string* encoded) {
+ encoded->clear();
+ if (plain.empty())
+ return;
+ std::size_t len = plain.size();
+ encoded->reserve(((len + 2) / 3 * 4) + 1);
+
+ std::size_t i;
+ for (i = 0; (i + 2) < len; i += 3) {
+ (*encoded) += basis_64[(plain[i] >> 2) & 0x3F];
+ (*encoded) +=
+ basis_64[((plain[i] & 0x3) << 4) | ((int)(plain[i + 1] & 0xF0) >> 4)];
+ (*encoded) += basis_64[((plain[i + 1] & 0xF) << 2) |
+ ((int)(plain[i + 2] & 0xC0) >> 6)];
+ (*encoded) += basis_64[plain[i + 2] & 0x3F];
+ }
+ if (i < len) {
+ (*encoded) += basis_64[(plain[i] >> 2) & 0x3F];
+ if (i == (len - 1)) {
+ (*encoded) += basis_64[((plain[i] & 0x3) << 4)];
+ (*encoded) += '=';
+ } else {
+ (*encoded) +=
+ basis_64[((plain[i] & 0x3) << 4) | ((int)(plain[i + 1] & 0xF0) >> 4)];
+ (*encoded) += basis_64[((plain[i + 1] & 0xF) << 2)];
+ }
+ (*encoded) += '=';
+ }
+}
+
+} // namespace nt
diff --git a/src/Base64.h b/src/Base64.h
new file mode 100644
index 0000000..a86e699
--- /dev/null
+++ b/src/Base64.h
@@ -0,0 +1,23 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_BASE64_H_
+#define NT_BASE64_H_
+
+#include <cstddef>
+#include <string>
+
+#include "llvm/StringRef.h"
+
+namespace nt {
+
+std::size_t Base64Decode(llvm::StringRef encoded, std::string* plain);
+void Base64Encode(llvm::StringRef plain, std::string* encoded);
+
+} // namespace nt
+
+#endif // NT_BASE64_H_
diff --git a/src/Dispatcher.cpp b/src/Dispatcher.cpp
new file mode 100644
index 0000000..2b8fcb2
--- /dev/null
+++ b/src/Dispatcher.cpp
@@ -0,0 +1,487 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "Dispatcher.h"
+
+#include <algorithm>
+#include <iterator>
+
+#include "tcpsockets/TCPAcceptor.h"
+#include "tcpsockets/TCPConnector.h"
+#include "Log.h"
+
+using namespace nt;
+
+ATOMIC_STATIC_INIT(Dispatcher)
+
+void Dispatcher::StartServer(StringRef persist_filename,
+ const char* listen_address, unsigned int port) {
+ DispatcherBase::StartServer(persist_filename,
+ std::unique_ptr<NetworkAcceptor>(new TCPAcceptor(
+ static_cast<int>(port), listen_address)));
+}
+
+void Dispatcher::StartClient(const char* server_name, unsigned int port) {
+ std::string server_name_copy(server_name);
+ DispatcherBase::StartClient([=]() -> std::unique_ptr<NetworkStream> {
+ return TCPConnector::connect(server_name_copy.c_str(),
+ static_cast<int>(port), 1);
+ });
+}
+
+Dispatcher::Dispatcher()
+ : Dispatcher(Storage::GetInstance(), Notifier::GetInstance()) {}
+
+DispatcherBase::DispatcherBase(Storage& storage, Notifier& notifier)
+ : m_storage(storage), m_notifier(notifier) {
+ m_active = false;
+ m_update_rate = 100;
+}
+
+DispatcherBase::~DispatcherBase() {
+ Logger::GetInstance().SetLogger(nullptr);
+ Stop();
+}
+
+void DispatcherBase::StartServer(StringRef persist_filename,
+ std::unique_ptr<NetworkAcceptor> acceptor) {
+ {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ if (m_active) return;
+ m_active = true;
+ }
+ m_server = true;
+ m_persist_filename = persist_filename;
+ m_server_acceptor = std::move(acceptor);
+
+ // Load persistent file. Ignore errors, but pass along warnings.
+ if (!persist_filename.empty()) {
+ bool first = true;
+ m_storage.LoadPersistent(
+ persist_filename, [&](std::size_t line, const char* msg) {
+ if (first) {
+ first = false;
+ WARNING("When reading initial persistent values from '"
+ << persist_filename << "':");
+ }
+ WARNING(persist_filename << ":" << line << ": " << msg);
+ });
+ }
+
+ using namespace std::placeholders;
+ m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3),
+ m_server);
+
+ m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
+ m_clientserver_thread = std::thread(&Dispatcher::ServerThreadMain, this);
+}
+
+void DispatcherBase::StartClient(
+ std::function<std::unique_ptr<NetworkStream>()> connect) {
+ {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ if (m_active) return;
+ m_active = true;
+ }
+ m_server = false;
+
+ using namespace std::placeholders;
+ m_storage.SetOutgoing(std::bind(&Dispatcher::QueueOutgoing, this, _1, _2, _3),
+ m_server);
+
+ m_dispatch_thread = std::thread(&Dispatcher::DispatchThreadMain, this);
+ m_clientserver_thread =
+ std::thread(&Dispatcher::ClientThreadMain, this, connect);
+}
+
+void DispatcherBase::Stop() {
+ m_active = false;
+
+ // wake up dispatch thread with a flush
+ m_flush_cv.notify_one();
+
+ // wake up client thread with a reconnect
+ ClientReconnect();
+
+ // wake up server thread by shutting down the socket
+ if (m_server_acceptor) m_server_acceptor->shutdown();
+
+ // join threads, with timeout
+ if (m_dispatch_thread.joinable()) m_dispatch_thread.join();
+ if (m_clientserver_thread.joinable()) m_clientserver_thread.join();
+
+ std::vector<std::shared_ptr<NetworkConnection>> conns;
+ {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ conns.swap(m_connections);
+ }
+
+ // close all connections
+ conns.resize(0);
+}
+
+void DispatcherBase::SetUpdateRate(double interval) {
+ // don't allow update rates faster than 100 ms or slower than 1 second
+ if (interval < 0.1)
+ interval = 0.1;
+ else if (interval > 1.0)
+ interval = 1.0;
+ m_update_rate = static_cast<unsigned int>(interval * 1000);
+}
+
+void DispatcherBase::SetIdentity(llvm::StringRef name) {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ m_identity = name;
+}
+
+void DispatcherBase::Flush() {
+ auto now = std::chrono::steady_clock::now();
+ {
+ std::lock_guard<std::mutex> lock(m_flush_mutex);
+ // don't allow flushes more often than every 100 ms
+ if ((now - m_last_flush) < std::chrono::milliseconds(100))
+ return;
+ m_last_flush = now;
+ m_do_flush = true;
+ }
+ m_flush_cv.notify_one();
+}
+
+std::vector<ConnectionInfo> DispatcherBase::GetConnections() const {
+ std::vector<ConnectionInfo> conns;
+ if (!m_active) return conns;
+
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ for (auto& conn : m_connections) {
+ if (conn->state() != NetworkConnection::kActive) continue;
+ conns.emplace_back(conn->info());
+ }
+
+ return conns;
+}
+
+void DispatcherBase::NotifyConnections(
+ ConnectionListenerCallback callback) const {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ for (auto& conn : m_connections) {
+ if (conn->state() != NetworkConnection::kActive) continue;
+ m_notifier.NotifyConnection(true, conn->info(), callback);
+ }
+}
+
+void DispatcherBase::DispatchThreadMain() {
+ auto timeout_time = std::chrono::steady_clock::now();
+
+ static const auto save_delta_time = std::chrono::seconds(1);
+ auto next_save_time = timeout_time + save_delta_time;
+
+ int count = 0;
+
+ std::unique_lock<std::mutex> flush_lock(m_flush_mutex);
+ while (m_active) {
+ // handle loop taking too long
+ auto start = std::chrono::steady_clock::now();
+ if (start > timeout_time)
+ timeout_time = start;
+
+ // wait for periodic or when flushed
+ timeout_time += std::chrono::milliseconds(m_update_rate);
+ m_flush_cv.wait_until(flush_lock, timeout_time,
+ [&] { return !m_active || m_do_flush; });
+ m_do_flush = false;
+ if (!m_active) break; // in case we were woken up to terminate
+
+ // perform periodic persistent save
+ if (m_server && !m_persist_filename.empty() && start > next_save_time) {
+ next_save_time += save_delta_time;
+ // handle loop taking too long
+ if (start > next_save_time) next_save_time = start + save_delta_time;
+ const char* err = m_storage.SavePersistent(m_persist_filename, true);
+ if (err) WARNING("periodic persistent save: " << err);
+ }
+
+ {
+ std::lock_guard<std::mutex> user_lock(m_user_mutex);
+ bool reconnect = false;
+
+ if (++count > 10) {
+ DEBUG("dispatch running " << m_connections.size() << " connections");
+ count = 0;
+ }
+
+ for (auto& conn : m_connections) {
+ // post outgoing messages if connection is active
+ // only send keep-alives on client
+ if (conn->state() == NetworkConnection::kActive)
+ conn->PostOutgoing(!m_server);
+
+ // if client, reconnect if connection died
+ if (!m_server && conn->state() == NetworkConnection::kDead)
+ reconnect = true;
+ }
+ // reconnect if we disconnected (and a reconnect is not in progress)
+ if (reconnect && !m_do_reconnect) {
+ m_do_reconnect = true;
+ m_reconnect_cv.notify_one();
+ }
+ }
+ }
+}
+
+void DispatcherBase::QueueOutgoing(std::shared_ptr<Message> msg,
+ NetworkConnection* only,
+ NetworkConnection* except) {
+ std::lock_guard<std::mutex> user_lock(m_user_mutex);
+ for (auto& conn : m_connections) {
+ if (conn.get() == except) continue;
+ if (only && conn.get() != only) continue;
+ auto state = conn->state();
+ if (state != NetworkConnection::kSynchronized &&
+ state != NetworkConnection::kActive) continue;
+ conn->QueueOutgoing(msg);
+ }
+}
+
+void DispatcherBase::ServerThreadMain() {
+ if (m_server_acceptor->start() != 0) {
+ m_active = false;
+ return;
+ }
+ while (m_active) {
+ auto stream = m_server_acceptor->accept();
+ if (!stream) {
+ m_active = false;
+ return;
+ }
+ if (!m_active) return;
+ DEBUG("server: client connection from " << stream->getPeerIP() << " port "
+ << stream->getPeerPort());
+
+ // add to connections list
+ using namespace std::placeholders;
+ auto conn = std::make_shared<NetworkConnection>(
+ std::move(stream), m_notifier,
+ std::bind(&Dispatcher::ServerHandshake, this, _1, _2, _3),
+ std::bind(&Storage::GetEntryType, &m_storage, _1));
+ conn->set_process_incoming(
+ std::bind(&Storage::ProcessIncoming, &m_storage, _1, _2,
+ std::weak_ptr<NetworkConnection>(conn)));
+ {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ // reuse dead connection slots
+ bool placed = false;
+ for (auto& c : m_connections) {
+ if (c->state() == NetworkConnection::kDead) {
+ c = conn;
+ placed = true;
+ break;
+ }
+ }
+ if (!placed) m_connections.emplace_back(conn);
+ conn->Start();
+ }
+ }
+}
+
+void DispatcherBase::ClientThreadMain(
+ std::function<std::unique_ptr<NetworkStream>()> connect) {
+ while (m_active) {
+ // sleep between retries
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+ // try to connect (with timeout)
+ DEBUG("client trying to connect");
+ auto stream = connect();
+ if (!stream) continue; // keep retrying
+ DEBUG("client connected");
+
+ std::unique_lock<std::mutex> lock(m_user_mutex);
+ using namespace std::placeholders;
+ auto conn = std::make_shared<NetworkConnection>(
+ std::move(stream), m_notifier,
+ std::bind(&Dispatcher::ClientHandshake, this, _1, _2, _3),
+ std::bind(&Storage::GetEntryType, &m_storage, _1));
+ conn->set_process_incoming(
+ std::bind(&Storage::ProcessIncoming, &m_storage, _1, _2,
+ std::weak_ptr<NetworkConnection>(conn)));
+ m_connections.resize(0); // disconnect any current
+ m_connections.emplace_back(conn);
+ conn->set_proto_rev(m_reconnect_proto_rev);
+ conn->Start();
+
+ // block until told to reconnect
+ m_do_reconnect = false;
+ m_reconnect_cv.wait(lock, [&] { return !m_active || m_do_reconnect; });
+ }
+}
+
+bool DispatcherBase::ClientHandshake(
+ NetworkConnection& conn,
+ std::function<std::shared_ptr<Message>()> get_msg,
+ std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
+ // get identity
+ std::string self_id;
+ {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ self_id = m_identity;
+ }
+
+ // send client hello
+ DEBUG("client: sending hello");
+ send_msgs(Message::ClientHello(self_id));
+
+ // wait for response
+ auto msg = get_msg();
+ if (!msg) {
+ // disconnected, retry
+ DEBUG("client: server disconnected before first response");
+ return false;
+ }
+
+ if (msg->Is(Message::kProtoUnsup)) {
+ if (msg->id() == 0x0200) ClientReconnect(0x0200);
+ return false;
+ }
+
+ bool new_server = true;
+ if (conn.proto_rev() >= 0x0300) {
+ // should be server hello; if not, disconnect.
+ if (!msg->Is(Message::kServerHello)) return false;
+ conn.set_remote_id(msg->str());
+ if ((msg->flags() & 1) != 0) new_server = false;
+ // get the next message
+ msg = get_msg();
+ }
+
+ // receive initial assignments
+ std::vector<std::shared_ptr<Message>> incoming;
+ for (;;) {
+ if (!msg) {
+ // disconnected, retry
+ DEBUG("client: server disconnected during initial entries");
+ return false;
+ }
+ DEBUG4("received init str=" << msg->str() << " id=" << msg->id()
+ << " seq_num=" << msg->seq_num_uid());
+ if (msg->Is(Message::kServerHelloDone)) break;
+ if (!msg->Is(Message::kEntryAssign)) {
+ // unexpected message
+ DEBUG("client: received message (" << msg->type() << ") other than entry assignment during initial handshake");
+ return false;
+ }
+ incoming.emplace_back(std::move(msg));
+ // get the next message
+ msg = get_msg();
+ }
+
+ // generate outgoing assignments
+ NetworkConnection::Outgoing outgoing;
+
+ m_storage.ApplyInitialAssignments(conn, incoming, new_server, &outgoing);
+
+ if (conn.proto_rev() >= 0x0300)
+ outgoing.emplace_back(Message::ClientHelloDone());
+
+ if (!outgoing.empty()) send_msgs(outgoing);
+
+ INFO("client: CONNECTED to server " << conn.stream().getPeerIP() << " port "
+ << conn.stream().getPeerPort());
+ return true;
+}
+
+bool DispatcherBase::ServerHandshake(
+ NetworkConnection& conn,
+ std::function<std::shared_ptr<Message>()> get_msg,
+ std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs) {
+ // Wait for the client to send us a hello.
+ auto msg = get_msg();
+ if (!msg) {
+ DEBUG("server: client disconnected before sending hello");
+ return false;
+ }
+ if (!msg->Is(Message::kClientHello)) {
+ DEBUG("server: client initial message was not client hello");
+ return false;
+ }
+
+ // Check that the client requested version is not too high.
+ unsigned int proto_rev = msg->id();
+ if (proto_rev > 0x0300) {
+ DEBUG("server: client requested proto > 0x0300");
+ send_msgs(Message::ProtoUnsup());
+ return false;
+ }
+
+ if (proto_rev >= 0x0300) conn.set_remote_id(msg->str());
+
+ // Set the proto version to the client requested version
+ DEBUG("server: client protocol " << proto_rev);
+ conn.set_proto_rev(proto_rev);
+
+ // Send initial set of assignments
+ NetworkConnection::Outgoing outgoing;
+
+ // Start with server hello. TODO: initial connection flag
+ if (proto_rev >= 0x0300) {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ outgoing.emplace_back(Message::ServerHello(0u, m_identity));
+ }
+
+ // Get snapshot of initial assignments
+ m_storage.GetInitialAssignments(conn, &outgoing);
+
+ // Finish with server hello done
+ outgoing.emplace_back(Message::ServerHelloDone());
+
+ // Batch transmit
+ DEBUG("server: sending initial assignments");
+ send_msgs(outgoing);
+
+ // In proto rev 3.0 and later, the handshake concludes with a client hello
+ // done message, so we can batch the assigns before marking the connection
+ // active. In pre-3.0, we need to just immediately mark it active and hand
+ // off control to the dispatcher to assign them as they arrive.
+ if (proto_rev >= 0x0300) {
+ // receive client initial assignments
+ std::vector<std::shared_ptr<Message>> incoming;
+ msg = get_msg();
+ for (;;) {
+ if (!msg) {
+ // disconnected, retry
+ DEBUG("server: disconnected waiting for initial entries");
+ return false;
+ }
+ if (msg->Is(Message::kClientHelloDone)) break;
+ if (!msg->Is(Message::kEntryAssign)) {
+ // unexpected message
+ DEBUG("server: received message ("
+ << msg->type()
+ << ") other than entry assignment during initial handshake");
+ return false;
+ }
+ incoming.push_back(msg);
+ // get the next message (blocks)
+ msg = get_msg();
+ }
+ for (auto& msg : incoming)
+ m_storage.ProcessIncoming(msg, &conn, std::weak_ptr<NetworkConnection>());
+ }
+
+ INFO("server: client CONNECTED: " << conn.stream().getPeerIP() << " port "
+ << conn.stream().getPeerPort());
+ return true;
+}
+
+void DispatcherBase::ClientReconnect(unsigned int proto_rev) {
+ if (m_server) return;
+ {
+ std::lock_guard<std::mutex> lock(m_user_mutex);
+ m_reconnect_proto_rev = proto_rev;
+ m_do_reconnect = true;
+ }
+ m_reconnect_cv.notify_one();
+}
diff --git a/src/Dispatcher.h b/src/Dispatcher.h
new file mode 100644
index 0000000..26f5e76
--- /dev/null
+++ b/src/Dispatcher.h
@@ -0,0 +1,127 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_DISPATCHER_H_
+#define NT_DISPATCHER_H_
+
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <vector>
+
+#include "llvm/StringRef.h"
+
+#include "atomic_static.h"
+#include "NetworkConnection.h"
+#include "Notifier.h"
+#include "Storage.h"
+
+class NetworkAcceptor;
+class NetworkStream;
+
+namespace nt {
+
+class DispatcherBase {
+ friend class DispatcherTest;
+ public:
+ virtual ~DispatcherBase();
+
+ void StartServer(StringRef persist_filename,
+ std::unique_ptr<NetworkAcceptor> acceptor);
+ void StartClient(std::function<std::unique_ptr<NetworkStream>()> connect);
+ void Stop();
+ void SetUpdateRate(double interval);
+ void SetIdentity(llvm::StringRef name);
+ void Flush();
+ std::vector<ConnectionInfo> GetConnections() const;
+ void NotifyConnections(ConnectionListenerCallback callback) const;
+
+ bool active() const { return m_active; }
+
+ DispatcherBase(const DispatcherBase&) = delete;
+ DispatcherBase& operator=(const DispatcherBase&) = delete;
+
+ protected:
+ DispatcherBase(Storage& storage, Notifier& notifier);
+
+ private:
+ void DispatchThreadMain();
+ void ServerThreadMain();
+ void ClientThreadMain(
+ std::function<std::unique_ptr<NetworkStream>()> connect);
+
+ bool ClientHandshake(
+ NetworkConnection& conn,
+ std::function<std::shared_ptr<Message>()> get_msg,
+ std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs);
+ bool ServerHandshake(
+ NetworkConnection& conn,
+ std::function<std::shared_ptr<Message>()> get_msg,
+ std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs);
+
+ void ClientReconnect(unsigned int proto_rev = 0x0300);
+
+ void QueueOutgoing(std::shared_ptr<Message> msg, NetworkConnection* only,
+ NetworkConnection* except);
+
+ Storage& m_storage;
+ Notifier& m_notifier;
+ bool m_server = false;
+ std::string m_persist_filename;
+ std::thread m_dispatch_thread;
+ std::thread m_clientserver_thread;
+
+ std::unique_ptr<NetworkAcceptor> m_server_acceptor;
+
+ // Mutex for user-accessible items
+ mutable std::mutex m_user_mutex;
+ std::vector<std::shared_ptr<NetworkConnection>> m_connections;
+ std::string m_identity;
+
+ std::atomic_bool m_active; // set to false to terminate threads
+ std::atomic_uint m_update_rate; // periodic dispatch update rate, in ms
+
+ // Condition variable for forced dispatch wakeup (flush)
+ std::mutex m_flush_mutex;
+ std::condition_variable m_flush_cv;
+ std::chrono::steady_clock::time_point m_last_flush;
+ bool m_do_flush = false;
+
+ // Condition variable for client reconnect (uses user mutex)
+ std::condition_variable m_reconnect_cv;
+ unsigned int m_reconnect_proto_rev = 0x0300;
+ bool m_do_reconnect = true;
+};
+
+class Dispatcher : public DispatcherBase {
+ friend class DispatcherTest;
+ public:
+ static Dispatcher& GetInstance() {
+ ATOMIC_STATIC(Dispatcher, instance);
+ return instance;
+ }
+
+ void StartServer(StringRef persist_filename, const char* listen_address,
+ unsigned int port);
+ void StartClient(const char* server_name, unsigned int port);
+
+ private:
+ Dispatcher();
+ Dispatcher(Storage& storage, Notifier& notifier)
+ : DispatcherBase(storage, notifier) {}
+
+ ATOMIC_STATIC_DECL(Dispatcher)
+};
+
+
+} // namespace nt
+
+#endif // NT_DISPATCHER_H_
diff --git a/src/Log.cpp b/src/Log.cpp
new file mode 100644
index 0000000..9b5a18d
--- /dev/null
+++ b/src/Log.cpp
@@ -0,0 +1,62 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "Log.h"
+
+#include <cstdio>
+#ifdef _WIN32
+ #include <cstdlib>
+#else
+ #include <cstring>
+#endif
+
+#ifdef __APPLE__
+ #include <libgen.h>
+#endif
+
+using namespace nt;
+
+ATOMIC_STATIC_INIT(Logger)
+
+static void def_log_func(unsigned int level, const char* file,
+ unsigned int line, const char* msg) {
+ if (level == 20) {
+ std::fprintf(stderr, "NT: %s\n", msg);
+ return;
+ }
+
+ const char* levelmsg;
+ if (level >= 50)
+ levelmsg = "CRITICAL";
+ else if (level >= 40)
+ levelmsg = "ERROR";
+ else if (level >= 30)
+ levelmsg = "WARNING";
+ else
+ return;
+#ifdef _WIN32
+ char fname[60];
+ char ext[10];
+ _splitpath_s(file, nullptr, 0, nullptr, 0, fname, 60, ext, 10);
+ std::fprintf(stderr, "NT: %s: %s (%s%s:%d)\n", levelmsg, msg, fname, ext,
+ line);
+#elif __APPLE__
+ int len = strlen(msg) + 1;
+ char* basestr = new char[len + 1];
+ strncpy(basestr, file, len);
+ std::fprintf(stderr, "NT: %s: %s (%s:%d)\n", levelmsg, msg, basename(basestr),
+ line);
+ delete[] basestr;
+#else
+ std::fprintf(stderr, "NT: %s: %s (%s:%d)\n", levelmsg, msg, basename(file),
+ line);
+#endif
+}
+
+Logger::Logger() : m_func(def_log_func) {}
+
+Logger::~Logger() {}
diff --git a/src/Log.h b/src/Log.h
new file mode 100644
index 0000000..dd9e125
--- /dev/null
+++ b/src/Log.h
@@ -0,0 +1,84 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_LOG_H_
+#define NT_LOG_H_
+
+#include <functional>
+#include <sstream>
+#include <string>
+
+#include "atomic_static.h"
+#include "ntcore_c.h"
+
+namespace nt {
+
+class Logger {
+ public:
+ static Logger& GetInstance() {
+ ATOMIC_STATIC(Logger, instance);
+ return instance;
+ }
+ ~Logger();
+
+ typedef std::function<void(unsigned int level, const char* file,
+ unsigned int line, const char* msg)> LogFunc;
+
+ void SetLogger(LogFunc func) { m_func = func; }
+
+ void set_min_level(unsigned int level) { m_min_level = level; }
+ unsigned int min_level() const { return m_min_level; }
+
+ void Log(unsigned int level, const char* file, unsigned int line,
+ const char* msg) {
+ if (!m_func || level < m_min_level) return;
+ m_func(level, file, line, msg);
+ }
+
+ bool HasLogger() const { return m_func != nullptr; }
+
+ private:
+ Logger();
+
+ LogFunc m_func;
+ unsigned int m_min_level = 20;
+
+ ATOMIC_STATIC_DECL(Logger)
+};
+
+#define LOG(level, x) \
+ do { \
+ nt::Logger& logger = nt::Logger::GetInstance(); \
+ if (logger.min_level() <= level && logger.HasLogger()) { \
+ std::ostringstream oss; \
+ oss << x; \
+ logger.Log(level, __FILE__, __LINE__, oss.str().c_str()); \
+ } \
+ } while (0)
+
+#undef ERROR
+#define ERROR(x) LOG(NT_LOG_ERROR, x)
+#define WARNING(x) LOG(NT_LOG_WARNING, x)
+#define INFO(x) LOG(NT_LOG_INFO, x)
+
+#ifdef NDEBUG
+#define DEBUG(x) do {} while (0)
+#define DEBUG1(x) do {} while (0)
+#define DEBUG2(x) do {} while (0)
+#define DEBUG3(x) do {} while (0)
+#define DEBUG4(x) do {} while (0)
+#else
+#define DEBUG(x) LOG(NT_LOG_DEBUG, x)
+#define DEBUG1(x) LOG(NT_LOG_DEBUG1, x)
+#define DEBUG2(x) LOG(NT_LOG_DEBUG2, x)
+#define DEBUG3(x) LOG(NT_LOG_DEBUG3, x)
+#define DEBUG4(x) LOG(NT_LOG_DEBUG4, x)
+#endif
+
+} // namespace nt
+
+#endif // NT_LOG_H_
diff --git a/src/Message.cpp b/src/Message.cpp
new file mode 100644
index 0000000..161d303
--- /dev/null
+++ b/src/Message.cpp
@@ -0,0 +1,303 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "Message.h"
+
+#include "Log.h"
+#include "WireDecoder.h"
+#include "WireEncoder.h"
+
+#define kClearAllMagic 0xD06CB27Aul
+
+using namespace nt;
+
+std::shared_ptr<Message> Message::Read(WireDecoder& decoder,
+ GetEntryTypeFunc get_entry_type) {
+ unsigned int msg_type;
+ if (!decoder.Read8(&msg_type)) return nullptr;
+ auto msg =
+ std::make_shared<Message>(static_cast<MsgType>(msg_type), private_init());
+ switch (msg_type) {
+ case kKeepAlive:
+ break;
+ case kClientHello: {
+ unsigned int proto_rev;
+ if (!decoder.Read16(&proto_rev)) return nullptr;
+ msg->m_id = proto_rev;
+ // This intentionally uses the provided proto_rev instead of
+ // decoder.proto_rev().
+ if (proto_rev >= 0x0300u) {
+ if (!decoder.ReadString(&msg->m_str)) return nullptr;
+ }
+ break;
+ }
+ case kProtoUnsup: {
+ if (!decoder.Read16(&msg->m_id)) return nullptr; // proto rev
+ break;
+ }
+ case kServerHelloDone:
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received SERVER_HELLO_DONE in protocol < 3.0");
+ return nullptr;
+ }
+ break;
+ case kServerHello:
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received SERVER_HELLO_DONE in protocol < 3.0");
+ return nullptr;
+ }
+ if (!decoder.Read8(&msg->m_flags)) return nullptr;
+ if (!decoder.ReadString(&msg->m_str)) return nullptr;
+ break;
+ case kClientHelloDone:
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received CLIENT_HELLO_DONE in protocol < 3.0");
+ return nullptr;
+ }
+ break;
+ case kEntryAssign: {
+ if (!decoder.ReadString(&msg->m_str)) return nullptr;
+ NT_Type type;
+ if (!decoder.ReadType(&type)) return nullptr; // name
+ if (!decoder.Read16(&msg->m_id)) return nullptr; // id
+ if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // seq num
+ if (decoder.proto_rev() >= 0x0300u) {
+ if (!decoder.Read8(&msg->m_flags)) return nullptr; // flags
+ }
+ msg->m_value = decoder.ReadValue(type);
+ if (!msg->m_value) return nullptr;
+ break;
+ }
+ case kEntryUpdate: {
+ if (!decoder.Read16(&msg->m_id)) return nullptr; // id
+ if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // seq num
+ NT_Type type;
+ if (decoder.proto_rev() >= 0x0300u) {
+ if (!decoder.ReadType(&type)) return nullptr;
+ } else {
+ type = get_entry_type(msg->m_id);
+ }
+ DEBUG4("update message data type: " << type);
+ msg->m_value = decoder.ReadValue(type);
+ if (!msg->m_value) return nullptr;
+ break;
+ }
+ case kFlagsUpdate: {
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received FLAGS_UPDATE in protocol < 3.0");
+ return nullptr;
+ }
+ if (!decoder.Read16(&msg->m_id)) return nullptr;
+ if (!decoder.Read8(&msg->m_flags)) return nullptr;
+ break;
+ }
+ case kEntryDelete: {
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received ENTRY_DELETE in protocol < 3.0");
+ return nullptr;
+ }
+ if (!decoder.Read16(&msg->m_id)) return nullptr;
+ break;
+ }
+ case kClearEntries: {
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received CLEAR_ENTRIES in protocol < 3.0");
+ return nullptr;
+ }
+ unsigned long magic;
+ if (!decoder.Read32(&magic)) return nullptr;
+ if (magic != kClearAllMagic) {
+ decoder.set_error(
+ "received incorrect CLEAR_ENTRIES magic value, ignoring");
+ return nullptr;
+ }
+ break;
+ }
+ case kExecuteRpc: {
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received EXECUTE_RPC in protocol < 3.0");
+ return nullptr;
+ }
+ if (!decoder.Read16(&msg->m_id)) return nullptr;
+ if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // uid
+ unsigned long size;
+ if (!decoder.ReadUleb128(&size)) return nullptr;
+ const char* params;
+ if (!decoder.Read(¶ms, size)) return nullptr;
+ msg->m_str = llvm::StringRef(params, size);
+ break;
+ }
+ case kRpcResponse: {
+ if (decoder.proto_rev() < 0x0300u) {
+ decoder.set_error("received RPC_RESPONSE in protocol < 3.0");
+ return nullptr;
+ }
+ if (!decoder.Read16(&msg->m_id)) return nullptr;
+ if (!decoder.Read16(&msg->m_seq_num_uid)) return nullptr; // uid
+ unsigned long size;
+ if (!decoder.ReadUleb128(&size)) return nullptr;
+ const char* results;
+ if (!decoder.Read(&results, size)) return nullptr;
+ msg->m_str = llvm::StringRef(results, size);
+ break;
+ }
+ default:
+ decoder.set_error("unrecognized message type");
+ INFO("unrecognized message type: " << msg_type);
+ return nullptr;
+ }
+ return msg;
+}
+
+std::shared_ptr<Message> Message::ClientHello(llvm::StringRef self_id) {
+ auto msg = std::make_shared<Message>(kClientHello, private_init());
+ msg->m_str = self_id;
+ return msg;
+}
+
+std::shared_ptr<Message> Message::ServerHello(unsigned int flags,
+ llvm::StringRef self_id) {
+ auto msg = std::make_shared<Message>(kServerHello, private_init());
+ msg->m_str = self_id;
+ msg->m_flags = flags;
+ return msg;
+}
+
+std::shared_ptr<Message> Message::EntryAssign(llvm::StringRef name,
+ unsigned int id,
+ unsigned int seq_num,
+ std::shared_ptr<Value> value,
+ unsigned int flags) {
+ auto msg = std::make_shared<Message>(kEntryAssign, private_init());
+ msg->m_str = name;
+ msg->m_value = value;
+ msg->m_id = id;
+ msg->m_flags = flags;
+ msg->m_seq_num_uid = seq_num;
+ return msg;
+}
+
+std::shared_ptr<Message> Message::EntryUpdate(unsigned int id,
+ unsigned int seq_num,
+ std::shared_ptr<Value> value) {
+ auto msg = std::make_shared<Message>(kEntryUpdate, private_init());
+ msg->m_value = value;
+ msg->m_id = id;
+ msg->m_seq_num_uid = seq_num;
+ return msg;
+}
+
+std::shared_ptr<Message> Message::FlagsUpdate(unsigned int id,
+ unsigned int flags) {
+ auto msg = std::make_shared<Message>(kFlagsUpdate, private_init());
+ msg->m_id = id;
+ msg->m_flags = flags;
+ return msg;
+}
+
+std::shared_ptr<Message> Message::EntryDelete(unsigned int id) {
+ auto msg = std::make_shared<Message>(kEntryDelete, private_init());
+ msg->m_id = id;
+ return msg;
+}
+
+std::shared_ptr<Message> Message::ExecuteRpc(unsigned int id, unsigned int uid,
+ llvm::StringRef params) {
+ auto msg = std::make_shared<Message>(kExecuteRpc, private_init());
+ msg->m_str = params;
+ msg->m_id = id;
+ msg->m_seq_num_uid = uid;
+ return msg;
+}
+
+std::shared_ptr<Message> Message::RpcResponse(unsigned int id, unsigned int uid,
+ llvm::StringRef results) {
+ auto msg = std::make_shared<Message>(kRpcResponse, private_init());
+ msg->m_str = results;
+ msg->m_id = id;
+ msg->m_seq_num_uid = uid;
+ return msg;
+}
+
+void Message::Write(WireEncoder& encoder) const {
+ switch (m_type) {
+ case kKeepAlive:
+ encoder.Write8(kKeepAlive);
+ break;
+ case kClientHello:
+ encoder.Write8(kClientHello);
+ encoder.Write16(encoder.proto_rev());
+ if (encoder.proto_rev() < 0x0300u) return;
+ encoder.WriteString(m_str);
+ break;
+ case kProtoUnsup:
+ encoder.Write8(kProtoUnsup);
+ encoder.Write16(encoder.proto_rev());
+ break;
+ case kServerHelloDone:
+ encoder.Write8(kServerHelloDone);
+ break;
+ case kServerHello:
+ if (encoder.proto_rev() < 0x0300u) return; // new message in version 3.0
+ encoder.Write8(kServerHello);
+ encoder.Write8(m_flags);
+ encoder.WriteString(m_str);
+ break;
+ case kClientHelloDone:
+ if (encoder.proto_rev() < 0x0300u) return; // new message in version 3.0
+ encoder.Write8(kClientHelloDone);
+ break;
+ case kEntryAssign:
+ encoder.Write8(kEntryAssign);
+ encoder.WriteString(m_str);
+ encoder.WriteType(m_value->type());
+ encoder.Write16(m_id);
+ encoder.Write16(m_seq_num_uid);
+ if (encoder.proto_rev() >= 0x0300u) encoder.Write8(m_flags);
+ encoder.WriteValue(*m_value);
+ break;
+ case kEntryUpdate:
+ encoder.Write8(kEntryUpdate);
+ encoder.Write16(m_id);
+ encoder.Write16(m_seq_num_uid);
+ if (encoder.proto_rev() >= 0x0300u) encoder.WriteType(m_value->type());
+ encoder.WriteValue(*m_value);
+ break;
+ case kFlagsUpdate:
+ if (encoder.proto_rev() < 0x0300u) return; // new message in version 3.0
+ encoder.Write8(kFlagsUpdate);
+ encoder.Write16(m_id);
+ encoder.Write8(m_flags);
+ break;
+ case kEntryDelete:
+ if (encoder.proto_rev() < 0x0300u) return; // new message in version 3.0
+ encoder.Write8(kEntryDelete);
+ encoder.Write16(m_id);
+ break;
+ case kClearEntries:
+ if (encoder.proto_rev() < 0x0300u) return; // new message in version 3.0
+ encoder.Write8(kClearEntries);
+ encoder.Write32(kClearAllMagic);
+ break;
+ case kExecuteRpc:
+ if (encoder.proto_rev() < 0x0300u) return; // new message in version 3.0
+ encoder.Write8(kExecuteRpc);
+ encoder.Write16(m_id);
+ encoder.Write16(m_seq_num_uid);
+ encoder.WriteString(m_str);
+ break;
+ case kRpcResponse:
+ if (encoder.proto_rev() < 0x0300u) return; // new message in version 3.0
+ encoder.Write8(kRpcResponse);
+ encoder.Write16(m_id);
+ encoder.Write16(m_seq_num_uid);
+ encoder.WriteString(m_str);
+ break;
+ default:
+ break;
+ }
+}
diff --git a/src/Message.h b/src/Message.h
new file mode 100644
index 0000000..3047834
--- /dev/null
+++ b/src/Message.h
@@ -0,0 +1,117 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_MESSAGE_H_
+#define NT_MESSAGE_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+
+#include "nt_Value.h"
+
+namespace nt {
+
+class WireDecoder;
+class WireEncoder;
+
+class Message {
+ struct private_init {};
+
+ public:
+ enum MsgType {
+ kUnknown = -1,
+ kKeepAlive = 0x00,
+ kClientHello = 0x01,
+ kProtoUnsup = 0x02,
+ kServerHelloDone = 0x03,
+ kServerHello = 0x04,
+ kClientHelloDone = 0x05,
+ kEntryAssign = 0x10,
+ kEntryUpdate = 0x11,
+ kFlagsUpdate = 0x12,
+ kEntryDelete = 0x13,
+ kClearEntries = 0x14,
+ kExecuteRpc = 0x20,
+ kRpcResponse = 0x21
+ };
+ typedef std::function<NT_Type(unsigned int id)> GetEntryTypeFunc;
+
+ Message() : m_type(kUnknown), m_id(0), m_flags(0), m_seq_num_uid(0) {}
+ Message(MsgType type, const private_init&)
+ : m_type(type), m_id(0), m_flags(0), m_seq_num_uid(0) {}
+
+ MsgType type() const { return m_type; }
+ bool Is(MsgType type) const { return type == m_type; }
+
+ // Message data accessors. Callers are responsible for knowing what data is
+ // actually provided for a particular message.
+ llvm::StringRef str() const { return m_str; }
+ std::shared_ptr<Value> value() const { return m_value; }
+ unsigned int id() const { return m_id; }
+ unsigned int flags() const { return m_flags; }
+ unsigned int seq_num_uid() const { return m_seq_num_uid; }
+
+ // Read and write from wire representation
+ void Write(WireEncoder& encoder) const;
+ static std::shared_ptr<Message> Read(WireDecoder& decoder,
+ GetEntryTypeFunc get_entry_type);
+
+ // Create messages without data
+ static std::shared_ptr<Message> KeepAlive() {
+ return std::make_shared<Message>(kKeepAlive, private_init());
+ }
+ static std::shared_ptr<Message> ProtoUnsup() {
+ return std::make_shared<Message>(kProtoUnsup, private_init());
+ }
+ static std::shared_ptr<Message> ServerHelloDone() {
+ return std::make_shared<Message>(kServerHelloDone, private_init());
+ }
+ static std::shared_ptr<Message> ClientHelloDone() {
+ return std::make_shared<Message>(kClientHelloDone, private_init());
+ }
+ static std::shared_ptr<Message> ClearEntries() {
+ return std::make_shared<Message>(kClearEntries, private_init());
+ }
+
+ // Create messages with data
+ static std::shared_ptr<Message> ClientHello(llvm::StringRef self_id);
+ static std::shared_ptr<Message> ServerHello(unsigned int flags,
+ llvm::StringRef self_id);
+ static std::shared_ptr<Message> EntryAssign(llvm::StringRef name,
+ unsigned int id,
+ unsigned int seq_num,
+ std::shared_ptr<Value> value,
+ unsigned int flags);
+ static std::shared_ptr<Message> EntryUpdate(unsigned int id,
+ unsigned int seq_num,
+ std::shared_ptr<Value> value);
+ static std::shared_ptr<Message> FlagsUpdate(unsigned int id,
+ unsigned int flags);
+ static std::shared_ptr<Message> EntryDelete(unsigned int id);
+ static std::shared_ptr<Message> ExecuteRpc(unsigned int id, unsigned int uid,
+ llvm::StringRef params);
+ static std::shared_ptr<Message> RpcResponse(unsigned int id, unsigned int uid,
+ llvm::StringRef results);
+
+ Message(const Message&) = delete;
+ Message& operator=(const Message&) = delete;
+
+ private:
+ MsgType m_type;
+
+ // Message data. Use varies by message type.
+ std::string m_str;
+ std::shared_ptr<Value> m_value;
+ unsigned int m_id; // also used for proto_rev
+ unsigned int m_flags;
+ unsigned int m_seq_num_uid;
+};
+
+} // namespace nt
+
+#endif // NT_MESSAGE_H_
diff --git a/src/NetworkConnection.cpp b/src/NetworkConnection.cpp
new file mode 100644
index 0000000..b2b741e
--- /dev/null
+++ b/src/NetworkConnection.cpp
@@ -0,0 +1,311 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "NetworkConnection.h"
+
+#include "support/timestamp.h"
+#include "tcpsockets/NetworkStream.h"
+#include "Log.h"
+#include "Notifier.h"
+#include "raw_socket_istream.h"
+#include "WireDecoder.h"
+#include "WireEncoder.h"
+
+using namespace nt;
+
+std::atomic_uint NetworkConnection::s_uid;
+
+NetworkConnection::NetworkConnection(std::unique_ptr<NetworkStream> stream,
+ Notifier& notifier,
+ HandshakeFunc handshake,
+ Message::GetEntryTypeFunc get_entry_type)
+ : m_uid(s_uid.fetch_add(1)),
+ m_stream(std::move(stream)),
+ m_notifier(notifier),
+ m_handshake(handshake),
+ m_get_entry_type(get_entry_type) {
+ m_active = false;
+ m_proto_rev = 0x0300;
+ m_state = static_cast<int>(kCreated);
+ m_last_update = 0;
+
+ // turn off Nagle algorithm; we bundle packets for transmission
+ m_stream->setNoDelay();
+}
+
+NetworkConnection::~NetworkConnection() { Stop(); }
+
+void NetworkConnection::Start() {
+ if (m_active) return;
+ m_active = true;
+ m_state = static_cast<int>(kInit);
+ // clear queue
+ while (!m_outgoing.empty()) m_outgoing.pop();
+ // reset shutdown flags
+ {
+ std::lock_guard<std::mutex> lock(m_shutdown_mutex);
+ m_read_shutdown = false;
+ m_write_shutdown = false;
+ }
+ // start threads
+ m_write_thread = std::thread(&NetworkConnection::WriteThreadMain, this);
+ m_read_thread = std::thread(&NetworkConnection::ReadThreadMain, this);
+}
+
+void NetworkConnection::Stop() {
+ DEBUG2("NetworkConnection stopping (" << this << ")");
+ m_state = static_cast<int>(kDead);
+ m_active = false;
+ // closing the stream so the read thread terminates
+ if (m_stream) m_stream->close();
+ // send an empty outgoing message set so the write thread terminates
+ m_outgoing.push(Outgoing());
+ // wait for threads to terminate, with timeout
+ if (m_write_thread.joinable()) {
+ std::unique_lock<std::mutex> lock(m_shutdown_mutex);
+ auto timeout_time =
+ std::chrono::steady_clock::now() + std::chrono::milliseconds(200);
+ if (m_write_shutdown_cv.wait_until(lock, timeout_time,
+ [&] { return m_write_shutdown; }))
+ m_write_thread.join();
+ else
+ m_write_thread.detach(); // timed out, detach it
+ }
+ if (m_read_thread.joinable()) {
+ std::unique_lock<std::mutex> lock(m_shutdown_mutex);
+ auto timeout_time =
+ std::chrono::steady_clock::now() + std::chrono::milliseconds(200);
+ if (m_read_shutdown_cv.wait_until(lock, timeout_time,
+ [&] { return m_read_shutdown; }))
+ m_read_thread.join();
+ else
+ m_read_thread.detach(); // timed out, detach it
+ }
+ // clear queue
+ while (!m_outgoing.empty()) m_outgoing.pop();
+}
+
+ConnectionInfo NetworkConnection::info() const {
+ return ConnectionInfo{remote_id(), m_stream->getPeerIP(),
+ static_cast<unsigned int>(m_stream->getPeerPort()),
+ m_last_update, m_proto_rev};
+}
+
+std::string NetworkConnection::remote_id() const {
+ std::lock_guard<std::mutex> lock(m_remote_id_mutex);
+ return m_remote_id;
+}
+
+void NetworkConnection::set_remote_id(StringRef remote_id) {
+ std::lock_guard<std::mutex> lock(m_remote_id_mutex);
+ m_remote_id = remote_id;
+}
+
+void NetworkConnection::ReadThreadMain() {
+ raw_socket_istream is(*m_stream);
+ WireDecoder decoder(is, m_proto_rev);
+
+ m_state = static_cast<int>(kHandshake);
+ if (!m_handshake(*this,
+ [&] {
+ decoder.set_proto_rev(m_proto_rev);
+ auto msg = Message::Read(decoder, m_get_entry_type);
+ if (!msg && decoder.error())
+ DEBUG("error reading in handshake: " << decoder.error());
+ return msg;
+ },
+ [&](llvm::ArrayRef<std::shared_ptr<Message>> msgs) {
+ m_outgoing.emplace(msgs);
+ })) {
+ m_state = static_cast<int>(kDead);
+ m_active = false;
+ goto done;
+ }
+
+ m_state = static_cast<int>(kActive);
+ m_notifier.NotifyConnection(true, info());
+ while (m_active) {
+ if (!m_stream)
+ break;
+ decoder.set_proto_rev(m_proto_rev);
+ decoder.Reset();
+ auto msg = Message::Read(decoder, m_get_entry_type);
+ if (!msg) {
+ if (decoder.error()) INFO("read error: " << decoder.error());
+ // terminate connection on bad message
+ if (m_stream) m_stream->close();
+ break;
+ }
+ DEBUG3("received type=" << msg->type() << " with str=" << msg->str()
+ << " id=" << msg->id()
+ << " seq_num=" << msg->seq_num_uid());
+ m_last_update = Now();
+ m_process_incoming(std::move(msg), this);
+ }
+ DEBUG2("read thread died (" << this << ")");
+ if (m_state != kDead) m_notifier.NotifyConnection(false, info());
+ m_state = static_cast<int>(kDead);
+ m_active = false;
+ m_outgoing.push(Outgoing()); // also kill write thread
+
+done:
+ // use condition variable to signal thread shutdown
+ {
+ std::lock_guard<std::mutex> lock(m_shutdown_mutex);
+ m_read_shutdown = true;
+ m_read_shutdown_cv.notify_one();
+ }
+}
+
+void NetworkConnection::WriteThreadMain() {
+ WireEncoder encoder(m_proto_rev);
+
+ while (m_active) {
+ auto msgs = m_outgoing.pop();
+ DEBUG4("write thread woke up");
+ if (msgs.empty()) continue;
+ encoder.set_proto_rev(m_proto_rev);
+ encoder.Reset();
+ DEBUG3("sending " << msgs.size() << " messages");
+ for (auto& msg : msgs) {
+ if (msg) {
+ DEBUG3("sending type=" << msg->type() << " with str=" << msg->str()
+ << " id=" << msg->id()
+ << " seq_num=" << msg->seq_num_uid());
+ msg->Write(encoder);
+ }
+ }
+ NetworkStream::Error err;
+ if (!m_stream) break;
+ if (encoder.size() == 0) continue;
+ if (m_stream->send(encoder.data(), encoder.size(), &err) == 0) break;
+ DEBUG4("sent " << encoder.size() << " bytes");
+ }
+ DEBUG2("write thread died (" << this << ")");
+ if (m_state != kDead) m_notifier.NotifyConnection(false, info());
+ m_state = static_cast<int>(kDead);
+ m_active = false;
+ if (m_stream) m_stream->close(); // also kill read thread
+
+ // use condition variable to signal thread shutdown
+ {
+ std::lock_guard<std::mutex> lock(m_shutdown_mutex);
+ m_write_shutdown = true;
+ m_write_shutdown_cv.notify_one();
+ }
+}
+
+void NetworkConnection::QueueOutgoing(std::shared_ptr<Message> msg) {
+ std::lock_guard<std::mutex> lock(m_pending_mutex);
+
+ // Merge with previous. One case we don't combine: delete/assign loop.
+ switch (msg->type()) {
+ case Message::kEntryAssign:
+ case Message::kEntryUpdate: {
+ // don't do this for unassigned id's
+ unsigned int id = msg->id();
+ if (id == 0xffff) {
+ m_pending_outgoing.push_back(msg);
+ break;
+ }
+ if (id < m_pending_update.size() && m_pending_update[id].first != 0) {
+ // overwrite the previous one for this id
+ auto& oldmsg = m_pending_outgoing[m_pending_update[id].first - 1];
+ if (oldmsg && oldmsg->Is(Message::kEntryAssign) &&
+ msg->Is(Message::kEntryUpdate)) {
+ // need to update assignment with new seq_num and value
+ oldmsg = Message::EntryAssign(oldmsg->str(), id, msg->seq_num_uid(),
+ msg->value(), oldmsg->flags());
+ } else
+ oldmsg = msg; // easy update
+ } else {
+ // new, but remember it
+ std::size_t pos = m_pending_outgoing.size();
+ m_pending_outgoing.push_back(msg);
+ if (id >= m_pending_update.size()) m_pending_update.resize(id + 1);
+ m_pending_update[id].first = pos + 1;
+ }
+ break;
+ }
+ case Message::kEntryDelete: {
+ // don't do this for unassigned id's
+ unsigned int id = msg->id();
+ if (id == 0xffff) {
+ m_pending_outgoing.push_back(msg);
+ break;
+ }
+
+ // clear previous updates
+ if (id < m_pending_update.size()) {
+ if (m_pending_update[id].first != 0) {
+ m_pending_outgoing[m_pending_update[id].first - 1].reset();
+ m_pending_update[id].first = 0;
+ }
+ if (m_pending_update[id].second != 0) {
+ m_pending_outgoing[m_pending_update[id].second - 1].reset();
+ m_pending_update[id].second = 0;
+ }
+ }
+
+ // add deletion
+ m_pending_outgoing.push_back(msg);
+ break;
+ }
+ case Message::kFlagsUpdate: {
+ // don't do this for unassigned id's
+ unsigned int id = msg->id();
+ if (id == 0xffff) {
+ m_pending_outgoing.push_back(msg);
+ break;
+ }
+ if (id < m_pending_update.size() && m_pending_update[id].second != 0) {
+ // overwrite the previous one for this id
+ m_pending_outgoing[m_pending_update[id].second - 1] = msg;
+ } else {
+ // new, but remember it
+ std::size_t pos = m_pending_outgoing.size();
+ m_pending_outgoing.push_back(msg);
+ if (id >= m_pending_update.size()) m_pending_update.resize(id + 1);
+ m_pending_update[id].second = pos + 1;
+ }
+ break;
+ }
+ case Message::kClearEntries: {
+ // knock out all previous assigns/updates!
+ for (auto& i : m_pending_outgoing) {
+ if (!i) continue;
+ auto t = i->type();
+ if (t == Message::kEntryAssign || t == Message::kEntryUpdate ||
+ t == Message::kFlagsUpdate || t == Message::kEntryDelete ||
+ t == Message::kClearEntries)
+ i.reset();
+ }
+ m_pending_update.resize(0);
+ m_pending_outgoing.push_back(msg);
+ break;
+ }
+ default:
+ m_pending_outgoing.push_back(msg);
+ break;
+ }
+}
+
+void NetworkConnection::PostOutgoing(bool keep_alive) {
+ std::lock_guard<std::mutex> lock(m_pending_mutex);
+ auto now = std::chrono::steady_clock::now();
+ if (m_pending_outgoing.empty()) {
+ if (!keep_alive) return;
+ // send keep-alives once a second (if no other messages have been sent)
+ if ((now - m_last_post) < std::chrono::seconds(1)) return;
+ m_outgoing.emplace(Outgoing{Message::KeepAlive()});
+ } else {
+ m_outgoing.emplace(std::move(m_pending_outgoing));
+ m_pending_outgoing.resize(0);
+ m_pending_update.resize(0);
+ }
+ m_last_post = now;
+}
diff --git a/src/NetworkConnection.h b/src/NetworkConnection.h
new file mode 100644
index 0000000..2f1073c
--- /dev/null
+++ b/src/NetworkConnection.h
@@ -0,0 +1,115 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_NETWORKCONNECTION_H_
+#define NT_NETWORKCONNECTION_H_
+
+#include <atomic>
+#include <chrono>
+#include <memory>
+#include <thread>
+
+#include "support/ConcurrentQueue.h"
+#include "Message.h"
+#include "ntcore_cpp.h"
+
+class NetworkStream;
+
+namespace nt {
+
+class Notifier;
+
+class NetworkConnection {
+ public:
+ enum State { kCreated, kInit, kHandshake, kSynchronized, kActive, kDead };
+
+ typedef std::function<bool(
+ NetworkConnection& conn,
+ std::function<std::shared_ptr<Message>()> get_msg,
+ std::function<void(llvm::ArrayRef<std::shared_ptr<Message>>)> send_msgs)>
+ HandshakeFunc;
+ typedef std::function<void(std::shared_ptr<Message> msg,
+ NetworkConnection* conn)> ProcessIncomingFunc;
+ typedef std::vector<std::shared_ptr<Message>> Outgoing;
+ typedef ConcurrentQueue<Outgoing> OutgoingQueue;
+
+ NetworkConnection(std::unique_ptr<NetworkStream> stream,
+ Notifier& notifier,
+ HandshakeFunc handshake,
+ Message::GetEntryTypeFunc get_entry_type);
+ ~NetworkConnection();
+
+ // Set the input processor function. This must be called before Start().
+ void set_process_incoming(ProcessIncomingFunc func) {
+ m_process_incoming = func;
+ }
+
+ void Start();
+ void Stop();
+
+ ConnectionInfo info() const;
+
+ bool active() const { return m_active; }
+ NetworkStream& stream() { return *m_stream; }
+
+ void QueueOutgoing(std::shared_ptr<Message> msg);
+ void PostOutgoing(bool keep_alive);
+
+ unsigned int uid() const { return m_uid; }
+
+ unsigned int proto_rev() const { return m_proto_rev; }
+ void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; }
+
+ State state() const { return static_cast<State>(m_state.load()); }
+ void set_state(State state) { m_state = static_cast<int>(state); }
+
+ std::string remote_id() const;
+ void set_remote_id(StringRef remote_id);
+
+ unsigned long long last_update() const { return m_last_update; }
+
+ NetworkConnection(const NetworkConnection&) = delete;
+ NetworkConnection& operator=(const NetworkConnection&) = delete;
+
+ private:
+ void ReadThreadMain();
+ void WriteThreadMain();
+
+ static std::atomic_uint s_uid;
+
+ unsigned int m_uid;
+ std::unique_ptr<NetworkStream> m_stream;
+ Notifier& m_notifier;
+ OutgoingQueue m_outgoing;
+ HandshakeFunc m_handshake;
+ Message::GetEntryTypeFunc m_get_entry_type;
+ ProcessIncomingFunc m_process_incoming;
+ std::thread m_read_thread;
+ std::thread m_write_thread;
+ std::atomic_bool m_active;
+ std::atomic_uint m_proto_rev;
+ std::atomic_int m_state;
+ mutable std::mutex m_remote_id_mutex;
+ std::string m_remote_id;
+ std::atomic_ullong m_last_update;
+ std::chrono::steady_clock::time_point m_last_post;
+
+ std::mutex m_pending_mutex;
+ Outgoing m_pending_outgoing;
+ std::vector<std::pair<std::size_t, std::size_t>> m_pending_update;
+
+ // Condition variables for shutdown
+ std::mutex m_shutdown_mutex;
+ std::condition_variable m_read_shutdown_cv;
+ std::condition_variable m_write_shutdown_cv;
+ bool m_read_shutdown = false;
+ bool m_write_shutdown = false;
+};
+
+} // namespace nt
+
+#endif // NT_NETWORKCONNECTION_H_
diff --git a/src/Notifier.cpp b/src/Notifier.cpp
new file mode 100644
index 0000000..fb2e8dc
--- /dev/null
+++ b/src/Notifier.cpp
@@ -0,0 +1,202 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "Notifier.h"
+
+using namespace nt;
+
+ATOMIC_STATIC_INIT(Notifier)
+bool Notifier::s_destroyed = false;
+
+Notifier::Notifier() {
+ m_active = false;
+ m_local_notifiers = false;
+ s_destroyed = false;
+}
+
+Notifier::~Notifier() {
+ s_destroyed = true;
+ Stop();
+}
+
+void Notifier::Start() {
+ {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ if (m_active) return;
+ m_active = true;
+ }
+ {
+ std::lock_guard<std::mutex> lock(m_shutdown_mutex);
+ m_shutdown = false;
+ }
+ m_thread = std::thread(&Notifier::ThreadMain, this);
+}
+
+void Notifier::Stop() {
+ m_active = false;
+ // send notification so the thread terminates
+ m_cond.notify_one();
+ if (m_thread.joinable()) {
+ // join with timeout
+ std::unique_lock<std::mutex> lock(m_shutdown_mutex);
+ auto timeout_time =
+ std::chrono::steady_clock::now() + std::chrono::seconds(1);
+ if (m_shutdown_cv.wait_until(lock, timeout_time,
+ [&] { return m_shutdown; }))
+ m_thread.join();
+ else
+ m_thread.detach(); // timed out, detach it
+ }
+}
+
+void Notifier::ThreadMain() {
+ if (m_on_start) m_on_start();
+
+ std::unique_lock<std::mutex> lock(m_mutex);
+ while (m_active) {
+ while (m_entry_notifications.empty() && m_conn_notifications.empty()) {
+ m_cond.wait(lock);
+ if (!m_active) goto done;
+ }
+
+ // Entry notifications
+ while (!m_entry_notifications.empty()) {
+ if (!m_active) goto done;
+ auto item = std::move(m_entry_notifications.front());
+ m_entry_notifications.pop();
+
+ if (!item.value) continue;
+ StringRef name(item.name);
+
+ if (item.only) {
+ // Don't hold mutex during callback execution!
+ lock.unlock();
+ item.only(0, name, item.value, item.flags);
+ lock.lock();
+ continue;
+ }
+
+ // Use index because iterator might get invalidated.
+ for (std::size_t i=0; i<m_entry_listeners.size(); ++i) {
+ if (!m_entry_listeners[i].callback) continue; // removed
+
+ // Flags must be within requested flag set for this listener.
+ // Because assign messages can result in both a value and flags update,
+ // we handle that case specially.
+ unsigned int listen_flags = m_entry_listeners[i].flags;
+ unsigned int flags = item.flags;
+ unsigned int assign_both = NT_NOTIFY_UPDATE | NT_NOTIFY_FLAGS;
+ if ((flags & assign_both) == assign_both) {
+ if ((listen_flags & assign_both) == 0) continue;
+ listen_flags &= ~assign_both;
+ flags &= ~assign_both;
+ }
+ if ((flags & ~listen_flags) != 0) continue;
+
+ // must match prefix
+ if (!name.startswith(m_entry_listeners[i].prefix)) continue;
+
+ // make a copy of the callback so we can safely release the mutex
+ auto callback = m_entry_listeners[i].callback;
+
+ // Don't hold mutex during callback execution!
+ lock.unlock();
+ callback(i+1, name, item.value, item.flags);
+ lock.lock();
+ }
+ }
+
+ // Connection notifications
+ while (!m_conn_notifications.empty()) {
+ if (!m_active) goto done;
+ auto item = std::move(m_conn_notifications.front());
+ m_conn_notifications.pop();
+
+ if (item.only) {
+ // Don't hold mutex during callback execution!
+ lock.unlock();
+ item.only(0, item.connected, item.conn_info);
+ lock.lock();
+ continue;
+ }
+
+ // Use index because iterator might get invalidated.
+ for (std::size_t i=0; i<m_conn_listeners.size(); ++i) {
+ if (!m_conn_listeners[i]) continue; // removed
+ auto callback = m_conn_listeners[i];
+ // Don't hold mutex during callback execution!
+ lock.unlock();
+ callback(i+1, item.connected, item.conn_info);
+ lock.lock();
+ }
+ }
+ }
+
+done:
+ if (m_on_exit) m_on_exit();
+
+ // use condition variable to signal thread shutdown
+ {
+ std::lock_guard<std::mutex> lock(m_shutdown_mutex);
+ m_shutdown = true;
+ m_shutdown_cv.notify_one();
+ }
+}
+
+unsigned int Notifier::AddEntryListener(StringRef prefix,
+ EntryListenerCallback callback,
+ unsigned int flags) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ unsigned int uid = m_entry_listeners.size();
+ m_entry_listeners.emplace_back(prefix, callback, flags);
+ if ((flags & NT_NOTIFY_LOCAL) != 0) m_local_notifiers = true;
+ return uid + 1;
+}
+
+void Notifier::RemoveEntryListener(unsigned int entry_listener_uid) {
+ --entry_listener_uid;
+ std::lock_guard<std::mutex> lock(m_mutex);
+ if (entry_listener_uid < m_entry_listeners.size())
+ m_entry_listeners[entry_listener_uid].callback = nullptr;
+}
+
+void Notifier::NotifyEntry(StringRef name, std::shared_ptr<Value> value,
+ unsigned int flags, EntryListenerCallback only) {
+ if (!m_active) return;
+ // optimization: don't generate needless local queue entries if we have
+ // no local listeners (as this is a common case on the server side)
+ if ((flags & NT_NOTIFY_LOCAL) != 0 && !m_local_notifiers) return;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ m_entry_notifications.emplace(name, value, flags, only);
+ lock.unlock();
+ m_cond.notify_one();
+}
+
+unsigned int Notifier::AddConnectionListener(
+ ConnectionListenerCallback callback) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ unsigned int uid = m_entry_listeners.size();
+ m_conn_listeners.emplace_back(callback);
+ return uid + 1;
+}
+
+void Notifier::RemoveConnectionListener(unsigned int conn_listener_uid) {
+ --conn_listener_uid;
+ std::lock_guard<std::mutex> lock(m_mutex);
+ if (conn_listener_uid < m_conn_listeners.size())
+ m_conn_listeners[conn_listener_uid] = nullptr;
+}
+
+void Notifier::NotifyConnection(bool connected,
+ const ConnectionInfo& conn_info,
+ ConnectionListenerCallback only) {
+ if (!m_active) return;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ m_conn_notifications.emplace(connected, conn_info, only);
+ lock.unlock();
+ m_cond.notify_one();
+}
diff --git a/src/Notifier.h b/src/Notifier.h
new file mode 100644
index 0000000..d10054c
--- /dev/null
+++ b/src/Notifier.h
@@ -0,0 +1,120 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_NOTIFIER_H_
+#define NT_NOTIFIER_H_
+
+#include <atomic>
+#include <condition_variable>
+#include <mutex>
+#include <queue>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "atomic_static.h"
+#include "ntcore_cpp.h"
+
+namespace nt {
+
+class Notifier {
+ friend class NotifierTest;
+ public:
+ static Notifier& GetInstance() {
+ ATOMIC_STATIC(Notifier, instance);
+ return instance;
+ }
+ ~Notifier();
+
+ void Start();
+ void Stop();
+
+ bool active() const { return m_active; }
+ bool local_notifiers() const { return m_local_notifiers; }
+ static bool destroyed() { return s_destroyed; }
+
+ void SetOnStart(std::function<void()> on_start) { m_on_start = on_start; }
+ void SetOnExit(std::function<void()> on_exit) { m_on_exit = on_exit; }
+
+ unsigned int AddEntryListener(StringRef prefix,
+ EntryListenerCallback callback,
+ unsigned int flags);
+ void RemoveEntryListener(unsigned int entry_listener_uid);
+
+ void NotifyEntry(StringRef name, std::shared_ptr<Value> value,
+ unsigned int flags, EntryListenerCallback only = nullptr);
+
+ unsigned int AddConnectionListener(ConnectionListenerCallback callback);
+ void RemoveConnectionListener(unsigned int conn_listener_uid);
+
+ void NotifyConnection(bool connected, const ConnectionInfo& conn_info,
+ ConnectionListenerCallback only = nullptr);
+
+ private:
+ Notifier();
+
+ void ThreadMain();
+
+ std::atomic_bool m_active;
+ std::atomic_bool m_local_notifiers;
+
+ std::mutex m_mutex;
+ std::condition_variable m_cond;
+
+ struct EntryListener {
+ EntryListener(StringRef prefix_, EntryListenerCallback callback_,
+ unsigned int flags_)
+ : prefix(prefix_), callback(callback_), flags(flags_) {}
+
+ std::string prefix;
+ EntryListenerCallback callback;
+ unsigned int flags;
+ };
+ std::vector<EntryListener> m_entry_listeners;
+ std::vector<ConnectionListenerCallback> m_conn_listeners;
+
+ struct EntryNotification {
+ EntryNotification(StringRef name_, std::shared_ptr<Value> value_,
+ unsigned int flags_, EntryListenerCallback only_)
+ : name(name_),
+ value(value_),
+ flags(flags_),
+ only(only_) {}
+
+ std::string name;
+ std::shared_ptr<Value> value;
+ unsigned int flags;
+ EntryListenerCallback only;
+ };
+ std::queue<EntryNotification> m_entry_notifications;
+
+ struct ConnectionNotification {
+ ConnectionNotification(bool connected_, const ConnectionInfo& conn_info_,
+ ConnectionListenerCallback only_)
+ : connected(connected_), conn_info(conn_info_), only(only_) {}
+
+ bool connected;
+ ConnectionInfo conn_info;
+ ConnectionListenerCallback only;
+ };
+ std::queue<ConnectionNotification> m_conn_notifications;
+
+ std::thread m_thread;
+ std::mutex m_shutdown_mutex;
+ std::condition_variable m_shutdown_cv;
+ bool m_shutdown = false;
+
+ std::function<void()> m_on_start;
+ std::function<void()> m_on_exit;
+
+ ATOMIC_STATIC_DECL(Notifier)
+ static bool s_destroyed;
+};
+
+} // namespace nt
+
+#endif // NT_NOTIFIER_H_
diff --git a/src/RpcServer.cpp b/src/RpcServer.cpp
new file mode 100644
index 0000000..43d37de
--- /dev/null
+++ b/src/RpcServer.cpp
@@ -0,0 +1,141 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "RpcServer.h"
+
+#include "Log.h"
+
+using namespace nt;
+
+ATOMIC_STATIC_INIT(RpcServer)
+
+RpcServer::RpcServer() {
+ m_active = false;
+ m_terminating = false;
+}
+
+RpcServer::~RpcServer() {
+ Logger::GetInstance().SetLogger(nullptr);
+ Stop();
+ m_terminating = true;
+ m_poll_cond.notify_all();
+}
+
+void RpcServer::Start() {
+ {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ if (m_active) return;
+ m_active = true;
+ }
+ {
+ std::lock_guard<std::mutex> lock(m_shutdown_mutex);
+ m_shutdown = false;
+ }
+ m_thread = std::thread(&RpcServer::ThreadMain, this);
+}
+
+void RpcServer::Stop() {
+ m_active = false;
+ if (m_thread.joinable()) {
+ // send notification so the thread terminates
+ m_call_cond.notify_one();
+ // join with timeout
+ std::unique_lock<std::mutex> lock(m_shutdown_mutex);
+ auto timeout_time =
+ std::chrono::steady_clock::now() + std::chrono::seconds(1);
+ if (m_shutdown_cv.wait_until(lock, timeout_time,
+ [&] { return m_shutdown; }))
+ m_thread.join();
+ else
+ m_thread.detach(); // timed out, detach it
+ }
+}
+
+void RpcServer::ProcessRpc(StringRef name, std::shared_ptr<Message> msg,
+ RpcCallback func, unsigned int conn_id,
+ SendMsgFunc send_response) {
+ std::unique_lock<std::mutex> lock(m_mutex);
+
+ if (func)
+ m_call_queue.emplace(name, msg, func, conn_id, send_response);
+ else
+ m_poll_queue.emplace(name, msg, func, conn_id, send_response);
+
+ lock.unlock();
+
+ if (func)
+ m_call_cond.notify_one();
+ else
+ m_poll_cond.notify_one();
+}
+
+bool RpcServer::PollRpc(bool blocking, RpcCallInfo* call_info) {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ while (m_poll_queue.empty()) {
+ if (!blocking || m_terminating) return false;
+ m_poll_cond.wait(lock);
+ }
+
+ auto& item = m_poll_queue.front();
+ unsigned int call_uid = (item.conn_id << 16) | item.msg->seq_num_uid();
+ call_info->rpc_id = item.msg->id();
+ call_info->call_uid = call_uid;
+ call_info->name = std::move(item.name);
+ call_info->params = item.msg->str();
+ m_response_map.insert(std::make_pair(std::make_pair(item.msg->id(), call_uid),
+ item.send_response));
+ m_poll_queue.pop();
+ return true;
+}
+
+void RpcServer::PostRpcResponse(unsigned int rpc_id, unsigned int call_uid,
+ llvm::StringRef result) {
+ auto i = m_response_map.find(std::make_pair(rpc_id, call_uid));
+ if (i == m_response_map.end()) {
+ WARNING("posting RPC response to nonexistent call (or duplicate response)");
+ return;
+ }
+ (i->getSecond())(Message::RpcResponse(rpc_id, call_uid, result));
+ m_response_map.erase(i);
+}
+
+void RpcServer::ThreadMain() {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ std::string tmp;
+ while (m_active) {
+ while (m_call_queue.empty()) {
+ m_call_cond.wait(lock);
+ if (!m_active) goto done;
+ }
+
+ while (!m_call_queue.empty()) {
+ if (!m_active) goto done;
+ auto item = std::move(m_call_queue.front());
+ m_call_queue.pop();
+
+ DEBUG4("rpc calling " << item.name);
+
+ if (item.name.empty() || !item.msg || !item.func || !item.send_response)
+ continue;
+
+ // Don't hold mutex during callback execution!
+ lock.unlock();
+ auto result = item.func(item.name, item.msg->str());
+ item.send_response(Message::RpcResponse(item.msg->id(),
+ item.msg->seq_num_uid(), result));
+ lock.lock();
+ }
+ }
+
+done:
+ // use condition variable to signal thread shutdown
+ {
+ std::lock_guard<std::mutex> lock(m_shutdown_mutex);
+ m_shutdown = true;
+ m_shutdown_cv.notify_one();
+ }
+}
diff --git a/src/RpcServer.h b/src/RpcServer.h
new file mode 100644
index 0000000..726034d
--- /dev/null
+++ b/src/RpcServer.h
@@ -0,0 +1,91 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_RPCSERVER_H_
+#define NT_RPCSERVER_H_
+
+#include <atomic>
+#include <condition_variable>
+#include <mutex>
+#include <queue>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "llvm/DenseMap.h"
+#include "atomic_static.h"
+#include "Message.h"
+#include "ntcore_cpp.h"
+
+namespace nt {
+
+class RpcServer {
+ friend class RpcServerTest;
+ public:
+ static RpcServer& GetInstance() {
+ ATOMIC_STATIC(RpcServer, instance);
+ return instance;
+ }
+ ~RpcServer();
+
+ typedef std::function<void(std::shared_ptr<Message>)> SendMsgFunc;
+
+ void Start();
+ void Stop();
+
+ bool active() const { return m_active; }
+
+ void ProcessRpc(StringRef name, std::shared_ptr<Message> msg,
+ RpcCallback func, unsigned int conn_id,
+ SendMsgFunc send_response);
+
+ bool PollRpc(bool blocking, RpcCallInfo* call_info);
+ void PostRpcResponse(unsigned int rpc_id, unsigned int call_uid,
+ llvm::StringRef result);
+
+ private:
+ RpcServer();
+
+ void ThreadMain();
+
+ std::atomic_bool m_active;
+ std::atomic_bool m_terminating;
+
+ std::mutex m_mutex;
+ std::condition_variable m_call_cond, m_poll_cond;
+
+ struct RpcCall {
+ RpcCall(StringRef name_, std::shared_ptr<Message> msg_, RpcCallback func_,
+ unsigned int conn_id_, SendMsgFunc send_response_)
+ : name(name_),
+ msg(msg_),
+ func(func_),
+ conn_id(conn_id_),
+ send_response(send_response_) {}
+
+ std::string name;
+ std::shared_ptr<Message> msg;
+ RpcCallback func;
+ unsigned int conn_id;
+ SendMsgFunc send_response;
+ };
+ std::queue<RpcCall> m_call_queue, m_poll_queue;
+
+ llvm::DenseMap<std::pair<unsigned int, unsigned int>, SendMsgFunc>
+ m_response_map;
+
+ std::thread m_thread;
+ std::mutex m_shutdown_mutex;
+ std::condition_variable m_shutdown_cv;
+ bool m_shutdown = false;
+
+ ATOMIC_STATIC_DECL(RpcServer)
+};
+
+} // namespace nt
+
+#endif // NT_RPCSERVER_H_
diff --git a/src/SequenceNumber.cpp b/src/SequenceNumber.cpp
new file mode 100644
index 0000000..b22bfec
--- /dev/null
+++ b/src/SequenceNumber.cpp
@@ -0,0 +1,30 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "SequenceNumber.h"
+
+namespace nt {
+
+bool operator<(const SequenceNumber& lhs, const SequenceNumber& rhs) {
+ if (lhs.m_value < rhs.m_value)
+ return (rhs.m_value - lhs.m_value) < (1u << 15);
+ else if (lhs.m_value > rhs.m_value)
+ return (lhs.m_value - rhs.m_value) > (1u << 15);
+ else
+ return false;
+}
+
+bool operator>(const SequenceNumber& lhs, const SequenceNumber& rhs) {
+ if (lhs.m_value < rhs.m_value)
+ return (rhs.m_value - lhs.m_value) > (1u << 15);
+ else if (lhs.m_value > rhs.m_value)
+ return (lhs.m_value - rhs.m_value) < (1u << 15);
+ else
+ return false;
+}
+
+} // namespace nt
diff --git a/src/SequenceNumber.h b/src/SequenceNumber.h
new file mode 100644
index 0000000..a8f85a4
--- /dev/null
+++ b/src/SequenceNumber.h
@@ -0,0 +1,63 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_SEQNUM_H_
+#define NT_SEQNUM_H_
+
+namespace nt {
+
+/* A sequence number per RFC 1982 */
+class SequenceNumber {
+ public:
+ SequenceNumber() : m_value(0) {}
+ explicit SequenceNumber(unsigned int value) : m_value(value) {}
+ unsigned int value() const { return m_value; }
+
+ SequenceNumber& operator++() {
+ ++m_value;
+ if (m_value > 0xffff) m_value = 0;
+ return *this;
+ }
+ SequenceNumber operator++(int) {
+ SequenceNumber tmp(*this);
+ operator++();
+ return tmp;
+ }
+
+ friend bool operator<(const SequenceNumber& lhs, const SequenceNumber& rhs);
+ friend bool operator>(const SequenceNumber& lhs, const SequenceNumber& rhs);
+ friend bool operator<=(const SequenceNumber& lhs, const SequenceNumber& rhs);
+ friend bool operator>=(const SequenceNumber& lhs, const SequenceNumber& rhs);
+ friend bool operator==(const SequenceNumber& lhs, const SequenceNumber& rhs);
+ friend bool operator!=(const SequenceNumber& lhs, const SequenceNumber& rhs);
+
+ private:
+ unsigned int m_value;
+};
+
+bool operator<(const SequenceNumber& lhs, const SequenceNumber& rhs);
+bool operator>(const SequenceNumber& lhs, const SequenceNumber& rhs);
+
+inline bool operator<=(const SequenceNumber& lhs, const SequenceNumber& rhs) {
+ return lhs == rhs || lhs < rhs;
+}
+
+inline bool operator>=(const SequenceNumber& lhs, const SequenceNumber& rhs) {
+ return lhs == rhs || lhs > rhs;
+}
+
+inline bool operator==(const SequenceNumber& lhs, const SequenceNumber& rhs) {
+ return lhs.m_value == rhs.m_value;
+}
+
+inline bool operator!=(const SequenceNumber& lhs, const SequenceNumber& rhs) {
+ return lhs.m_value != rhs.m_value;
+}
+
+} // namespace nt
+
+#endif // NT_SEQNUM_H_
diff --git a/src/Storage.cpp b/src/Storage.cpp
new file mode 100644
index 0000000..e2903e0
--- /dev/null
+++ b/src/Storage.cpp
@@ -0,0 +1,1381 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "Storage.h"
+
+#include <cctype>
+#include <string>
+#include <tuple>
+
+#include "llvm/StringExtras.h"
+#include "Base64.h"
+#include "Log.h"
+#include "NetworkConnection.h"
+
+using namespace nt;
+
+ATOMIC_STATIC_INIT(Storage)
+
+Storage::Storage()
+ : Storage(Notifier::GetInstance(), RpcServer::GetInstance()) {}
+
+Storage::Storage(Notifier& notifier, RpcServer& rpc_server)
+ : m_notifier(notifier), m_rpc_server(rpc_server) {
+ m_terminating = false;
+}
+
+Storage::~Storage() {
+ Logger::GetInstance().SetLogger(nullptr);
+ m_terminating = true;
+ m_rpc_results_cond.notify_all();
+}
+
+void Storage::SetOutgoing(QueueOutgoingFunc queue_outgoing, bool server) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ m_queue_outgoing = queue_outgoing;
+ m_server = server;
+}
+
+void Storage::ClearOutgoing() {
+ m_queue_outgoing = nullptr;
+}
+
+NT_Type Storage::GetEntryType(unsigned int id) const {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ if (id >= m_idmap.size()) return NT_UNASSIGNED;
+ Entry* entry = m_idmap[id];
+ if (!entry || !entry->value) return NT_UNASSIGNED;
+ return entry->value->type();
+}
+
+void Storage::ProcessIncoming(std::shared_ptr<Message> msg,
+ NetworkConnection* conn,
+ std::weak_ptr<NetworkConnection> conn_weak) {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ switch (msg->type()) {
+ case Message::kKeepAlive:
+ break; // ignore
+ case Message::kClientHello:
+ case Message::kProtoUnsup:
+ case Message::kServerHelloDone:
+ case Message::kServerHello:
+ case Message::kClientHelloDone:
+ // shouldn't get these, but ignore if we do
+ break;
+ case Message::kEntryAssign: {
+ unsigned int id = msg->id();
+ StringRef name = msg->str();
+ Entry* entry;
+ bool may_need_update = false;
+ if (m_server) {
+ // if we're a server, id=0xffff requests are requests for an id
+ // to be assigned, and we need to send the new assignment back to
+ // the sender as well as all other connections.
+ if (id == 0xffff) {
+ // see if it was already assigned; ignore if so.
+ if (m_entries.count(name) != 0) return;
+
+ // create it locally
+ id = m_idmap.size();
+ auto& new_entry = m_entries[name];
+ if (!new_entry) new_entry.reset(new Entry(name));
+ entry = new_entry.get();
+ entry->value = msg->value();
+ entry->flags = msg->flags();
+ entry->id = id;
+ m_idmap.push_back(entry);
+
+ // update persistent dirty flag if it's persistent
+ if (entry->IsPersistent()) m_persistent_dirty = true;
+
+ // notify
+ m_notifier.NotifyEntry(name, entry->value, NT_NOTIFY_NEW);
+
+ // send the assignment to everyone (including the originator)
+ if (m_queue_outgoing) {
+ auto queue_outgoing = m_queue_outgoing;
+ auto outmsg = Message::EntryAssign(
+ name, id, entry->seq_num.value(), msg->value(), msg->flags());
+ lock.unlock();
+ queue_outgoing(outmsg, nullptr, nullptr);
+ }
+ return;
+ }
+ if (id >= m_idmap.size() || !m_idmap[id]) {
+ // ignore arbitrary entry assignments
+ // this can happen due to e.g. assignment to deleted entry
+ lock.unlock();
+ DEBUG("server: received assignment to unknown entry");
+ return;
+ }
+ entry = m_idmap[id];
+ } else {
+ // clients simply accept new assignments
+ if (id == 0xffff) {
+ lock.unlock();
+ DEBUG("client: received entry assignment request?");
+ return;
+ }
+ if (id >= m_idmap.size()) m_idmap.resize(id+1);
+ entry = m_idmap[id];
+ if (!entry) {
+ // create local
+ auto& new_entry = m_entries[name];
+ if (!new_entry) {
+ // didn't exist at all (rather than just being a response to a
+ // id assignment request)
+ new_entry.reset(new Entry(name));
+ new_entry->value = msg->value();
+ new_entry->flags = msg->flags();
+ new_entry->id = id;
+ m_idmap[id] = new_entry.get();
+
+ // notify
+ m_notifier.NotifyEntry(name, new_entry->value, NT_NOTIFY_NEW);
+ return;
+ }
+ may_need_update = true; // we may need to send an update message
+ entry = new_entry.get();
+ entry->id = id;
+ m_idmap[id] = entry;
+
+ // if the received flags don't match what we sent, we most likely
+ // updated flags locally in the interim; send flags update message.
+ if (msg->flags() != entry->flags) {
+ auto queue_outgoing = m_queue_outgoing;
+ auto outmsg = Message::FlagsUpdate(id, entry->flags);
+ lock.unlock();
+ queue_outgoing(outmsg, nullptr, nullptr);
+ lock.lock();
+ }
+ }
+ }
+
+ // common client and server handling
+
+ // already exists; ignore if sequence number not higher than local
+ SequenceNumber seq_num(msg->seq_num_uid());
+ if (seq_num < entry->seq_num) {
+ if (may_need_update) {
+ auto queue_outgoing = m_queue_outgoing;
+ auto outmsg = Message::EntryUpdate(entry->id, entry->seq_num.value(),
+ entry->value);
+ lock.unlock();
+ queue_outgoing(outmsg, nullptr, nullptr);
+ }
+ return;
+ }
+
+ // sanity check: name should match id
+ if (msg->str() != entry->name) {
+ lock.unlock();
+ DEBUG("entry assignment for same id with different name?");
+ return;
+ }
+
+ unsigned int notify_flags = NT_NOTIFY_UPDATE;
+
+ // don't update flags from a <3.0 remote (not part of message)
+ // don't update flags if this is a server response to a client id request
+ if (!may_need_update && conn->proto_rev() >= 0x0300) {
+ // update persistent dirty flag if persistent flag changed
+ if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT))
+ m_persistent_dirty = true;
+ if (entry->flags != msg->flags())
+ notify_flags |= NT_NOTIFY_FLAGS;
+ entry->flags = msg->flags();
+ }
+
+ // update persistent dirty flag if the value changed and it's persistent
+ if (entry->IsPersistent() && *entry->value != *msg->value())
+ m_persistent_dirty = true;
+
+ // update local
+ entry->value = msg->value();
+ entry->seq_num = seq_num;
+
+ // notify
+ m_notifier.NotifyEntry(name, entry->value, notify_flags);
+
+ // broadcast to all other connections (note for client there won't
+ // be any other connections, so don't bother)
+ if (m_server && m_queue_outgoing) {
+ auto queue_outgoing = m_queue_outgoing;
+ auto outmsg =
+ Message::EntryAssign(entry->name, id, msg->seq_num_uid(),
+ msg->value(), entry->flags);
+ lock.unlock();
+ queue_outgoing(outmsg, nullptr, conn);
+ }
+ break;
+ }
+ case Message::kEntryUpdate: {
+ unsigned int id = msg->id();
+ if (id >= m_idmap.size() || !m_idmap[id]) {
+ // ignore arbitrary entry updates;
+ // this can happen due to deleted entries
+ lock.unlock();
+ DEBUG("received update to unknown entry");
+ return;
+ }
+ Entry* entry = m_idmap[id];
+
+ // ignore if sequence number not higher than local
+ SequenceNumber seq_num(msg->seq_num_uid());
+ if (seq_num <= entry->seq_num) return;
+
+ // update local
+ entry->value = msg->value();
+ entry->seq_num = seq_num;
+
+ // update persistent dirty flag if it's a persistent value
+ if (entry->IsPersistent()) m_persistent_dirty = true;
+
+ // notify
+ m_notifier.NotifyEntry(entry->name, entry->value, NT_NOTIFY_UPDATE);
+
+ // broadcast to all other connections (note for client there won't
+ // be any other connections, so don't bother)
+ if (m_server && m_queue_outgoing) {
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ queue_outgoing(msg, nullptr, conn);
+ }
+ break;
+ }
+ case Message::kFlagsUpdate: {
+ unsigned int id = msg->id();
+ if (id >= m_idmap.size() || !m_idmap[id]) {
+ // ignore arbitrary entry updates;
+ // this can happen due to deleted entries
+ lock.unlock();
+ DEBUG("received flags update to unknown entry");
+ return;
+ }
+ Entry* entry = m_idmap[id];
+
+ // ignore if flags didn't actually change
+ if (entry->flags == msg->flags()) return;
+
+ // update persistent dirty flag if persistent flag changed
+ if ((entry->flags & NT_PERSISTENT) != (msg->flags() & NT_PERSISTENT))
+ m_persistent_dirty = true;
+
+ // update local
+ entry->flags = msg->flags();
+
+ // notify
+ m_notifier.NotifyEntry(entry->name, entry->value, NT_NOTIFY_FLAGS);
+
+ // broadcast to all other connections (note for client there won't
+ // be any other connections, so don't bother)
+ if (m_server && m_queue_outgoing) {
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ queue_outgoing(msg, nullptr, conn);
+ }
+ break;
+ }
+ case Message::kEntryDelete: {
+ unsigned int id = msg->id();
+ if (id >= m_idmap.size() || !m_idmap[id]) {
+ // ignore arbitrary entry updates;
+ // this can happen due to deleted entries
+ lock.unlock();
+ DEBUG("received delete to unknown entry");
+ return;
+ }
+ Entry* entry = m_idmap[id];
+
+ // update persistent dirty flag if it's a persistent value
+ if (entry->IsPersistent()) m_persistent_dirty = true;
+
+ // delete it from idmap
+ m_idmap[id] = nullptr;
+
+ // get entry (as we'll need it for notify) and erase it from the map
+ // it should always be in the map, but sanity check just in case
+ auto i = m_entries.find(entry->name);
+ if (i != m_entries.end()) {
+ auto entry2 = std::move(i->getValue()); // move the value out
+ m_entries.erase(i);
+
+ // notify
+ m_notifier.NotifyEntry(entry2->name, entry2->value, NT_NOTIFY_DELETE);
+ }
+
+ // broadcast to all other connections (note for client there won't
+ // be any other connections, so don't bother)
+ if (m_server && m_queue_outgoing) {
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ queue_outgoing(msg, nullptr, conn);
+ }
+ break;
+ }
+ case Message::kClearEntries: {
+ // update local
+ EntriesMap map;
+ m_entries.swap(map);
+ m_idmap.resize(0);
+
+ // set persistent dirty flag
+ m_persistent_dirty = true;
+
+ // notify
+ for (auto& entry : map)
+ m_notifier.NotifyEntry(entry.getKey(), entry.getValue()->value,
+ NT_NOTIFY_DELETE);
+
+ // broadcast to all other connections (note for client there won't
+ // be any other connections, so don't bother)
+ if (m_server && m_queue_outgoing) {
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ queue_outgoing(msg, nullptr, conn);
+ }
+ break;
+ }
+ case Message::kExecuteRpc: {
+ if (!m_server) return; // only process on server
+ unsigned int id = msg->id();
+ if (id >= m_idmap.size() || !m_idmap[id]) {
+ // ignore call to non-existent RPC
+ // this can happen due to deleted entries
+ lock.unlock();
+ DEBUG("received RPC call to unknown entry");
+ return;
+ }
+ Entry* entry = m_idmap[id];
+ if (!entry->value->IsRpc()) {
+ lock.unlock();
+ DEBUG("received RPC call to non-RPC entry");
+ return;
+ }
+ m_rpc_server.ProcessRpc(entry->name, msg, entry->rpc_callback,
+ conn->uid(), [=](std::shared_ptr<Message> msg) {
+ auto c = conn_weak.lock();
+ if (c) c->QueueOutgoing(msg);
+ });
+ break;
+ }
+ case Message::kRpcResponse: {
+ if (m_server) return; // only process on client
+ m_rpc_results.insert(std::make_pair(
+ std::make_pair(msg->id(), msg->seq_num_uid()), msg->str()));
+ m_rpc_results_cond.notify_all();
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+void Storage::GetInitialAssignments(
+ NetworkConnection& conn, std::vector<std::shared_ptr<Message>>* msgs) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ conn.set_state(NetworkConnection::kSynchronized);
+ for (auto& i : m_entries) {
+ Entry* entry = i.getValue().get();
+ msgs->emplace_back(Message::EntryAssign(i.getKey(), entry->id,
+ entry->seq_num.value(),
+ entry->value, entry->flags));
+ }
+}
+
+void Storage::ApplyInitialAssignments(
+ NetworkConnection& conn, llvm::ArrayRef<std::shared_ptr<Message>> msgs,
+ bool new_server, std::vector<std::shared_ptr<Message>>* out_msgs) {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ if (m_server) return; // should not do this on server
+
+ conn.set_state(NetworkConnection::kSynchronized);
+
+ std::vector<std::shared_ptr<Message>> update_msgs;
+
+ // clear existing id's
+ for (auto& i : m_entries) i.getValue()->id = 0xffff;
+
+ // clear existing idmap
+ m_idmap.resize(0);
+
+ // apply assignments
+ for (auto& msg : msgs) {
+ if (!msg->Is(Message::kEntryAssign)) {
+ DEBUG("client: received non-entry assignment request?");
+ continue;
+ }
+
+ unsigned int id = msg->id();
+ if (id == 0xffff) {
+ DEBUG("client: received entry assignment request?");
+ continue;
+ }
+
+ SequenceNumber seq_num(msg->seq_num_uid());
+ StringRef name = msg->str();
+
+ auto& entry = m_entries[name];
+ if (!entry) {
+ // doesn't currently exist
+ entry.reset(new Entry(name));
+ entry->value = msg->value();
+ entry->flags = msg->flags();
+ entry->seq_num = seq_num;
+ // notify
+ m_notifier.NotifyEntry(name, entry->value, NT_NOTIFY_NEW);
+ } else {
+ // if reconnect and sequence number not higher than local, then we
+ // don't update the local value and instead send it back to the server
+ // as an update message
+ if (!new_server && seq_num <= entry->seq_num) {
+ update_msgs.emplace_back(Message::EntryUpdate(
+ entry->id, entry->seq_num.value(), entry->value));
+ } else {
+ entry->value = msg->value();
+ entry->seq_num = seq_num;
+ unsigned int notify_flags = NT_NOTIFY_UPDATE;
+ // don't update flags from a <3.0 remote (not part of message)
+ if (conn.proto_rev() >= 0x0300) {
+ if (entry->flags != msg->flags()) notify_flags |= NT_NOTIFY_FLAGS;
+ entry->flags = msg->flags();
+ }
+ // notify
+ m_notifier.NotifyEntry(name, entry->value, notify_flags);
+ }
+ }
+
+ // set id and save to idmap
+ entry->id = id;
+ if (id >= m_idmap.size()) m_idmap.resize(id+1);
+ m_idmap[id] = entry.get();
+ }
+
+ // generate assign messages for unassigned local entries
+ for (auto& i : m_entries) {
+ Entry* entry = i.getValue().get();
+ if (entry->id != 0xffff) continue;
+ out_msgs->emplace_back(Message::EntryAssign(entry->name, entry->id,
+ entry->seq_num.value(),
+ entry->value, entry->flags));
+ }
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ for (auto& msg : update_msgs) queue_outgoing(msg, nullptr, nullptr);
+}
+
+std::shared_ptr<Value> Storage::GetEntryValue(StringRef name) const {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ auto i = m_entries.find(name);
+ return i == m_entries.end() ? nullptr : i->getValue()->value;
+}
+
+bool Storage::SetEntryValue(StringRef name, std::shared_ptr<Value> value) {
+ if (name.empty()) return true;
+ if (!value) return true;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ auto& new_entry = m_entries[name];
+ if (!new_entry) new_entry.reset(new Entry(name));
+ Entry* entry = new_entry.get();
+ auto old_value = entry->value;
+ if (old_value && old_value->type() != value->type())
+ return false; // error on type mismatch
+ entry->value = value;
+
+ // if we're the server, assign an id if it doesn't have one
+ if (m_server && entry->id == 0xffff) {
+ unsigned int id = m_idmap.size();
+ entry->id = id;
+ m_idmap.push_back(entry);
+ }
+
+ // update persistent dirty flag if value changed and it's persistent
+ if (entry->IsPersistent() && *old_value != *value) m_persistent_dirty = true;
+
+ // notify (for local listeners)
+ if (m_notifier.local_notifiers()) {
+ if (!old_value)
+ m_notifier.NotifyEntry(name, value, NT_NOTIFY_NEW | NT_NOTIFY_LOCAL);
+ else if (*old_value != *value)
+ m_notifier.NotifyEntry(name, value, NT_NOTIFY_UPDATE | NT_NOTIFY_LOCAL);
+ }
+
+ // generate message
+ if (!m_queue_outgoing) return true;
+ auto queue_outgoing = m_queue_outgoing;
+ if (!old_value) {
+ auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
+ value, entry->flags);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ } else if (*old_value != *value) {
+ ++entry->seq_num;
+ // don't send an update if we don't have an assigned id yet
+ if (entry->id != 0xffff) {
+ auto msg =
+ Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ }
+ }
+ return true;
+}
+
+void Storage::SetEntryTypeValue(StringRef name, std::shared_ptr<Value> value) {
+ if (name.empty()) return;
+ if (!value) return;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ auto& new_entry = m_entries[name];
+ if (!new_entry) new_entry.reset(new Entry(name));
+ Entry* entry = new_entry.get();
+ auto old_value = entry->value;
+ entry->value = value;
+ if (old_value && *old_value == *value) return;
+
+ // if we're the server, assign an id if it doesn't have one
+ if (m_server && entry->id == 0xffff) {
+ unsigned int id = m_idmap.size();
+ entry->id = id;
+ m_idmap.push_back(entry);
+ }
+
+ // update persistent dirty flag if it's a persistent value
+ if (entry->IsPersistent()) m_persistent_dirty = true;
+
+ // notify (for local listeners)
+ if (m_notifier.local_notifiers()) {
+ if (!old_value)
+ m_notifier.NotifyEntry(name, value, NT_NOTIFY_NEW | NT_NOTIFY_LOCAL);
+ else
+ m_notifier.NotifyEntry(name, value, NT_NOTIFY_UPDATE | NT_NOTIFY_LOCAL);
+ }
+
+ // generate message
+ if (!m_queue_outgoing) return;
+ auto queue_outgoing = m_queue_outgoing;
+ if (!old_value || old_value->type() != value->type()) {
+ ++entry->seq_num;
+ auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
+ value, entry->flags);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ } else {
+ ++entry->seq_num;
+ // don't send an update if we don't have an assigned id yet
+ if (entry->id != 0xffff) {
+ auto msg =
+ Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ }
+ }
+}
+
+void Storage::SetEntryFlags(StringRef name, unsigned int flags) {
+ if (name.empty()) return;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ auto i = m_entries.find(name);
+ if (i == m_entries.end()) return;
+ Entry* entry = i->getValue().get();
+ if (entry->flags == flags) return;
+
+ // update persistent dirty flag if persistent flag changed
+ if ((entry->flags & NT_PERSISTENT) != (flags & NT_PERSISTENT))
+ m_persistent_dirty = true;
+
+ entry->flags = flags;
+
+ // notify
+ m_notifier.NotifyEntry(name, entry->value, NT_NOTIFY_FLAGS | NT_NOTIFY_LOCAL);
+
+ // generate message
+ if (!m_queue_outgoing) return;
+ auto queue_outgoing = m_queue_outgoing;
+ unsigned int id = entry->id;
+ // don't send an update if we don't have an assigned id yet
+ if (id != 0xffff) {
+ lock.unlock();
+ queue_outgoing(Message::FlagsUpdate(id, flags), nullptr, nullptr);
+ }
+}
+
+unsigned int Storage::GetEntryFlags(StringRef name) const {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ auto i = m_entries.find(name);
+ return i == m_entries.end() ? 0 : i->getValue()->flags;
+}
+
+void Storage::DeleteEntry(StringRef name) {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ auto i = m_entries.find(name);
+ if (i == m_entries.end()) return;
+ auto entry = std::move(i->getValue());
+ unsigned int id = entry->id;
+
+ // update persistent dirty flag if it's a persistent value
+ if (entry->IsPersistent()) m_persistent_dirty = true;
+
+ m_entries.erase(i); // erase from map
+ if (id < m_idmap.size()) m_idmap[id] = nullptr;
+
+ if (!entry->value) return;
+
+ // notify
+ m_notifier.NotifyEntry(name, entry->value,
+ NT_NOTIFY_DELETE | NT_NOTIFY_LOCAL);
+
+ // if it had a value, generate message
+ // don't send an update if we don't have an assigned id yet
+ if (id != 0xffff) {
+ if (!m_queue_outgoing) return;
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ queue_outgoing(Message::EntryDelete(id), nullptr, nullptr);
+ }
+}
+
+void Storage::DeleteAllEntries() {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ if (m_entries.empty()) return;
+ EntriesMap map;
+ m_entries.swap(map);
+ m_idmap.resize(0);
+
+ // set persistent dirty flag
+ m_persistent_dirty = true;
+
+ // notify
+ if (m_notifier.local_notifiers()) {
+ for (auto& entry : map)
+ m_notifier.NotifyEntry(entry.getKey(), entry.getValue()->value,
+ NT_NOTIFY_DELETE | NT_NOTIFY_LOCAL);
+ }
+
+ // generate message
+ if (!m_queue_outgoing) return;
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ queue_outgoing(Message::ClearEntries(), nullptr, nullptr);
+}
+
+std::vector<EntryInfo> Storage::GetEntryInfo(StringRef prefix,
+ unsigned int types) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ std::vector<EntryInfo> infos;
+ for (auto& i : m_entries) {
+ if (!i.getKey().startswith(prefix)) continue;
+ Entry* entry = i.getValue().get();
+ auto value = entry->value;
+ if (!value) continue;
+ if (types != 0 && (types & value->type()) == 0) continue;
+ EntryInfo info;
+ info.name = i.getKey();
+ info.type = value->type();
+ info.flags = entry->flags;
+ info.last_change = value->last_change();
+ infos.push_back(std::move(info));
+ }
+ return infos;
+}
+
+void Storage::NotifyEntries(StringRef prefix,
+ EntryListenerCallback only) const {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ for (auto& i : m_entries) {
+ if (!i.getKey().startswith(prefix)) continue;
+ m_notifier.NotifyEntry(i.getKey(), i.getValue()->value, NT_NOTIFY_IMMEDIATE,
+ only);
+ }
+}
+
+/* Escapes and writes a string, including start and end double quotes */
+static void WriteString(std::ostream& os, llvm::StringRef str) {
+ os << '"';
+ for (auto c : str) {
+ switch (c) {
+ case '\\':
+ os << "\\\\";
+ break;
+ case '\t':
+ os << "\\t";
+ break;
+ case '\n':
+ os << "\\n";
+ break;
+ case '"':
+ os << "\\\"";
+ break;
+ default:
+ if (std::isprint(c)) {
+ os << c;
+ break;
+ }
+
+ // Write out the escaped representation.
+ os << "\\x";
+ os << llvm::hexdigit((c >> 4) & 0xF);
+ os << llvm::hexdigit((c >> 0) & 0xF);
+ }
+ }
+ os << '"';
+}
+
+bool Storage::GetPersistentEntries(
+ bool periodic,
+ std::vector<std::pair<std::string, std::shared_ptr<Value>>>* entries)
+ const {
+ // copy values out of storage as quickly as possible so lock isn't held
+ {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ // for periodic, don't re-save unless something has changed
+ if (periodic && !m_persistent_dirty) return false;
+ m_persistent_dirty = false;
+ entries->reserve(m_entries.size());
+ for (auto& i : m_entries) {
+ Entry* entry = i.getValue().get();
+ // only write persistent-flagged values
+ if (!entry->IsPersistent()) continue;
+ entries->emplace_back(i.getKey(), entry->value);
+ }
+ }
+
+ // sort in name order
+ std::sort(entries->begin(), entries->end(),
+ [](const std::pair<std::string, std::shared_ptr<Value>>& a,
+ const std::pair<std::string, std::shared_ptr<Value>>& b) {
+ return a.first < b.first;
+ });
+ return true;
+}
+
+static void SavePersistentImpl(
+ std::ostream& os,
+ llvm::ArrayRef<std::pair<std::string, std::shared_ptr<Value>>> entries) {
+ std::string base64_encoded;
+
+ // header
+ os << "[NetworkTables Storage 3.0]\n";
+
+ for (auto& i : entries) {
+ // type
+ auto v = i.second;
+ if (!v) continue;
+ switch (v->type()) {
+ case NT_BOOLEAN:
+ os << "boolean ";
+ break;
+ case NT_DOUBLE:
+ os << "double ";
+ break;
+ case NT_STRING:
+ os << "string ";
+ break;
+ case NT_RAW:
+ os << "raw ";
+ break;
+ case NT_BOOLEAN_ARRAY:
+ os << "array boolean ";
+ break;
+ case NT_DOUBLE_ARRAY:
+ os << "array double ";
+ break;
+ case NT_STRING_ARRAY:
+ os << "array string ";
+ break;
+ default:
+ continue;
+ }
+
+ // name
+ WriteString(os, i.first);
+
+ // =
+ os << '=';
+
+ // value
+ switch (v->type()) {
+ case NT_BOOLEAN:
+ os << (v->GetBoolean() ? "true" : "false");
+ break;
+ case NT_DOUBLE:
+ os << v->GetDouble();
+ break;
+ case NT_STRING:
+ WriteString(os, v->GetString());
+ break;
+ case NT_RAW:
+ Base64Encode(v->GetRaw(), &base64_encoded);
+ os << base64_encoded;
+ break;
+ case NT_BOOLEAN_ARRAY: {
+ bool first = true;
+ for (auto elem : v->GetBooleanArray()) {
+ if (!first) os << ',';
+ first = false;
+ os << (elem ? "true" : "false");
+ }
+ break;
+ }
+ case NT_DOUBLE_ARRAY: {
+ bool first = true;
+ for (auto elem : v->GetDoubleArray()) {
+ if (!first) os << ',';
+ first = false;
+ os << elem;
+ }
+ break;
+ }
+ case NT_STRING_ARRAY: {
+ bool first = true;
+ for (auto& elem : v->GetStringArray()) {
+ if (!first) os << ',';
+ first = false;
+ WriteString(os, elem);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
+ // eol
+ os << '\n';
+ }
+}
+
+void Storage::SavePersistent(std::ostream& os, bool periodic) const {
+ std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
+ if (!GetPersistentEntries(periodic, &entries)) return;
+ SavePersistentImpl(os, entries);
+}
+
+const char* Storage::SavePersistent(StringRef filename, bool periodic) const {
+ std::string fn = filename;
+ std::string tmp = filename;
+ tmp += ".tmp";
+ std::string bak = filename;
+ bak += ".bak";
+
+ // Get entries before creating file
+ std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
+ if (!GetPersistentEntries(periodic, &entries)) return nullptr;
+
+ const char* err = nullptr;
+
+ // start by writing to temporary file
+ std::ofstream os(tmp);
+ if (!os) {
+ err = "could not open file";
+ goto done;
+ }
+ DEBUG("saving persistent file '" << filename << "'");
+ SavePersistentImpl(os, entries);
+ os.flush();
+ if (!os) {
+ os.close();
+ std::remove(tmp.c_str());
+ err = "error saving file";
+ goto done;
+ }
+ os.close();
+
+ // Safely move to real file. We ignore any failures related to the backup.
+ std::remove(bak.c_str());
+ std::rename(fn.c_str(), bak.c_str());
+ if (std::rename(tmp.c_str(), fn.c_str()) != 0) {
+ std::rename(bak.c_str(), fn.c_str()); // attempt to restore backup
+ err = "could not rename temp file to real file";
+ goto done;
+ }
+
+done:
+ // try again if there was an error
+ if (err && periodic) m_persistent_dirty = true;
+ return err;
+}
+
+/* Extracts an escaped string token. Does not unescape the string.
+ * If a string cannot be matched, an empty string is returned.
+ * If the string is unterminated, an empty tail string is returned.
+ * The returned token includes the starting and trailing quotes (unless the
+ * string is unterminated).
+ * Returns a pair containing the extracted token (if any) and the remaining
+ * tail string.
+ */
+static std::pair<llvm::StringRef, llvm::StringRef> ReadStringToken(
+ llvm::StringRef source) {
+ // Match opening quote
+ if (source.empty() || source.front() != '"')
+ return std::make_pair(llvm::StringRef(), source);
+
+ // Scan for ending double quote, checking for escaped as we go.
+ std::size_t size = source.size();
+ std::size_t pos;
+ for (pos = 1; pos < size; ++pos) {
+ if (source[pos] == '"' && source[pos - 1] != '\\') {
+ ++pos; // we want to include the trailing quote in the result
+ break;
+ }
+ }
+ return std::make_pair(source.slice(0, pos), source.substr(pos));
+}
+
+static int fromxdigit(char ch) {
+ if (ch >= 'a' && ch <= 'f')
+ return (ch - 'a' + 10);
+ else if (ch >= 'A' && ch <= 'F')
+ return (ch - 'A' + 10);
+ else
+ return ch - '0';
+}
+
+static void UnescapeString(llvm::StringRef source, std::string* dest) {
+ assert(source.size() >= 2 && source.front() == '"' && source.back() == '"');
+ dest->clear();
+ dest->reserve(source.size() - 2);
+ for (auto s = source.begin() + 1, end = source.end() - 1; s != end; ++s) {
+ if (*s != '\\') {
+ dest->push_back(*s);
+ continue;
+ }
+ switch (*++s) {
+ case '\\':
+ case '"':
+ dest->push_back(s[-1]);
+ break;
+ case 't':
+ dest->push_back('\t');
+ break;
+ case 'n':
+ dest->push_back('\n');
+ break;
+ case 'x': {
+ if (!isxdigit(*(s+1))) {
+ dest->push_back('x'); // treat it like a unknown escape
+ break;
+ }
+ int ch = fromxdigit(*++s);
+ if (isxdigit(*(s+1))) {
+ ch <<= 4;
+ ch |= fromxdigit(*++s);
+ }
+ dest->push_back(static_cast<char>(ch));
+ break;
+ }
+ default:
+ dest->push_back(s[-1]);
+ break;
+ }
+ }
+}
+
+bool Storage::LoadPersistent(
+ std::istream& is,
+ std::function<void(std::size_t line, const char* msg)> warn) {
+ std::string line_str;
+ std::size_t line_num = 1;
+
+ // entries to add
+ std::vector<std::pair<std::string, std::shared_ptr<Value>>> entries;
+
+ // declare these outside the loop to reduce reallocs
+ std::string name, str;
+ std::vector<int> boolean_array;
+ std::vector<double> double_array;
+ std::vector<std::string> string_array;
+
+ // ignore blank lines and lines that start with ; or # (comments)
+ while (std::getline(is, line_str)) {
+ llvm::StringRef line = llvm::StringRef(line_str).trim();
+ if (!line.empty() && line.front() != ';' && line.front() != '#')
+ break;
+ }
+
+ // header
+ if (line_str != "[NetworkTables Storage 3.0]") {
+ if (warn) warn(line_num, "header line mismatch, ignoring rest of file");
+ return false;
+ }
+
+ while (std::getline(is, line_str)) {
+ llvm::StringRef line = llvm::StringRef(line_str).trim();
+ ++line_num;
+
+ // ignore blank lines and lines that start with ; or # (comments)
+ if (line.empty() || line.front() == ';' || line.front() == '#')
+ continue;
+
+ // type
+ llvm::StringRef type_tok;
+ std::tie(type_tok, line) = line.split(' ');
+ NT_Type type = NT_UNASSIGNED;
+ if (type_tok == "boolean") type = NT_BOOLEAN;
+ else if (type_tok == "double") type = NT_DOUBLE;
+ else if (type_tok == "string") type = NT_STRING;
+ else if (type_tok == "raw") type = NT_RAW;
+ else if (type_tok == "array") {
+ llvm::StringRef array_tok;
+ std::tie(array_tok, line) = line.split(' ');
+ if (array_tok == "boolean") type = NT_BOOLEAN_ARRAY;
+ else if (array_tok == "double") type = NT_DOUBLE_ARRAY;
+ else if (array_tok == "string") type = NT_STRING_ARRAY;
+ }
+ if (type == NT_UNASSIGNED) {
+ if (warn) warn(line_num, "unrecognized type");
+ continue;
+ }
+
+ // name
+ llvm::StringRef name_tok;
+ std::tie(name_tok, line) = ReadStringToken(line);
+ if (name_tok.empty()) {
+ if (warn) warn(line_num, "missing name");
+ continue;
+ }
+ if (name_tok.back() != '"') {
+ if (warn) warn(line_num, "unterminated name string");
+ continue;
+ }
+ UnescapeString(name_tok, &name);
+
+ // =
+ line = line.ltrim(" \t");
+ if (line.empty() || line.front() != '=') {
+ if (warn) warn(line_num, "expected = after name");
+ continue;
+ }
+ line = line.drop_front().ltrim(" \t");
+
+ // value
+ std::shared_ptr<Value> value;
+ switch (type) {
+ case NT_BOOLEAN:
+ // only true or false is accepted
+ if (line == "true")
+ value = Value::MakeBoolean(true);
+ else if (line == "false")
+ value = Value::MakeBoolean(false);
+ else {
+ if (warn)
+ warn(line_num, "unrecognized boolean value, not 'true' or 'false'");
+ goto next_line;
+ }
+ break;
+ case NT_DOUBLE: {
+ // need to convert to null-terminated string for strtod()
+ str.clear();
+ str += line;
+ char* end;
+ double v = std::strtod(str.c_str(), &end);
+ if (*end != '\0') {
+ if (warn) warn(line_num, "invalid double value");
+ goto next_line;
+ }
+ value = Value::MakeDouble(v);
+ break;
+ }
+ case NT_STRING: {
+ llvm::StringRef str_tok;
+ std::tie(str_tok, line) = ReadStringToken(line);
+ if (str_tok.empty()) {
+ if (warn) warn(line_num, "missing string value");
+ goto next_line;
+ }
+ if (str_tok.back() != '"') {
+ if (warn) warn(line_num, "unterminated string value");
+ goto next_line;
+ }
+ UnescapeString(str_tok, &str);
+ value = Value::MakeString(std::move(str));
+ break;
+ }
+ case NT_RAW:
+ Base64Decode(line, &str);
+ value = Value::MakeRaw(std::move(str));
+ break;
+ case NT_BOOLEAN_ARRAY: {
+ llvm::StringRef elem_tok;
+ boolean_array.clear();
+ while (!line.empty()) {
+ std::tie(elem_tok, line) = line.split(',');
+ elem_tok = elem_tok.trim(" \t");
+ if (elem_tok == "true")
+ boolean_array.push_back(1);
+ else if (elem_tok == "false")
+ boolean_array.push_back(0);
+ else {
+ if (warn)
+ warn(line_num,
+ "unrecognized boolean value, not 'true' or 'false'");
+ goto next_line;
+ }
+ }
+
+ value = Value::MakeBooleanArray(std::move(boolean_array));
+ break;
+ }
+ case NT_DOUBLE_ARRAY: {
+ llvm::StringRef elem_tok;
+ double_array.clear();
+ while (!line.empty()) {
+ std::tie(elem_tok, line) = line.split(',');
+ elem_tok = elem_tok.trim(" \t");
+ // need to convert to null-terminated string for strtod()
+ str.clear();
+ str += elem_tok;
+ char* end;
+ double v = std::strtod(str.c_str(), &end);
+ if (*end != '\0') {
+ if (warn) warn(line_num, "invalid double value");
+ goto next_line;
+ }
+ double_array.push_back(v);
+ }
+
+ value = Value::MakeDoubleArray(std::move(double_array));
+ break;
+ }
+ case NT_STRING_ARRAY: {
+ llvm::StringRef elem_tok;
+ string_array.clear();
+ while (!line.empty()) {
+ std::tie(elem_tok, line) = ReadStringToken(line);
+ if (elem_tok.empty()) {
+ if (warn) warn(line_num, "missing string value");
+ goto next_line;
+ }
+ if (elem_tok.back() != '"') {
+ if (warn) warn(line_num, "unterminated string value");
+ goto next_line;
+ }
+
+ UnescapeString(elem_tok, &str);
+ string_array.push_back(std::move(str));
+
+ line = line.ltrim(" \t");
+ if (line.empty()) break;
+ if (line.front() != ',') {
+ if (warn) warn(line_num, "expected comma between strings");
+ goto next_line;
+ }
+ line = line.drop_front().ltrim(" \t");
+ }
+
+ value = Value::MakeStringArray(std::move(string_array));
+ break;
+ }
+ default:
+ break;
+ }
+ if (!name.empty() && value)
+ entries.push_back(std::make_pair(std::move(name), std::move(value)));
+next_line:
+ ;
+ }
+
+ // copy values into storage as quickly as possible so lock isn't held
+ {
+ std::vector<std::shared_ptr<Message>> msgs;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ for (auto& i : entries) {
+ auto& new_entry = m_entries[i.first];
+ if (!new_entry) new_entry.reset(new Entry(i.first));
+ Entry* entry = new_entry.get();
+ auto old_value = entry->value;
+ entry->value = i.second;
+ bool was_persist = entry->IsPersistent();
+ if (!was_persist) entry->flags |= NT_PERSISTENT;
+
+ // if we're the server, assign an id if it doesn't have one
+ if (m_server && entry->id == 0xffff) {
+ unsigned int id = m_idmap.size();
+ entry->id = id;
+ m_idmap.push_back(entry);
+ }
+
+ // notify (for local listeners)
+ if (m_notifier.local_notifiers()) {
+ if (!old_value)
+ m_notifier.NotifyEntry(i.first, i.second,
+ NT_NOTIFY_NEW | NT_NOTIFY_LOCAL);
+ else if (*old_value != *i.second) {
+ unsigned int notify_flags = NT_NOTIFY_UPDATE | NT_NOTIFY_LOCAL;
+ if (!was_persist) notify_flags |= NT_NOTIFY_FLAGS;
+ m_notifier.NotifyEntry(i.first, i.second, notify_flags);
+ }
+ }
+
+ if (!m_queue_outgoing) continue; // shortcut
+ ++entry->seq_num;
+
+ // put on update queue
+ if (!old_value || old_value->type() != i.second->type())
+ msgs.emplace_back(Message::EntryAssign(i.first, entry->id,
+ entry->seq_num.value(),
+ i.second, entry->flags));
+ else if (entry->id != 0xffff) {
+ // don't send an update if we don't have an assigned id yet
+ if (*old_value != *i.second)
+ msgs.emplace_back(Message::EntryUpdate(
+ entry->id, entry->seq_num.value(), i.second));
+ if (!was_persist)
+ msgs.emplace_back(Message::FlagsUpdate(entry->id, entry->flags));
+ }
+ }
+
+ if (m_queue_outgoing) {
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ for (auto& msg : msgs) queue_outgoing(std::move(msg), nullptr, nullptr);
+ }
+ }
+
+ return true;
+}
+
+const char* Storage::LoadPersistent(
+ StringRef filename,
+ std::function<void(std::size_t line, const char* msg)> warn) {
+ std::ifstream is(filename);
+ if (!is) return "could not open file";
+ if (!LoadPersistent(is, warn)) return "error reading file";
+ return nullptr;
+}
+
+void Storage::CreateRpc(StringRef name, StringRef def, RpcCallback callback) {
+ if (name.empty() || def.empty() || !callback) return;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ if (!m_server) return; // only server can create RPCs
+
+ auto& new_entry = m_entries[name];
+ if (!new_entry) new_entry.reset(new Entry(name));
+ Entry* entry = new_entry.get();
+ auto old_value = entry->value;
+ auto value = Value::MakeRpc(def);
+ entry->value = value;
+
+ // set up the new callback
+ entry->rpc_callback = callback;
+
+ // start the RPC server
+ if (!m_rpc_server.active()) m_rpc_server.Start();
+
+ if (old_value && *old_value == *value) return;
+
+ // assign an id if it doesn't have one
+ if (entry->id == 0xffff) {
+ unsigned int id = m_idmap.size();
+ entry->id = id;
+ m_idmap.push_back(entry);
+ }
+
+ // generate message
+ if (!m_queue_outgoing) return;
+ auto queue_outgoing = m_queue_outgoing;
+ if (!old_value || old_value->type() != value->type()) {
+ ++entry->seq_num;
+ auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
+ value, entry->flags);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ } else {
+ ++entry->seq_num;
+ auto msg = Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ }
+}
+
+void Storage::CreatePolledRpc(StringRef name, StringRef def) {
+ if (name.empty() || def.empty()) return;
+ std::unique_lock<std::mutex> lock(m_mutex);
+ if (!m_server) return; // only server can create RPCs
+
+ auto& new_entry = m_entries[name];
+ if (!new_entry) new_entry.reset(new Entry(name));
+ Entry* entry = new_entry.get();
+ auto old_value = entry->value;
+ auto value = Value::MakeRpc(def);
+ entry->value = value;
+
+ // a nullptr callback indicates a polled RPC
+ entry->rpc_callback = nullptr;
+
+ if (old_value && *old_value == *value) return;
+
+ // assign an id if it doesn't have one
+ if (entry->id == 0xffff) {
+ unsigned int id = m_idmap.size();
+ entry->id = id;
+ m_idmap.push_back(entry);
+ }
+
+ // generate message
+ if (!m_queue_outgoing) return;
+ auto queue_outgoing = m_queue_outgoing;
+ if (!old_value || old_value->type() != value->type()) {
+ ++entry->seq_num;
+ auto msg = Message::EntryAssign(name, entry->id, entry->seq_num.value(),
+ value, entry->flags);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ } else {
+ ++entry->seq_num;
+ auto msg = Message::EntryUpdate(entry->id, entry->seq_num.value(), value);
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ }
+}
+
+unsigned int Storage::CallRpc(StringRef name, StringRef params) {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ auto i = m_entries.find(name);
+ if (i == m_entries.end()) return 0;
+ auto& entry = i->getValue();
+ if (!entry->value->IsRpc()) return 0;
+
+ ++entry->rpc_call_uid;
+ if (entry->rpc_call_uid > 0xffff) entry->rpc_call_uid = 0;
+ unsigned int combined_uid = (entry->id << 16) | entry->rpc_call_uid;
+ auto msg = Message::ExecuteRpc(entry->id, entry->rpc_call_uid, params);
+ if (m_server) {
+ // RPCs are unlikely to be used locally on the server, but handle it
+ // gracefully anyway.
+ auto rpc_callback = entry->rpc_callback;
+ lock.unlock();
+ m_rpc_server.ProcessRpc(
+ name, msg, rpc_callback, 0xffffU, [this](std::shared_ptr<Message> msg) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ m_rpc_results.insert(std::make_pair(
+ std::make_pair(msg->id(), msg->seq_num_uid()), msg->str()));
+ m_rpc_results_cond.notify_all();
+ });
+ } else {
+ auto queue_outgoing = m_queue_outgoing;
+ lock.unlock();
+ queue_outgoing(msg, nullptr, nullptr);
+ }
+ return combined_uid;
+}
+
+bool Storage::GetRpcResult(bool blocking, unsigned int call_uid,
+ std::string* result) {
+ std::unique_lock<std::mutex> lock(m_mutex);
+ for (;;) {
+ auto i =
+ m_rpc_results.find(std::make_pair(call_uid >> 16, call_uid & 0xffff));
+ if (i == m_rpc_results.end()) {
+ if (!blocking || m_terminating) return false;
+ m_rpc_results_cond.wait(lock);
+ if (m_terminating) return false;
+ continue;
+ }
+ result->swap(i->getSecond());
+ m_rpc_results.erase(i);
+ return true;
+ }
+}
diff --git a/src/Storage.h b/src/Storage.h
new file mode 100644
index 0000000..c87a37b
--- /dev/null
+++ b/src/Storage.h
@@ -0,0 +1,173 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_STORAGE_H_
+#define NT_STORAGE_H_
+
+#include <atomic>
+#include <cstddef>
+#include <fstream>
+#include <functional>
+#include <iosfwd>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+#include "llvm/DenseMap.h"
+#include "llvm/StringMap.h"
+#include "atomic_static.h"
+#include "Message.h"
+#include "Notifier.h"
+#include "ntcore_cpp.h"
+#include "RpcServer.h"
+#include "SequenceNumber.h"
+
+namespace nt {
+
+class NetworkConnection;
+class StorageTest;
+
+class Storage {
+ friend class StorageTest;
+ public:
+ static Storage& GetInstance() {
+ ATOMIC_STATIC(Storage, instance);
+ return instance;
+ }
+ ~Storage();
+
+ // Accessors required by Dispatcher. A function pointer is used for
+ // generation of outgoing messages to break a dependency loop between
+ // Storage and Dispatcher; in operation this is always set to
+ // Dispatcher::QueueOutgoing.
+ typedef std::function<void(std::shared_ptr<Message> msg,
+ NetworkConnection* only,
+ NetworkConnection* except)> QueueOutgoingFunc;
+ void SetOutgoing(QueueOutgoingFunc queue_outgoing, bool server);
+ void ClearOutgoing();
+
+ // Required for wire protocol 2.0 to get the entry type of an entry when
+ // receiving entry updates (because the length/type is not provided in the
+ // message itself). Not used in wire protocol 3.0.
+ NT_Type GetEntryType(unsigned int id) const;
+
+ void ProcessIncoming(std::shared_ptr<Message> msg, NetworkConnection* conn,
+ std::weak_ptr<NetworkConnection> conn_weak);
+ void GetInitialAssignments(NetworkConnection& conn,
+ std::vector<std::shared_ptr<Message>>* msgs);
+ void ApplyInitialAssignments(NetworkConnection& conn,
+ llvm::ArrayRef<std::shared_ptr<Message>> msgs,
+ bool new_server,
+ std::vector<std::shared_ptr<Message>>* out_msgs);
+
+ // User functions. These are the actual implementations of the corresponding
+ // user API functions in ntcore_cpp.
+ std::shared_ptr<Value> GetEntryValue(StringRef name) const;
+ bool SetEntryValue(StringRef name, std::shared_ptr<Value> value);
+ void SetEntryTypeValue(StringRef name, std::shared_ptr<Value> value);
+ void SetEntryFlags(StringRef name, unsigned int flags);
+ unsigned int GetEntryFlags(StringRef name) const;
+ void DeleteEntry(StringRef name);
+ void DeleteAllEntries();
+ std::vector<EntryInfo> GetEntryInfo(StringRef prefix, unsigned int types);
+ void NotifyEntries(StringRef prefix,
+ EntryListenerCallback only = nullptr) const;
+
+ // Filename-based save/load functions. Used both by periodic saves and
+ // accessible directly via the user API.
+ const char* SavePersistent(StringRef filename, bool periodic) const;
+ const char* LoadPersistent(
+ StringRef filename,
+ std::function<void(std::size_t line, const char* msg)> warn);
+
+ // Stream-based save/load functions (exposed for testing purposes). These
+ // implement the guts of the filename-based functions.
+ void SavePersistent(std::ostream& os, bool periodic) const;
+ bool LoadPersistent(
+ std::istream& is,
+ std::function<void(std::size_t line, const char* msg)> warn);
+
+ // RPC configuration needs to come through here as RPC definitions are
+ // actually special Storage value types.
+ void CreateRpc(StringRef name, StringRef def, RpcCallback callback);
+ void CreatePolledRpc(StringRef name, StringRef def);
+
+ unsigned int CallRpc(StringRef name, StringRef params);
+ bool GetRpcResult(bool blocking, unsigned int call_uid, std::string* result);
+
+ private:
+ Storage();
+ Storage(Notifier& notifier, RpcServer& rpcserver);
+ Storage(const Storage&) = delete;
+ Storage& operator=(const Storage&) = delete;
+
+ // Data for each table entry.
+ struct Entry {
+ Entry(llvm::StringRef name_)
+ : name(name_), flags(0), id(0xffff), rpc_call_uid(0) {}
+ bool IsPersistent() const { return (flags & NT_PERSISTENT) != 0; }
+
+ // We redundantly store the name so that it's available when accessing the
+ // raw Entry* via the ID map.
+ std::string name;
+
+ // The current value and flags.
+ std::shared_ptr<Value> value;
+ unsigned int flags;
+
+ // Unique ID for this entry as used in network messages. The value is
+ // assigned by the server, so on the client this is 0xffff until an
+ // entry assignment is received back from the server.
+ unsigned int id;
+
+ // Sequence number for update resolution.
+ SequenceNumber seq_num;
+
+ // RPC callback function. Null if either not an RPC or if the RPC is
+ // polled.
+ RpcCallback rpc_callback;
+
+ // Last UID used when calling this RPC (primarily for client use). This
+ // is incremented for each call.
+ unsigned int rpc_call_uid;
+ };
+
+ typedef llvm::StringMap<std::unique_ptr<Entry>> EntriesMap;
+ typedef std::vector<Entry*> IdMap;
+ typedef llvm::DenseMap<std::pair<unsigned int, unsigned int>, std::string>
+ RpcResultMap;
+
+ mutable std::mutex m_mutex;
+ EntriesMap m_entries;
+ IdMap m_idmap;
+ RpcResultMap m_rpc_results;
+ // If any persistent values have changed
+ mutable bool m_persistent_dirty = false;
+
+ // condition variable and termination flag for blocking on a RPC result
+ std::atomic_bool m_terminating;
+ std::condition_variable m_rpc_results_cond;
+
+ // configured by dispatcher at startup
+ QueueOutgoingFunc m_queue_outgoing;
+ bool m_server = true;
+
+ // references to singletons (we don't grab them directly for testing purposes)
+ Notifier& m_notifier;
+ RpcServer& m_rpc_server;
+
+ bool GetPersistentEntries(
+ bool periodic,
+ std::vector<std::pair<std::string, std::shared_ptr<Value>>>* entries)
+ const;
+
+ ATOMIC_STATIC_DECL(Storage)
+};
+
+} // namespace nt
+
+#endif // NT_STORAGE_H_
diff --git a/src/Value.cpp b/src/Value.cpp
new file mode 100644
index 0000000..8095fa9
--- /dev/null
+++ b/src/Value.cpp
@@ -0,0 +1,210 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "nt_Value.h"
+#include "Value_internal.h"
+#include "support/timestamp.h"
+
+using namespace nt;
+
+Value::Value() {
+ m_val.type = NT_UNASSIGNED;
+ m_val.last_change = Now();
+}
+
+Value::Value(NT_Type type, const private_init&) {
+ m_val.type = type;
+ m_val.last_change = Now();
+ if (m_val.type == NT_BOOLEAN_ARRAY)
+ m_val.data.arr_boolean.arr = nullptr;
+ else if (m_val.type == NT_DOUBLE_ARRAY)
+ m_val.data.arr_double.arr = nullptr;
+ else if (m_val.type == NT_STRING_ARRAY)
+ m_val.data.arr_string.arr = nullptr;
+}
+
+Value::~Value() {
+ if (m_val.type == NT_BOOLEAN_ARRAY)
+ delete[] m_val.data.arr_boolean.arr;
+ else if (m_val.type == NT_DOUBLE_ARRAY)
+ delete[] m_val.data.arr_double.arr;
+ else if (m_val.type == NT_STRING_ARRAY)
+ delete[] m_val.data.arr_string.arr;
+}
+
+std::shared_ptr<Value> Value::MakeBooleanArray(llvm::ArrayRef<int> value) {
+ auto val = std::make_shared<Value>(NT_BOOLEAN_ARRAY, private_init());
+ val->m_val.data.arr_boolean.arr = new int[value.size()];
+ val->m_val.data.arr_boolean.size = value.size();
+ std::copy(value.begin(), value.end(), val->m_val.data.arr_boolean.arr);
+ return val;
+}
+
+std::shared_ptr<Value> Value::MakeDoubleArray(llvm::ArrayRef<double> value) {
+ auto val = std::make_shared<Value>(NT_DOUBLE_ARRAY, private_init());
+ val->m_val.data.arr_double.arr = new double[value.size()];
+ val->m_val.data.arr_double.size = value.size();
+ std::copy(value.begin(), value.end(), val->m_val.data.arr_double.arr);
+ return val;
+}
+
+std::shared_ptr<Value> Value::MakeStringArray(
+ llvm::ArrayRef<std::string> value) {
+ auto val = std::make_shared<Value>(NT_STRING_ARRAY, private_init());
+ val->m_string_array = value;
+ // point NT_Value to the contents in the vector.
+ val->m_val.data.arr_string.arr = new NT_String[value.size()];
+ val->m_val.data.arr_string.size = val->m_string_array.size();
+ for (std::size_t i=0; i<value.size(); ++i) {
+ val->m_val.data.arr_string.arr[i].str = const_cast<char*>(value[i].c_str());
+ val->m_val.data.arr_string.arr[i].len = value[i].size();
+ }
+ return val;
+}
+
+std::shared_ptr<Value> Value::MakeStringArray(
+ std::vector<std::string>&& value) {
+ auto val = std::make_shared<Value>(NT_STRING_ARRAY, private_init());
+ val->m_string_array = std::move(value);
+ value.clear();
+ // point NT_Value to the contents in the vector.
+ val->m_val.data.arr_string.arr = new NT_String[val->m_string_array.size()];
+ val->m_val.data.arr_string.size = val->m_string_array.size();
+ for (std::size_t i=0; i<val->m_string_array.size(); ++i) {
+ val->m_val.data.arr_string.arr[i].str =
+ const_cast<char*>(val->m_string_array[i].c_str());
+ val->m_val.data.arr_string.arr[i].len = val->m_string_array[i].size();
+ }
+ return val;
+}
+
+void nt::ConvertToC(const Value& in, NT_Value* out) {
+ out->type = NT_UNASSIGNED;
+ switch (in.type()) {
+ case NT_UNASSIGNED:
+ return;
+ case NT_BOOLEAN:
+ out->data.v_boolean = in.GetBoolean() ? 1 : 0;
+ break;
+ case NT_DOUBLE:
+ out->data.v_double = in.GetDouble();
+ break;
+ case NT_STRING:
+ ConvertToC(in.GetString(), &out->data.v_string);
+ break;
+ case NT_RAW:
+ ConvertToC(in.GetRaw(), &out->data.v_raw);
+ break;
+ case NT_RPC:
+ ConvertToC(in.GetRpc(), &out->data.v_raw);
+ break;
+ case NT_BOOLEAN_ARRAY: {
+ auto v = in.GetBooleanArray();
+ out->data.arr_boolean.arr =
+ static_cast<int*>(std::malloc(v.size() * sizeof(int)));
+ out->data.arr_boolean.size = v.size();
+ std::copy(v.begin(), v.end(), out->data.arr_boolean.arr);
+ break;
+ }
+ case NT_DOUBLE_ARRAY: {
+ auto v = in.GetDoubleArray();
+ out->data.arr_double.arr =
+ static_cast<double*>(std::malloc(v.size() * sizeof(double)));
+ out->data.arr_double.size = v.size();
+ std::copy(v.begin(), v.end(), out->data.arr_double.arr);
+ break;
+ }
+ case NT_STRING_ARRAY: {
+ auto v = in.GetStringArray();
+ out->data.arr_string.arr =
+ static_cast<NT_String*>(std::malloc(v.size()*sizeof(NT_String)));
+ for (size_t i = 0; i < v.size(); ++i)
+ ConvertToC(v[i], &out->data.arr_string.arr[i]);
+ out->data.arr_string.size = v.size();
+ break;
+ }
+ default:
+ // assert(false && "unknown value type");
+ return;
+ }
+ out->type = in.type();
+}
+
+void nt::ConvertToC(llvm::StringRef in, NT_String* out) {
+ out->len = in.size();
+ out->str = static_cast<char*>(std::malloc(in.size()+1));
+ std::memcpy(out->str, in.data(), in.size());
+ out->str[in.size()] = '\0';
+}
+
+std::shared_ptr<Value> nt::ConvertFromC(const NT_Value& value) {
+ switch (value.type) {
+ case NT_UNASSIGNED:
+ return nullptr;
+ case NT_BOOLEAN:
+ return Value::MakeBoolean(value.data.v_boolean != 0);
+ case NT_DOUBLE:
+ return Value::MakeDouble(value.data.v_double);
+ case NT_STRING:
+ return Value::MakeString(ConvertFromC(value.data.v_string));
+ case NT_RAW:
+ return Value::MakeRaw(ConvertFromC(value.data.v_raw));
+ case NT_RPC:
+ return Value::MakeRpc(ConvertFromC(value.data.v_raw));
+ case NT_BOOLEAN_ARRAY:
+ return Value::MakeBooleanArray(llvm::ArrayRef<int>(
+ value.data.arr_boolean.arr, value.data.arr_boolean.size));
+ case NT_DOUBLE_ARRAY:
+ return Value::MakeDoubleArray(llvm::ArrayRef<double>(
+ value.data.arr_double.arr, value.data.arr_double.size));
+ case NT_STRING_ARRAY: {
+ std::vector<std::string> v;
+ v.reserve(value.data.arr_string.size);
+ for (size_t i=0; i<value.data.arr_string.size; ++i)
+ v.push_back(ConvertFromC(value.data.arr_string.arr[i]));
+ return Value::MakeStringArray(std::move(v));
+ }
+ default:
+ // assert(false && "unknown value type");
+ return nullptr;
+ }
+}
+
+bool nt::operator==(const Value& lhs, const Value& rhs) {
+ if (lhs.type() != rhs.type()) return false;
+ switch (lhs.type()) {
+ case NT_UNASSIGNED:
+ return true; // XXX: is this better being false instead?
+ case NT_BOOLEAN:
+ return lhs.m_val.data.v_boolean == rhs.m_val.data.v_boolean;
+ case NT_DOUBLE:
+ return lhs.m_val.data.v_double == rhs.m_val.data.v_double;
+ case NT_STRING:
+ case NT_RAW:
+ case NT_RPC:
+ return lhs.m_string == rhs.m_string;
+ case NT_BOOLEAN_ARRAY:
+ if (lhs.m_val.data.arr_boolean.size != rhs.m_val.data.arr_boolean.size)
+ return false;
+ return std::memcmp(lhs.m_val.data.arr_boolean.arr,
+ rhs.m_val.data.arr_boolean.arr,
+ lhs.m_val.data.arr_boolean.size *
+ sizeof(lhs.m_val.data.arr_boolean.arr[0])) == 0;
+ case NT_DOUBLE_ARRAY:
+ if (lhs.m_val.data.arr_double.size != rhs.m_val.data.arr_double.size)
+ return false;
+ return std::memcmp(lhs.m_val.data.arr_double.arr,
+ rhs.m_val.data.arr_double.arr,
+ lhs.m_val.data.arr_double.size *
+ sizeof(lhs.m_val.data.arr_double.arr[0])) == 0;
+ case NT_STRING_ARRAY:
+ return lhs.m_string_array == rhs.m_string_array;
+ default:
+ // assert(false && "unknown value type");
+ return false;
+ }
+}
diff --git a/src/Value_internal.h b/src/Value_internal.h
new file mode 100644
index 0000000..f09748c
--- /dev/null
+++ b/src/Value_internal.h
@@ -0,0 +1,30 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_VALUE_INTERNAL_H_
+#define NT_VALUE_INTERNAL_H_
+
+#include <memory>
+#include <string>
+
+#include "llvm/StringRef.h"
+#include "ntcore_c.h"
+
+namespace nt {
+
+class Value;
+
+void ConvertToC(const Value& in, NT_Value* out);
+std::shared_ptr<Value> ConvertFromC(const NT_Value& value);
+void ConvertToC(llvm::StringRef in, NT_String* out);
+inline llvm::StringRef ConvertFromC(const NT_String& str) {
+ return llvm::StringRef(str.str, str.len);
+}
+
+} // namespace nt
+
+#endif // NT_VALUE_INTERNAL_H_
diff --git a/src/WireDecoder.cpp b/src/WireDecoder.cpp
new file mode 100644
index 0000000..138225e
--- /dev/null
+++ b/src/WireDecoder.cpp
@@ -0,0 +1,206 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "WireDecoder.h"
+
+#include <cassert>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+
+#include "llvm/MathExtras.h"
+#include "leb128.h"
+
+using namespace nt;
+
+static double ReadDouble(const char*& buf) {
+ // Fast but non-portable!
+ std::uint64_t val = (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ val <<= 8;
+ val |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ val <<= 8;
+ val |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ val <<= 8;
+ val |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ val <<= 8;
+ val |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ val <<= 8;
+ val |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ val <<= 8;
+ val |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ val <<= 8;
+ val |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ return llvm::BitsToDouble(val);
+}
+
+WireDecoder::WireDecoder(raw_istream& is, unsigned int proto_rev) : m_is(is) {
+ // Start with a 1K temporary buffer. Use malloc instead of new so we can
+ // realloc.
+ m_allocated = 1024;
+ m_buf = static_cast<char*>(std::malloc(m_allocated));
+ m_proto_rev = proto_rev;
+ m_error = nullptr;
+}
+
+WireDecoder::~WireDecoder() { std::free(m_buf); }
+
+bool WireDecoder::ReadDouble(double* val) {
+ const char* buf;
+ if (!Read(&buf, 8)) return false;
+ *val = ::ReadDouble(buf);
+ return true;
+}
+
+void WireDecoder::Realloc(std::size_t len) {
+ // Double current buffer size until we have enough space.
+ if (m_allocated >= len) return;
+ std::size_t newlen = m_allocated * 2;
+ while (newlen < len) newlen *= 2;
+ m_buf = static_cast<char*>(std::realloc(m_buf, newlen));
+ m_allocated = newlen;
+}
+
+bool WireDecoder::ReadType(NT_Type* type) {
+ unsigned int itype;
+ if (!Read8(&itype)) return false;
+ // Convert from byte value to enum
+ switch (itype) {
+ case 0x00:
+ *type = NT_BOOLEAN;
+ break;
+ case 0x01:
+ *type = NT_DOUBLE;
+ break;
+ case 0x02:
+ *type = NT_STRING;
+ break;
+ case 0x03:
+ *type = NT_RAW;
+ break;
+ case 0x10:
+ *type = NT_BOOLEAN_ARRAY;
+ break;
+ case 0x11:
+ *type = NT_DOUBLE_ARRAY;
+ break;
+ case 0x12:
+ *type = NT_STRING_ARRAY;
+ break;
+ case 0x20:
+ *type = NT_RPC;
+ break;
+ default:
+ *type = NT_UNASSIGNED;
+ m_error = "unrecognized value type";
+ return false;
+ }
+ return true;
+}
+
+std::shared_ptr<Value> WireDecoder::ReadValue(NT_Type type) {
+ switch (type) {
+ case NT_BOOLEAN: {
+ unsigned int v;
+ if (!Read8(&v)) return nullptr;
+ return Value::MakeBoolean(v != 0);
+ }
+ case NT_DOUBLE: {
+ double v;
+ if (!ReadDouble(&v)) return nullptr;
+ return Value::MakeDouble(v);
+ }
+ case NT_STRING: {
+ std::string v;
+ if (!ReadString(&v)) return nullptr;
+ return Value::MakeString(std::move(v));
+ }
+ case NT_RAW: {
+ if (m_proto_rev < 0x0300u) {
+ m_error = "received raw value in protocol < 3.0";
+ return nullptr;
+ }
+ std::string v;
+ if (!ReadString(&v)) return nullptr;
+ return Value::MakeRaw(std::move(v));
+ }
+ case NT_RPC: {
+ if (m_proto_rev < 0x0300u) {
+ m_error = "received RPC value in protocol < 3.0";
+ return nullptr;
+ }
+ std::string v;
+ if (!ReadString(&v)) return nullptr;
+ return Value::MakeRpc(std::move(v));
+ }
+ case NT_BOOLEAN_ARRAY: {
+ // size
+ unsigned int size;
+ if (!Read8(&size)) return nullptr;
+
+ // array values
+ const char* buf;
+ if (!Read(&buf, size)) return nullptr;
+ std::vector<int> v(size);
+ for (unsigned int i = 0; i < size; ++i)
+ v[i] = buf[i] ? 1 : 0;
+ return Value::MakeBooleanArray(std::move(v));
+ }
+ case NT_DOUBLE_ARRAY: {
+ // size
+ unsigned int size;
+ if (!Read8(&size)) return nullptr;
+
+ // array values
+ const char* buf;
+ if (!Read(&buf, size * 8)) return nullptr;
+ std::vector<double> v(size);
+ for (unsigned int i = 0; i < size; ++i)
+ v[i] = ::ReadDouble(buf);
+ return Value::MakeDoubleArray(std::move(v));
+ }
+ case NT_STRING_ARRAY: {
+ // size
+ unsigned int size;
+ if (!Read8(&size)) return nullptr;
+
+ // array values
+ std::vector<std::string> v(size);
+ for (unsigned int i = 0; i < size; ++i) {
+ if (!ReadString(&v[i])) return nullptr;
+ }
+ return Value::MakeStringArray(std::move(v));
+ }
+ default:
+ m_error = "invalid type when trying to read value";
+ return nullptr;
+ }
+}
+
+bool WireDecoder::ReadString(std::string* str) {
+ size_t len;
+ if (m_proto_rev < 0x0300u) {
+ unsigned int v;
+ if (!Read16(&v)) return false;
+ len = v;
+ } else {
+ unsigned long v;
+ if (!ReadUleb128(&v)) return false;
+ len = v;
+ }
+ const char* buf;
+ if (!Read(&buf, len)) return false;
+ *str = llvm::StringRef(buf, len);
+ return true;
+}
diff --git a/src/WireDecoder.h b/src/WireDecoder.h
new file mode 100644
index 0000000..c520be7
--- /dev/null
+++ b/src/WireDecoder.h
@@ -0,0 +1,149 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_WIREDECODER_H_
+#define NT_WIREDECODER_H_
+
+#include <cstddef>
+
+#include "nt_Value.h"
+#include "leb128.h"
+//#include "Log.h"
+#include "raw_istream.h"
+
+namespace nt {
+
+/* Decodes network data into native representation.
+ * This class is designed to read from a raw_istream, which provides a blocking
+ * read interface. There are no provisions in this class for resuming a read
+ * that was interrupted partway. Read functions return false if
+ * raw_istream.read() returned false (indicating the end of the input data
+ * stream).
+ */
+class WireDecoder {
+ public:
+ explicit WireDecoder(raw_istream& is, unsigned int proto_rev);
+ ~WireDecoder();
+
+ void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; }
+
+ /* Get the active protocol revision. */
+ unsigned int proto_rev() const { return m_proto_rev; }
+
+ /* Clears error indicator. */
+ void Reset() { m_error = nullptr; }
+
+ /* Returns error indicator (a string describing the error). Returns nullptr
+ * if no error has occurred.
+ */
+ const char* error() const { return m_error; }
+
+ void set_error(const char* error) { m_error = error; }
+
+ /* Reads the specified number of bytes.
+ * @param buf pointer to read data (output parameter)
+ * @param len number of bytes to read
+ * Caution: the buffer is only temporarily valid.
+ */
+ bool Read(const char** buf, std::size_t len) {
+ if (len > m_allocated) Realloc(len);
+ *buf = m_buf;
+ bool rv = m_is.read(m_buf, len);
+#if 0
+ nt::Logger& logger = nt::Logger::GetInstance();
+ if (logger.min_level() <= NT_LOG_DEBUG4 && logger.HasLogger()) {
+ std::ostringstream oss;
+ oss << "read " << len << " bytes:" << std::hex;
+ if (!rv)
+ oss << "error";
+ else {
+ for (std::size_t i=0; i < len; ++i)
+ oss << ' ' << (unsigned int)((*buf)[i]);
+ }
+ logger.Log(NT_LOG_DEBUG4, __FILE__, __LINE__, oss.str().c_str());
+ }
+#endif
+ return rv;
+ }
+
+ /* Reads a single byte. */
+ bool Read8(unsigned int* val) {
+ const char* buf;
+ if (!Read(&buf, 1)) return false;
+ *val = (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ return true;
+ }
+
+ /* Reads a 16-bit word. */
+ bool Read16(unsigned int* val) {
+ const char* buf;
+ if (!Read(&buf, 2)) return false;
+ unsigned int v = (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ v <<= 8;
+ v |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ *val = v;
+ return true;
+ }
+
+ /* Reads a 32-bit word. */
+ bool Read32(unsigned long* val) {
+ const char* buf;
+ if (!Read(&buf, 4)) return false;
+ unsigned int v = (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ v <<= 8;
+ v |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ v <<= 8;
+ v |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ ++buf;
+ v <<= 8;
+ v |= (*reinterpret_cast<const unsigned char*>(buf)) & 0xff;
+ *val = v;
+ return true;
+ }
+
+ /* Reads a double. */
+ bool ReadDouble(double* val);
+
+ /* Reads an ULEB128-encoded unsigned integer. */
+ bool ReadUleb128(unsigned long* val) {
+ return nt::ReadUleb128(m_is, val);
+ }
+
+ bool ReadType(NT_Type* type);
+ bool ReadString(std::string* str);
+ std::shared_ptr<Value> ReadValue(NT_Type type);
+
+ WireDecoder(const WireDecoder&) = delete;
+ WireDecoder& operator=(const WireDecoder&) = delete;
+
+ protected:
+ /* The protocol revision. E.g. 0x0200 for version 2.0. */
+ unsigned int m_proto_rev;
+
+ /* Error indicator. */
+ const char* m_error;
+
+ private:
+ /* Reallocate temporary buffer to specified length. */
+ void Realloc(std::size_t len);
+
+ /* input stream */
+ raw_istream& m_is;
+
+ /* temporary buffer */
+ char* m_buf;
+
+ /* allocated size of temporary buffer */
+ std::size_t m_allocated;
+};
+
+} // namespace nt
+
+#endif // NT_WIREDECODER_H_
diff --git a/src/WireEncoder.cpp b/src/WireEncoder.cpp
new file mode 100644
index 0000000..610a53f
--- /dev/null
+++ b/src/WireEncoder.cpp
@@ -0,0 +1,208 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "WireEncoder.h"
+
+#include <cassert>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+
+#include "llvm/MathExtras.h"
+#include "leb128.h"
+
+using namespace nt;
+
+WireEncoder::WireEncoder(unsigned int proto_rev) {
+ m_proto_rev = proto_rev;
+ m_error = nullptr;
+}
+
+void WireEncoder::WriteDouble(double val) {
+ // The highest performance way to do this, albeit non-portable.
+ std::uint64_t v = llvm::DoubleToBits(val);
+ m_data.append({
+ (char)((v >> 56) & 0xff),
+ (char)((v >> 48) & 0xff),
+ (char)((v >> 40) & 0xff),
+ (char)((v >> 32) & 0xff),
+ (char)((v >> 24) & 0xff),
+ (char)((v >> 16) & 0xff),
+ (char)((v >> 8) & 0xff),
+ (char)(v & 0xff)
+ });
+}
+
+void WireEncoder::WriteUleb128(unsigned long val) {
+ nt::WriteUleb128(m_data, val);
+}
+
+void WireEncoder::WriteType(NT_Type type) {
+ char ch;
+ // Convert from enum to actual byte value.
+ switch (type) {
+ case NT_BOOLEAN:
+ ch = 0x00;
+ break;
+ case NT_DOUBLE:
+ ch = 0x01;
+ break;
+ case NT_STRING:
+ ch = 0x02;
+ break;
+ case NT_RAW:
+ if (m_proto_rev < 0x0300u) {
+ m_error = "raw type not supported in protocol < 3.0";
+ return;
+ }
+ ch = 0x03;
+ break;
+ case NT_BOOLEAN_ARRAY:
+ ch = 0x10;
+ break;
+ case NT_DOUBLE_ARRAY:
+ ch = 0x11;
+ break;
+ case NT_STRING_ARRAY:
+ ch = 0x12;
+ break;
+ case NT_RPC:
+ if (m_proto_rev < 0x0300u) {
+ m_error = "RPC type not supported in protocol < 3.0";
+ return;
+ }
+ ch = 0x20;
+ break;
+ default:
+ m_error = "unrecognized type";
+ return;
+ }
+ m_data.push_back(ch);
+}
+
+std::size_t WireEncoder::GetValueSize(const Value& value) const {
+ switch (value.type()) {
+ case NT_BOOLEAN:
+ return 1;
+ case NT_DOUBLE:
+ return 8;
+ case NT_STRING:
+ return GetStringSize(value.GetString());
+ case NT_RAW:
+ if (m_proto_rev < 0x0300u) return 0;
+ return GetStringSize(value.GetRaw());
+ case NT_RPC:
+ if (m_proto_rev < 0x0300u) return 0;
+ return GetStringSize(value.GetRpc());
+ case NT_BOOLEAN_ARRAY: {
+ // 1-byte size, 1 byte per element
+ std::size_t size = value.GetBooleanArray().size();
+ if (size > 0xff) size = 0xff; // size is only 1 byte, truncate
+ return 1 + size;
+ }
+ case NT_DOUBLE_ARRAY: {
+ // 1-byte size, 8 bytes per element
+ std::size_t size = value.GetDoubleArray().size();
+ if (size > 0xff) size = 0xff; // size is only 1 byte, truncate
+ return 1 + size * 8;
+ }
+ case NT_STRING_ARRAY: {
+ auto v = value.GetStringArray();
+ std::size_t size = v.size();
+ if (size > 0xff) size = 0xff; // size is only 1 byte, truncate
+ std::size_t len = 1; // 1-byte size
+ for (std::size_t i = 0; i < size; ++i)
+ len += GetStringSize(v[i]);
+ return len;
+ }
+ default:
+ return 0;
+ }
+}
+
+void WireEncoder::WriteValue(const Value& value) {
+ switch (value.type()) {
+ case NT_BOOLEAN:
+ Write8(value.GetBoolean() ? 1 : 0);
+ break;
+ case NT_DOUBLE:
+ WriteDouble(value.GetDouble());
+ break;
+ case NT_STRING:
+ WriteString(value.GetString());
+ break;
+ case NT_RAW:
+ if (m_proto_rev < 0x0300u) {
+ m_error = "raw values not supported in protocol < 3.0";
+ return;
+ }
+ WriteString(value.GetRaw());
+ break;
+ case NT_RPC:
+ if (m_proto_rev < 0x0300u) {
+ m_error = "RPC values not supported in protocol < 3.0";
+ return;
+ }
+ WriteString(value.GetRpc());
+ break;
+ case NT_BOOLEAN_ARRAY: {
+ auto v = value.GetBooleanArray();
+ std::size_t size = v.size();
+ if (size > 0xff) size = 0xff; // size is only 1 byte, truncate
+ Write8(size);
+
+ for (std::size_t i = 0; i < size; ++i)
+ Write8(v[i] ? 1 : 0);
+ break;
+ }
+ case NT_DOUBLE_ARRAY: {
+ auto v = value.GetDoubleArray();
+ std::size_t size = v.size();
+ if (size > 0xff) size = 0xff; // size is only 1 byte, truncate
+ Write8(size);
+
+ for (std::size_t i = 0; i < size; ++i)
+ WriteDouble(v[i]);
+ break;
+ }
+ case NT_STRING_ARRAY: {
+ auto v = value.GetStringArray();
+ std::size_t size = v.size();
+ if (size > 0xff) size = 0xff; // size is only 1 byte, truncate
+ Write8(size);
+
+ for (std::size_t i = 0; i < size; ++i)
+ WriteString(v[i]);
+ break;
+ }
+ default:
+ m_error = "unrecognized type when writing value";
+ return;
+ }
+}
+
+std::size_t WireEncoder::GetStringSize(llvm::StringRef str) const {
+ if (m_proto_rev < 0x0300u) {
+ std::size_t len = str.size();
+ if (len > 0xffff) len = 0xffff; // Limited to 64K length; truncate
+ return 2 + len;
+ }
+ return SizeUleb128(str.size()) + str.size();
+}
+
+void WireEncoder::WriteString(llvm::StringRef str) {
+ // length
+ std::size_t len = str.size();
+ if (m_proto_rev < 0x0300u) {
+ if (len > 0xffff) len = 0xffff; // Limited to 64K length; truncate
+ Write16(len);
+ } else
+ WriteUleb128(len);
+
+ // contents
+ m_data.append(str.data(), str.data() + len);
+}
diff --git a/src/WireEncoder.h b/src/WireEncoder.h
new file mode 100644
index 0000000..40a4ca0
--- /dev/null
+++ b/src/WireEncoder.h
@@ -0,0 +1,105 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_WIREENCODER_H_
+#define NT_WIREENCODER_H_
+
+#include <cassert>
+#include <cstddef>
+
+#include "llvm/SmallVector.h"
+#include "llvm/StringRef.h"
+#include "nt_Value.h"
+
+namespace nt {
+
+/* Encodes native data for network transmission.
+ * This class maintains an internal memory buffer for written data so that
+ * it can be efficiently bursted to the network after a number of writes
+ * have been performed. For this reason, all operations are non-blocking.
+ */
+class WireEncoder {
+ public:
+ explicit WireEncoder(unsigned int proto_rev);
+
+ /* Change the protocol revision (mostly affects value encoding). */
+ void set_proto_rev(unsigned int proto_rev) { m_proto_rev = proto_rev; }
+
+ /* Get the active protocol revision. */
+ unsigned int proto_rev() const { return m_proto_rev; }
+
+ /* Clears buffer and error indicator. */
+ void Reset() {
+ m_data.clear();
+ m_error = nullptr;
+ }
+
+ /* Returns error indicator (a string describing the error). Returns nullptr
+ * if no error has occurred.
+ */
+ const char* error() const { return m_error; }
+
+ /* Returns pointer to start of memory buffer with written data. */
+ const char* data() const { return m_data.data(); }
+
+ /* Returns number of bytes written to memory buffer. */
+ std::size_t size() const { return m_data.size(); }
+
+ llvm::StringRef ToStringRef() const {
+ return llvm::StringRef(m_data.data(), m_data.size());
+ }
+
+ /* Writes a single byte. */
+ void Write8(unsigned int val) { m_data.push_back((char)(val & 0xff)); }
+
+ /* Writes a 16-bit word. */
+ void Write16(unsigned int val) {
+ m_data.append({(char)((val >> 8) & 0xff), (char)(val & 0xff)});
+ }
+
+ /* Writes a 32-bit word. */
+ void Write32(unsigned long val) {
+ m_data.append({(char)((val >> 24) & 0xff),
+ (char)((val >> 16) & 0xff),
+ (char)((val >> 8) & 0xff),
+ (char)(val & 0xff)});
+ }
+
+ /* Writes a double. */
+ void WriteDouble(double val);
+
+ /* Writes an ULEB128-encoded unsigned integer. */
+ void WriteUleb128(unsigned long val);
+
+ void WriteType(NT_Type type);
+ void WriteValue(const Value& value);
+ void WriteString(llvm::StringRef str);
+
+ /* Utility function to get the written size of a value (without actually
+ * writing it).
+ */
+ std::size_t GetValueSize(const Value& value) const;
+
+ /* Utility function to get the written size of a string (without actually
+ * writing it).
+ */
+ std::size_t GetStringSize(llvm::StringRef str) const;
+
+ protected:
+ /* The protocol revision. E.g. 0x0200 for version 2.0. */
+ unsigned int m_proto_rev;
+
+ /* Error indicator. */
+ const char* m_error;
+
+ private:
+ llvm::SmallVector<char, 256> m_data;
+};
+
+} // namespace nt
+
+#endif // NT_WIREENCODER_H_
diff --git a/src/atomic_static.h b/src/atomic_static.h
new file mode 100644
index 0000000..b00ccda
--- /dev/null
+++ b/src/atomic_static.h
@@ -0,0 +1,49 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_ATOMIC_STATIC_H_
+#define NT_ATOMIC_STATIC_H_
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+
+// Just use a local static. This is thread-safe per
+// http://preshing.com/20130930/double-checked-locking-is-fixed-in-cpp11/
+
+// Per https://msdn.microsoft.com/en-us/library/Hh567368.aspx "Magic Statics"
+// are supported in Visual Studio 2015 but not in earlier versions.
+#define ATOMIC_STATIC(cls, inst) static cls inst
+#define ATOMIC_STATIC_DECL(cls)
+#define ATOMIC_STATIC_INIT(cls)
+
+#else
+// From http://preshing.com/20130930/double-checked-locking-is-fixed-in-cpp11/
+#include <atomic>
+#include <mutex>
+
+#define ATOMIC_STATIC(cls, inst) \
+ cls* inst##tmp = m_instance.load(std::memory_order_acquire); \
+ if (inst##tmp == nullptr) { \
+ std::lock_guard<std::mutex> lock(m_instance_mutex); \
+ inst##tmp = m_instance.load(std::memory_order_relaxed); \
+ if (inst##tmp == nullptr) { \
+ inst##tmp = new cls; \
+ m_instance.store(inst##tmp, std::memory_order_release); \
+ } \
+ } \
+ cls& inst = *inst##tmp
+
+#define ATOMIC_STATIC_DECL(cls) \
+ static std::atomic<cls*> m_instance; \
+ static std::mutex m_instance_mutex;
+
+#define ATOMIC_STATIC_INIT(cls) \
+ std::atomic<cls*> cls::m_instance; \
+ std::mutex cls::m_instance_mutex;
+
+#endif
+
+#endif // NT_ATOMIC_STATIC_H_
diff --git a/src/leb128.cpp b/src/leb128.cpp
new file mode 100644
index 0000000..3e99842
--- /dev/null
+++ b/src/leb128.cpp
@@ -0,0 +1,119 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "leb128.h"
+
+#include "raw_istream.h"
+
+namespace nt {
+
+/**
+ * Get size of unsigned LEB128 data
+ * @val: value
+ *
+ * Determine the number of bytes required to encode an unsigned LEB128 datum.
+ * The algorithm is taken from Appendix C of the DWARF 3 spec. For information
+ * on the encodings refer to section "7.6 - Variable Length Data". Return
+ * the number of bytes required.
+ */
+std::size_t SizeUleb128(unsigned long val) {
+ std::size_t count = 0;
+ do {
+ val >>= 7;
+ ++count;
+ } while (val != 0);
+ return count;
+}
+
+/**
+ * Write unsigned LEB128 data
+ * @addr: the address where the ULEB128 data is to be stored
+ * @val: value to be stored
+ *
+ * Encode an unsigned LEB128 encoded datum. The algorithm is taken
+ * from Appendix C of the DWARF 3 spec. For information on the
+ * encodings refer to section "7.6 - Variable Length Data". Return
+ * the number of bytes written.
+ */
+std::size_t WriteUleb128(llvm::SmallVectorImpl<char>& dest, unsigned long val) {
+ std::size_t count = 0;
+
+ do {
+ unsigned char byte = val & 0x7f;
+ val >>= 7;
+
+ if (val != 0)
+ byte |= 0x80; // mark this byte to show that more bytes will follow
+
+ dest.push_back(byte);
+ count++;
+ } while (val != 0);
+
+ return count;
+}
+
+/**
+ * Read unsigned LEB128 data
+ * @addr: the address where the ULEB128 data is stored
+ * @ret: address to store the result
+ *
+ * Decode an unsigned LEB128 encoded datum. The algorithm is taken
+ * from Appendix C of the DWARF 3 spec. For information on the
+ * encodings refer to section "7.6 - Variable Length Data". Return
+ * the number of bytes read.
+ */
+std::size_t ReadUleb128(const char* addr, unsigned long* ret) {
+ unsigned long result = 0;
+ int shift = 0;
+ std::size_t count = 0;
+
+ while (1) {
+ unsigned char byte = *reinterpret_cast<const unsigned char*>(addr);
+ addr++;
+ count++;
+
+ result |= (byte & 0x7f) << shift;
+ shift += 7;
+
+ if (!(byte & 0x80)) break;
+ }
+
+ *ret = result;
+
+ return count;
+}
+
+/**
+ * Read unsigned LEB128 data from a stream
+ * @is: the input stream where the ULEB128 data is to be read from
+ * @ret: address to store the result
+ *
+ * Decode an unsigned LEB128 encoded datum. The algorithm is taken
+ * from Appendix C of the DWARF 3 spec. For information on the
+ * encodings refer to section "7.6 - Variable Length Data". Return
+ * false on stream error, true on success.
+ */
+bool ReadUleb128(raw_istream& is, unsigned long* ret) {
+ unsigned long result = 0;
+ int shift = 0;
+
+ while (1) {
+ unsigned char byte;
+ if (!is.read((char*)&byte, 1)) return false;
+
+ result |= (byte & 0x7f) << shift;
+ shift += 7;
+
+ if (!(byte & 0x80)) break;
+ }
+
+ *ret = result;
+
+ return true;
+}
+
+} // namespace nt
diff --git a/src/leb128.h b/src/leb128.h
new file mode 100644
index 0000000..73ca245
--- /dev/null
+++ b/src/leb128.h
@@ -0,0 +1,26 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_LEB128_H_
+#define NT_LEB128_H_
+
+#include <cstddef>
+
+#include "llvm/SmallVector.h"
+
+namespace nt {
+
+class raw_istream;
+
+std::size_t SizeUleb128(unsigned long val);
+std::size_t WriteUleb128(llvm::SmallVectorImpl<char>& dest, unsigned long val);
+std::size_t ReadUleb128(const char* addr, unsigned long* ret);
+bool ReadUleb128(raw_istream& is, unsigned long* ret);
+
+} // namespace nt
+
+#endif // NT_LEB128_H_
diff --git a/src/llvm/SmallPtrSet.cpp b/src/llvm/SmallPtrSet.cpp
new file mode 100644
index 0000000..d23599a
--- /dev/null
+++ b/src/llvm/SmallPtrSet.cpp
@@ -0,0 +1,338 @@
+//===- llvm/ADT/SmallPtrSet.cpp - 'Normally small' pointer set ------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SmallPtrSet class. See SmallPtrSet.h for an
+// overview of the algorithm.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/SmallPtrSet.h"
+#include "llvm/DenseMapInfo.h"
+#include "llvm/MathExtras.h"
+#include <algorithm>
+#include <cstdlib>
+
+using namespace llvm;
+
+void SmallPtrSetImplBase::shrink_and_clear() {
+ assert(!isSmall() && "Can't shrink a small set!");
+ free(CurArray);
+
+ // Reduce the number of buckets.
+ CurArraySize = NumElements > 16 ? 1 << (Log2_32_Ceil(NumElements) + 1) : 32;
+ NumElements = NumTombstones = 0;
+
+ // Install the new array. Clear all the buckets to empty.
+ CurArray = (const void**)malloc(sizeof(void*) * CurArraySize);
+ assert(CurArray && "Failed to allocate memory?");
+ memset(CurArray, -1, CurArraySize*sizeof(void*));
+}
+
+std::pair<const void *const *, bool>
+SmallPtrSetImplBase::insert_imp(const void *Ptr) {
+ if (isSmall()) {
+ // Check to see if it is already in the set.
+ for (const void **APtr = SmallArray, **E = SmallArray+NumElements;
+ APtr != E; ++APtr)
+ if (*APtr == Ptr)
+ return std::make_pair(APtr, false);
+
+ // Nope, there isn't. If we stay small, just 'pushback' now.
+ if (NumElements < CurArraySize) {
+ SmallArray[NumElements++] = Ptr;
+ return std::make_pair(SmallArray + (NumElements - 1), true);
+ }
+ // Otherwise, hit the big set case, which will call grow.
+ }
+
+ if (LLVM_UNLIKELY(NumElements * 4 >= CurArraySize * 3)) {
+ // If more than 3/4 of the array is full, grow.
+ Grow(CurArraySize < 64 ? 128 : CurArraySize*2);
+ } else if (LLVM_UNLIKELY(CurArraySize - (NumElements + NumTombstones) <
+ CurArraySize / 8)) {
+ // If fewer of 1/8 of the array is empty (meaning that many are filled with
+ // tombstones), rehash.
+ Grow(CurArraySize);
+ }
+
+ // Okay, we know we have space. Find a hash bucket.
+ const void **Bucket = const_cast<const void**>(FindBucketFor(Ptr));
+ if (*Bucket == Ptr)
+ return std::make_pair(Bucket, false); // Already inserted, good.
+
+ // Otherwise, insert it!
+ if (*Bucket == getTombstoneMarker())
+ --NumTombstones;
+ *Bucket = Ptr;
+ ++NumElements; // Track density.
+ return std::make_pair(Bucket, true);
+}
+
+bool SmallPtrSetImplBase::erase_imp(const void * Ptr) {
+ if (isSmall()) {
+ // Check to see if it is in the set.
+ for (const void **APtr = SmallArray, **E = SmallArray+NumElements;
+ APtr != E; ++APtr)
+ if (*APtr == Ptr) {
+ // If it is in the set, replace this element.
+ *APtr = E[-1];
+ E[-1] = getEmptyMarker();
+ --NumElements;
+ return true;
+ }
+
+ return false;
+ }
+
+ // Okay, we know we have space. Find a hash bucket.
+ void **Bucket = const_cast<void**>(FindBucketFor(Ptr));
+ if (*Bucket != Ptr) return false; // Not in the set?
+
+ // Set this as a tombstone.
+ *Bucket = getTombstoneMarker();
+ --NumElements;
+ ++NumTombstones;
+ return true;
+}
+
+const void * const *SmallPtrSetImplBase::FindBucketFor(const void *Ptr) const {
+ unsigned Bucket = DenseMapInfo<void *>::getHashValue(Ptr) & (CurArraySize-1);
+ unsigned ArraySize = CurArraySize;
+ unsigned ProbeAmt = 1;
+ const void *const *Array = CurArray;
+ const void *const *Tombstone = nullptr;
+ while (1) {
+ // If we found an empty bucket, the pointer doesn't exist in the set.
+ // Return a tombstone if we've seen one so far, or the empty bucket if
+ // not.
+ if (LLVM_LIKELY(Array[Bucket] == getEmptyMarker()))
+ return Tombstone ? Tombstone : Array+Bucket;
+
+ // Found Ptr's bucket?
+ if (LLVM_LIKELY(Array[Bucket] == Ptr))
+ return Array+Bucket;
+
+ // If this is a tombstone, remember it. If Ptr ends up not in the set, we
+ // prefer to return it than something that would require more probing.
+ if (Array[Bucket] == getTombstoneMarker() && !Tombstone)
+ Tombstone = Array+Bucket; // Remember the first tombstone found.
+
+ // It's a hash collision or a tombstone. Reprobe.
+ Bucket = (Bucket + ProbeAmt++) & (ArraySize-1);
+ }
+}
+
+/// Grow - Allocate a larger backing store for the buckets and move it over.
+///
+void SmallPtrSetImplBase::Grow(unsigned NewSize) {
+ // Allocate at twice as many buckets, but at least 128.
+ unsigned OldSize = CurArraySize;
+
+ const void **OldBuckets = CurArray;
+ bool WasSmall = isSmall();
+
+ // Install the new array. Clear all the buckets to empty.
+ CurArray = (const void**)malloc(sizeof(void*) * NewSize);
+ assert(CurArray && "Failed to allocate memory?");
+ CurArraySize = NewSize;
+ memset(CurArray, -1, NewSize*sizeof(void*));
+
+ // Copy over all the elements.
+ if (WasSmall) {
+ // Small sets store their elements in order.
+ for (const void **BucketPtr = OldBuckets, **E = OldBuckets+NumElements;
+ BucketPtr != E; ++BucketPtr) {
+ const void *Elt = *BucketPtr;
+ *const_cast<void**>(FindBucketFor(Elt)) = const_cast<void*>(Elt);
+ }
+ } else {
+ // Copy over all valid entries.
+ for (const void **BucketPtr = OldBuckets, **E = OldBuckets+OldSize;
+ BucketPtr != E; ++BucketPtr) {
+ // Copy over the element if it is valid.
+ const void *Elt = *BucketPtr;
+ if (Elt != getTombstoneMarker() && Elt != getEmptyMarker())
+ *const_cast<void**>(FindBucketFor(Elt)) = const_cast<void*>(Elt);
+ }
+
+ free(OldBuckets);
+ NumTombstones = 0;
+ }
+}
+
+SmallPtrSetImplBase::SmallPtrSetImplBase(const void **SmallStorage,
+ const SmallPtrSetImplBase& that) {
+ SmallArray = SmallStorage;
+
+ // If we're becoming small, prepare to insert into our stack space
+ if (that.isSmall()) {
+ CurArray = SmallArray;
+ // Otherwise, allocate new heap space (unless we were the same size)
+ } else {
+ CurArray = (const void**)malloc(sizeof(void*) * that.CurArraySize);
+ assert(CurArray && "Failed to allocate memory?");
+ }
+
+ // Copy over the new array size
+ CurArraySize = that.CurArraySize;
+
+ // Copy over the contents from the other set
+ memcpy(CurArray, that.CurArray, sizeof(void*)*CurArraySize);
+
+ NumElements = that.NumElements;
+ NumTombstones = that.NumTombstones;
+}
+
+SmallPtrSetImplBase::SmallPtrSetImplBase(const void **SmallStorage,
+ unsigned SmallSize,
+ SmallPtrSetImplBase &&that) {
+ SmallArray = SmallStorage;
+
+ // Copy over the basic members.
+ CurArraySize = that.CurArraySize;
+ NumElements = that.NumElements;
+ NumTombstones = that.NumTombstones;
+
+ // When small, just copy into our small buffer.
+ if (that.isSmall()) {
+ CurArray = SmallArray;
+ memcpy(CurArray, that.CurArray, sizeof(void *) * CurArraySize);
+ } else {
+ // Otherwise, we steal the large memory allocation and no copy is needed.
+ CurArray = that.CurArray;
+ that.CurArray = that.SmallArray;
+ }
+
+ // Make the "that" object small and empty.
+ that.CurArraySize = SmallSize;
+ assert(that.CurArray == that.SmallArray);
+ that.NumElements = 0;
+ that.NumTombstones = 0;
+}
+
+/// CopyFrom - implement operator= from a smallptrset that has the same pointer
+/// type, but may have a different small size.
+void SmallPtrSetImplBase::CopyFrom(const SmallPtrSetImplBase &RHS) {
+ assert(&RHS != this && "Self-copy should be handled by the caller.");
+
+ if (isSmall() && RHS.isSmall())
+ assert(CurArraySize == RHS.CurArraySize &&
+ "Cannot assign sets with different small sizes");
+
+ // If we're becoming small, prepare to insert into our stack space
+ if (RHS.isSmall()) {
+ if (!isSmall())
+ free(CurArray);
+ CurArray = SmallArray;
+ // Otherwise, allocate new heap space (unless we were the same size)
+ } else if (CurArraySize != RHS.CurArraySize) {
+ if (isSmall())
+ CurArray = (const void**)malloc(sizeof(void*) * RHS.CurArraySize);
+ else {
+ const void **T = (const void**)realloc(CurArray,
+ sizeof(void*) * RHS.CurArraySize);
+ if (!T)
+ free(CurArray);
+ CurArray = T;
+ }
+ assert(CurArray && "Failed to allocate memory?");
+ }
+
+ // Copy over the new array size
+ CurArraySize = RHS.CurArraySize;
+
+ // Copy over the contents from the other set
+ memcpy(CurArray, RHS.CurArray, sizeof(void*)*CurArraySize);
+
+ NumElements = RHS.NumElements;
+ NumTombstones = RHS.NumTombstones;
+}
+
+void SmallPtrSetImplBase::MoveFrom(unsigned SmallSize,
+ SmallPtrSetImplBase &&RHS) {
+ assert(&RHS != this && "Self-move should be handled by the caller.");
+
+ if (!isSmall())
+ free(CurArray);
+
+ if (RHS.isSmall()) {
+ // Copy a small RHS rather than moving.
+ CurArray = SmallArray;
+ memcpy(CurArray, RHS.CurArray, sizeof(void*)*RHS.CurArraySize);
+ } else {
+ CurArray = RHS.CurArray;
+ RHS.CurArray = RHS.SmallArray;
+ }
+
+ // Copy the rest of the trivial members.
+ CurArraySize = RHS.CurArraySize;
+ NumElements = RHS.NumElements;
+ NumTombstones = RHS.NumTombstones;
+
+ // Make the RHS small and empty.
+ RHS.CurArraySize = SmallSize;
+ assert(RHS.CurArray == RHS.SmallArray);
+ RHS.NumElements = 0;
+ RHS.NumTombstones = 0;
+}
+
+void SmallPtrSetImplBase::swap(SmallPtrSetImplBase &RHS) {
+ if (this == &RHS) return;
+
+ // We can only avoid copying elements if neither set is small.
+ if (!this->isSmall() && !RHS.isSmall()) {
+ std::swap(this->CurArray, RHS.CurArray);
+ std::swap(this->CurArraySize, RHS.CurArraySize);
+ std::swap(this->NumElements, RHS.NumElements);
+ std::swap(this->NumTombstones, RHS.NumTombstones);
+ return;
+ }
+
+ // FIXME: From here on we assume that both sets have the same small size.
+
+ // If only RHS is small, copy the small elements into LHS and move the pointer
+ // from LHS to RHS.
+ if (!this->isSmall() && RHS.isSmall()) {
+ std::copy(RHS.SmallArray, RHS.SmallArray+RHS.CurArraySize,
+ this->SmallArray);
+ std::swap(this->NumElements, RHS.NumElements);
+ std::swap(this->CurArraySize, RHS.CurArraySize);
+ RHS.CurArray = this->CurArray;
+ RHS.NumTombstones = this->NumTombstones;
+ this->CurArray = this->SmallArray;
+ this->NumTombstones = 0;
+ return;
+ }
+
+ // If only LHS is small, copy the small elements into RHS and move the pointer
+ // from RHS to LHS.
+ if (this->isSmall() && !RHS.isSmall()) {
+ std::copy(this->SmallArray, this->SmallArray+this->CurArraySize,
+ RHS.SmallArray);
+ std::swap(RHS.NumElements, this->NumElements);
+ std::swap(RHS.CurArraySize, this->CurArraySize);
+ this->CurArray = RHS.CurArray;
+ this->NumTombstones = RHS.NumTombstones;
+ RHS.CurArray = RHS.SmallArray;
+ RHS.NumTombstones = 0;
+ return;
+ }
+
+ // Both a small, just swap the small elements.
+ assert(this->isSmall() && RHS.isSmall());
+ assert(this->CurArraySize == RHS.CurArraySize);
+ std::swap_ranges(this->SmallArray, this->SmallArray+this->CurArraySize,
+ RHS.SmallArray);
+ std::swap(this->NumElements, RHS.NumElements);
+}
+
+SmallPtrSetImplBase::~SmallPtrSetImplBase() {
+ if (!isSmall())
+ free(CurArray);
+}
diff --git a/src/llvm/SmallVector.cpp b/src/llvm/SmallVector.cpp
new file mode 100644
index 0000000..6aa709e
--- /dev/null
+++ b/src/llvm/SmallVector.cpp
@@ -0,0 +1,41 @@
+//===- llvm/ADT/SmallVector.cpp - 'Normally small' vectors ----------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SmallVector class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/SmallVector.h"
+using namespace llvm;
+
+/// grow_pod - This is an implementation of the grow() method which only works
+/// on POD-like datatypes and is out of line to reduce code duplication.
+void SmallVectorBase::grow_pod(void *FirstEl, size_t MinSizeInBytes,
+ size_t TSize) {
+ size_t CurSizeBytes = size_in_bytes();
+ size_t NewCapacityInBytes = 2 * capacity_in_bytes() + TSize; // Always grow.
+ if (NewCapacityInBytes < MinSizeInBytes)
+ NewCapacityInBytes = MinSizeInBytes;
+
+ void *NewElts;
+ if (BeginX == FirstEl) {
+ NewElts = malloc(NewCapacityInBytes);
+
+ // Copy the elements over. No need to run dtors on PODs.
+ memcpy(NewElts, this->BeginX, CurSizeBytes);
+ } else {
+ // If this wasn't grown from the inline copy, grow the allocated space.
+ NewElts = realloc(this->BeginX, NewCapacityInBytes);
+ }
+ assert(NewElts && "Out of memory");
+
+ this->EndX = (char*)NewElts+CurSizeBytes;
+ this->BeginX = NewElts;
+ this->CapacityX = (char*)this->BeginX + NewCapacityInBytes;
+}
diff --git a/src/llvm/StringExtras.cpp b/src/llvm/StringExtras.cpp
new file mode 100644
index 0000000..74b47a5
--- /dev/null
+++ b/src/llvm/StringExtras.cpp
@@ -0,0 +1,58 @@
+//===-- StringExtras.cpp - Implement the StringExtras header --------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the StringExtras.h header
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/StringExtras.h"
+#include "llvm/SmallVector.h"
+using namespace llvm;
+
+/// StrInStrNoCase - Portable version of strcasestr. Locates the first
+/// occurrence of string 's1' in string 's2', ignoring case. Returns
+/// the offset of s2 in s1 or npos if s2 cannot be found.
+StringRef::size_type llvm::StrInStrNoCase(StringRef s1, StringRef s2) {
+ size_t N = s2.size(), M = s1.size();
+ if (N > M)
+ return StringRef::npos;
+ for (size_t i = 0, e = M - N + 1; i != e; ++i)
+ if (s1.substr(i, N).equals_lower(s2))
+ return i;
+ return StringRef::npos;
+}
+
+/// getToken - This function extracts one token from source, ignoring any
+/// leading characters that appear in the Delimiters string, and ending the
+/// token at any of the characters that appear in the Delimiters string. If
+/// there are no tokens in the source string, an empty string is returned.
+/// The function returns a pair containing the extracted token and the
+/// remaining tail string.
+std::pair<StringRef, StringRef> llvm::getToken(StringRef Source,
+ StringRef Delimiters) {
+ // Figure out where the token starts.
+ StringRef::size_type Start = Source.find_first_not_of(Delimiters);
+
+ // Find the next occurrence of the delimiter.
+ StringRef::size_type End = Source.find_first_of(Delimiters, Start);
+
+ return std::make_pair(Source.slice(Start, End), Source.substr(End));
+}
+
+/// SplitString - Split up the specified string according to the specified
+/// delimiters, appending the result fragments to the output list.
+void llvm::SplitString(StringRef Source,
+ SmallVectorImpl<StringRef> &OutFragments,
+ StringRef Delimiters) {
+ std::pair<StringRef, StringRef> S = getToken(Source, Delimiters);
+ while (!S.first.empty()) {
+ OutFragments.push_back(S.first);
+ S = getToken(S.second, Delimiters);
+ }
+}
diff --git a/src/llvm/StringMap.cpp b/src/llvm/StringMap.cpp
new file mode 100644
index 0000000..5649834
--- /dev/null
+++ b/src/llvm/StringMap.cpp
@@ -0,0 +1,244 @@
+//===--- StringMap.cpp - String Hash table map implementation -------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the StringMap class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/StringMap.h"
+#include "llvm/StringExtras.h"
+//#include "llvm/Support/Compiler.h"
+#include <cassert>
+using namespace llvm;
+
+StringMapImpl::StringMapImpl(unsigned InitSize, unsigned itemSize) {
+ ItemSize = itemSize;
+
+ // If a size is specified, initialize the table with that many buckets.
+ if (InitSize) {
+ init(InitSize);
+ return;
+ }
+
+ // Otherwise, initialize it with zero buckets to avoid the allocation.
+ TheTable = nullptr;
+ NumBuckets = 0;
+ NumItems = 0;
+ NumTombstones = 0;
+}
+
+void StringMapImpl::init(unsigned InitSize) {
+ assert((InitSize & (InitSize-1)) == 0 &&
+ "Init Size must be a power of 2 or zero!");
+ NumBuckets = InitSize ? InitSize : 16;
+ NumItems = 0;
+ NumTombstones = 0;
+
+ TheTable = (StringMapEntryBase **)calloc(NumBuckets+1,
+ sizeof(StringMapEntryBase **) +
+ sizeof(unsigned));
+
+ // Allocate one extra bucket, set it to look filled so the iterators stop at
+ // end.
+ TheTable[NumBuckets] = (StringMapEntryBase*)2;
+}
+
+
+/// LookupBucketFor - Look up the bucket that the specified string should end
+/// up in. If it already exists as a key in the map, the Item pointer for the
+/// specified bucket will be non-null. Otherwise, it will be null. In either
+/// case, the FullHashValue field of the bucket will be set to the hash value
+/// of the string.
+unsigned StringMapImpl::LookupBucketFor(StringRef Name) {
+ unsigned HTSize = NumBuckets;
+ if (HTSize == 0) { // Hash table unallocated so far?
+ init(16);
+ HTSize = NumBuckets;
+ }
+ unsigned FullHashValue = HashString(Name);
+ unsigned BucketNo = FullHashValue & (HTSize-1);
+ unsigned *HashTable = (unsigned *)(TheTable + NumBuckets + 1);
+
+ unsigned ProbeAmt = 1;
+ int FirstTombstone = -1;
+ while (1) {
+ StringMapEntryBase *BucketItem = TheTable[BucketNo];
+ // If we found an empty bucket, this key isn't in the table yet, return it.
+ if (!BucketItem) {
+ // If we found a tombstone, we want to reuse the tombstone instead of an
+ // empty bucket. This reduces probing.
+ if (FirstTombstone != -1) {
+ HashTable[FirstTombstone] = FullHashValue;
+ return FirstTombstone;
+ }
+
+ HashTable[BucketNo] = FullHashValue;
+ return BucketNo;
+ }
+
+ if (BucketItem == getTombstoneVal()) {
+ // Skip over tombstones. However, remember the first one we see.
+ if (FirstTombstone == -1) FirstTombstone = BucketNo;
+ } else if (HashTable[BucketNo] == FullHashValue) {
+ // If the full hash value matches, check deeply for a match. The common
+ // case here is that we are only looking at the buckets (for item info
+ // being non-null and for the full hash value) not at the items. This
+ // is important for cache locality.
+
+ // Do the comparison like this because Name isn't necessarily
+ // null-terminated!
+ char *ItemStr = (char*)BucketItem+ItemSize;
+ if (Name == StringRef(ItemStr, BucketItem->getKeyLength())) {
+ // We found a match!
+ return BucketNo;
+ }
+ }
+
+ // Okay, we didn't find the item. Probe to the next bucket.
+ BucketNo = (BucketNo+ProbeAmt) & (HTSize-1);
+
+ // Use quadratic probing, it has fewer clumping artifacts than linear
+ // probing and has good cache behavior in the common case.
+ ++ProbeAmt;
+ }
+}
+
+
+/// FindKey - Look up the bucket that contains the specified key. If it exists
+/// in the map, return the bucket number of the key. Otherwise return -1.
+/// This does not modify the map.
+int StringMapImpl::FindKey(StringRef Key) const {
+ unsigned HTSize = NumBuckets;
+ if (HTSize == 0) return -1; // Really empty table?
+ unsigned FullHashValue = HashString(Key);
+ unsigned BucketNo = FullHashValue & (HTSize-1);
+ unsigned *HashTable = (unsigned *)(TheTable + NumBuckets + 1);
+
+ unsigned ProbeAmt = 1;
+ while (1) {
+ StringMapEntryBase *BucketItem = TheTable[BucketNo];
+ // If we found an empty bucket, this key isn't in the table yet, return.
+ if (!BucketItem)
+ return -1;
+
+ if (BucketItem == getTombstoneVal()) {
+ // Ignore tombstones.
+ } else if (HashTable[BucketNo] == FullHashValue) {
+ // If the full hash value matches, check deeply for a match. The common
+ // case here is that we are only looking at the buckets (for item info
+ // being non-null and for the full hash value) not at the items. This
+ // is important for cache locality.
+
+ // Do the comparison like this because NameStart isn't necessarily
+ // null-terminated!
+ char *ItemStr = (char*)BucketItem+ItemSize;
+ if (Key == StringRef(ItemStr, BucketItem->getKeyLength())) {
+ // We found a match!
+ return BucketNo;
+ }
+ }
+
+ // Okay, we didn't find the item. Probe to the next bucket.
+ BucketNo = (BucketNo+ProbeAmt) & (HTSize-1);
+
+ // Use quadratic probing, it has fewer clumping artifacts than linear
+ // probing and has good cache behavior in the common case.
+ ++ProbeAmt;
+ }
+}
+
+/// RemoveKey - Remove the specified StringMapEntry from the table, but do not
+/// delete it. This aborts if the value isn't in the table.
+void StringMapImpl::RemoveKey(StringMapEntryBase *V) {
+ const char *VStr = (char*)V + ItemSize;
+ StringMapEntryBase *V2 = RemoveKey(StringRef(VStr, V->getKeyLength()));
+ (void)V2;
+ assert(V == V2 && "Didn't find key?");
+}
+
+/// RemoveKey - Remove the StringMapEntry for the specified key from the
+/// table, returning it. If the key is not in the table, this returns null.
+StringMapEntryBase *StringMapImpl::RemoveKey(StringRef Key) {
+ int Bucket = FindKey(Key);
+ if (Bucket == -1) return nullptr;
+
+ StringMapEntryBase *Result = TheTable[Bucket];
+ TheTable[Bucket] = getTombstoneVal();
+ --NumItems;
+ ++NumTombstones;
+ assert(NumItems + NumTombstones <= NumBuckets);
+
+ return Result;
+}
+
+
+
+/// RehashTable - Grow the table, redistributing values into the buckets with
+/// the appropriate mod-of-hashtable-size.
+unsigned StringMapImpl::RehashTable(unsigned BucketNo) {
+ unsigned NewSize;
+ unsigned *HashTable = (unsigned *)(TheTable + NumBuckets + 1);
+
+ // If the hash table is now more than 3/4 full, or if fewer than 1/8 of
+ // the buckets are empty (meaning that many are filled with tombstones),
+ // grow/rehash the table.
+ if (NumItems * 4 > NumBuckets * 3) {
+ NewSize = NumBuckets*2;
+ } else if (NumBuckets - (NumItems + NumTombstones) <= NumBuckets / 8) {
+ NewSize = NumBuckets;
+ } else {
+ return BucketNo;
+ }
+
+ unsigned NewBucketNo = BucketNo;
+ // Allocate one extra bucket which will always be non-empty. This allows the
+ // iterators to stop at end.
+ StringMapEntryBase **NewTableArray =
+ (StringMapEntryBase **)calloc(NewSize+1, sizeof(StringMapEntryBase *) +
+ sizeof(unsigned));
+ unsigned *NewHashArray = (unsigned *)(NewTableArray + NewSize + 1);
+ NewTableArray[NewSize] = (StringMapEntryBase*)2;
+
+ // Rehash all the items into their new buckets. Luckily :) we already have
+ // the hash values available, so we don't have to rehash any strings.
+ for (unsigned I = 0, E = NumBuckets; I != E; ++I) {
+ StringMapEntryBase *Bucket = TheTable[I];
+ if (Bucket && Bucket != getTombstoneVal()) {
+ // Fast case, bucket available.
+ unsigned FullHash = HashTable[I];
+ unsigned NewBucket = FullHash & (NewSize-1);
+ if (!NewTableArray[NewBucket]) {
+ NewTableArray[FullHash & (NewSize-1)] = Bucket;
+ NewHashArray[FullHash & (NewSize-1)] = FullHash;
+ if (I == BucketNo)
+ NewBucketNo = NewBucket;
+ continue;
+ }
+
+ // Otherwise probe for a spot.
+ unsigned ProbeSize = 1;
+ do {
+ NewBucket = (NewBucket + ProbeSize++) & (NewSize-1);
+ } while (NewTableArray[NewBucket]);
+
+ // Finally found a slot. Fill it in.
+ NewTableArray[NewBucket] = Bucket;
+ NewHashArray[NewBucket] = FullHash;
+ if (I == BucketNo)
+ NewBucketNo = NewBucket;
+ }
+ }
+
+ free(TheTable);
+
+ TheTable = NewTableArray;
+ NumBuckets = NewSize;
+ NumTombstones = 0;
+ return NewBucketNo;
+}
diff --git a/src/llvm/StringRef.cpp b/src/llvm/StringRef.cpp
new file mode 100644
index 0000000..f12318c
--- /dev/null
+++ b/src/llvm/StringRef.cpp
@@ -0,0 +1,393 @@
+//===-- StringRef.cpp - Lightweight String References ---------------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/StringRef.h"
+#include "llvm/SmallVector.h"
+#include <bitset>
+#include <climits>
+
+using namespace llvm;
+
+// MSVC emits references to this into the translation units which reference it.
+#ifndef _MSC_VER
+const size_t StringRef::npos;
+#endif
+
+static char ascii_tolower(char x) {
+ if (x >= 'A' && x <= 'Z')
+ return x - 'A' + 'a';
+ return x;
+}
+
+static char ascii_toupper(char x) {
+ if (x >= 'a' && x <= 'z')
+ return x - 'a' + 'A';
+ return x;
+}
+
+static bool ascii_isdigit(char x) {
+ return x >= '0' && x <= '9';
+}
+
+// strncasecmp() is not available on non-POSIX systems, so define an
+// alternative function here.
+static int ascii_strncasecmp(const char *LHS, const char *RHS, size_t Length) {
+ for (size_t I = 0; I < Length; ++I) {
+ unsigned char LHC = ascii_tolower(LHS[I]);
+ unsigned char RHC = ascii_tolower(RHS[I]);
+ if (LHC != RHC)
+ return LHC < RHC ? -1 : 1;
+ }
+ return 0;
+}
+
+/// compare_lower - Compare strings, ignoring case.
+int StringRef::compare_lower(StringRef RHS) const {
+ if (int Res = ascii_strncasecmp(Data, RHS.Data, std::min(Length, RHS.Length)))
+ return Res;
+ if (Length == RHS.Length)
+ return 0;
+ return Length < RHS.Length ? -1 : 1;
+}
+
+/// Check if this string starts with the given \p Prefix, ignoring case.
+bool StringRef::startswith_lower(StringRef Prefix) const {
+ return Length >= Prefix.Length &&
+ ascii_strncasecmp(Data, Prefix.Data, Prefix.Length) == 0;
+}
+
+/// Check if this string ends with the given \p Suffix, ignoring case.
+bool StringRef::endswith_lower(StringRef Suffix) const {
+ return Length >= Suffix.Length &&
+ ascii_strncasecmp(end() - Suffix.Length, Suffix.Data, Suffix.Length) == 0;
+}
+
+/// compare_numeric - Compare strings, handle embedded numbers.
+int StringRef::compare_numeric(StringRef RHS) const {
+ for (size_t I = 0, E = std::min(Length, RHS.Length); I != E; ++I) {
+ // Check for sequences of digits.
+ if (ascii_isdigit(Data[I]) && ascii_isdigit(RHS.Data[I])) {
+ // The longer sequence of numbers is considered larger.
+ // This doesn't really handle prefixed zeros well.
+ size_t J;
+ for (J = I + 1; J != E + 1; ++J) {
+ bool ld = J < Length && ascii_isdigit(Data[J]);
+ bool rd = J < RHS.Length && ascii_isdigit(RHS.Data[J]);
+ if (ld != rd)
+ return rd ? -1 : 1;
+ if (!rd)
+ break;
+ }
+ // The two number sequences have the same length (J-I), just memcmp them.
+ if (int Res = compareMemory(Data + I, RHS.Data + I, J - I))
+ return Res < 0 ? -1 : 1;
+ // Identical number sequences, continue search after the numbers.
+ I = J - 1;
+ continue;
+ }
+ if (Data[I] != RHS.Data[I])
+ return (unsigned char)Data[I] < (unsigned char)RHS.Data[I] ? -1 : 1;
+ }
+ if (Length == RHS.Length)
+ return 0;
+ return Length < RHS.Length ? -1 : 1;
+}
+
+//===----------------------------------------------------------------------===//
+// String Operations
+//===----------------------------------------------------------------------===//
+
+std::string StringRef::lower() const {
+ std::string Result(size(), char());
+ for (size_type i = 0, e = size(); i != e; ++i) {
+ Result[i] = ascii_tolower(Data[i]);
+ }
+ return Result;
+}
+
+std::string StringRef::upper() const {
+ std::string Result(size(), char());
+ for (size_type i = 0, e = size(); i != e; ++i) {
+ Result[i] = ascii_toupper(Data[i]);
+ }
+ return Result;
+}
+
+//===----------------------------------------------------------------------===//
+// String Searching
+//===----------------------------------------------------------------------===//
+
+
+/// find - Search for the first string \arg Str in the string.
+///
+/// \return - The index of the first occurrence of \arg Str, or npos if not
+/// found.
+size_t StringRef::find(StringRef Str, size_t From) const {
+ size_t N = Str.size();
+ if (N > Length)
+ return npos;
+
+ // For short haystacks or unsupported needles fall back to the naive algorithm
+ if (Length < 16 || N > 255 || N == 0) {
+ for (size_t e = Length - N + 1, i = std::min(From, e); i != e; ++i)
+ if (substr(i, N).equals(Str))
+ return i;
+ return npos;
+ }
+
+ if (From >= Length)
+ return npos;
+
+ // Build the bad char heuristic table, with uint8_t to reduce cache thrashing.
+ uint8_t BadCharSkip[256];
+ std::memset(BadCharSkip, N, 256);
+ for (unsigned i = 0; i != N-1; ++i)
+ BadCharSkip[(uint8_t)Str[i]] = N-1-i;
+
+ unsigned Len = Length-From, Pos = From;
+ while (Len >= N) {
+ if (substr(Pos, N).equals(Str)) // See if this is the correct substring.
+ return Pos;
+
+ // Otherwise skip the appropriate number of bytes.
+ uint8_t Skip = BadCharSkip[(uint8_t)(*this)[Pos+N-1]];
+ Len -= Skip;
+ Pos += Skip;
+ }
+
+ return npos;
+}
+
+/// rfind - Search for the last string \arg Str in the string.
+///
+/// \return - The index of the last occurrence of \arg Str, or npos if not
+/// found.
+size_t StringRef::rfind(StringRef Str) const {
+ size_t N = Str.size();
+ if (N > Length)
+ return npos;
+ for (size_t i = Length - N + 1, e = 0; i != e;) {
+ --i;
+ if (substr(i, N).equals(Str))
+ return i;
+ }
+ return npos;
+}
+
+/// find_first_of - Find the first character in the string that is in \arg
+/// Chars, or npos if not found.
+///
+/// Note: O(size() + Chars.size())
+StringRef::size_type StringRef::find_first_of(StringRef Chars,
+ size_t From) const {
+ std::bitset<1 << CHAR_BIT> CharBits;
+ for (size_type i = 0; i != Chars.size(); ++i)
+ CharBits.set((unsigned char)Chars[i]);
+
+ for (size_type i = std::min(From, Length), e = Length; i != e; ++i)
+ if (CharBits.test((unsigned char)Data[i]))
+ return i;
+ return npos;
+}
+
+/// find_first_not_of - Find the first character in the string that is not
+/// \arg C or npos if not found.
+StringRef::size_type StringRef::find_first_not_of(char C, size_t From) const {
+ for (size_type i = std::min(From, Length), e = Length; i != e; ++i)
+ if (Data[i] != C)
+ return i;
+ return npos;
+}
+
+/// find_first_not_of - Find the first character in the string that is not
+/// in the string \arg Chars, or npos if not found.
+///
+/// Note: O(size() + Chars.size())
+StringRef::size_type StringRef::find_first_not_of(StringRef Chars,
+ size_t From) const {
+ std::bitset<1 << CHAR_BIT> CharBits;
+ for (size_type i = 0; i != Chars.size(); ++i)
+ CharBits.set((unsigned char)Chars[i]);
+
+ for (size_type i = std::min(From, Length), e = Length; i != e; ++i)
+ if (!CharBits.test((unsigned char)Data[i]))
+ return i;
+ return npos;
+}
+
+/// find_last_of - Find the last character in the string that is in \arg C,
+/// or npos if not found.
+///
+/// Note: O(size() + Chars.size())
+StringRef::size_type StringRef::find_last_of(StringRef Chars,
+ size_t From) const {
+ std::bitset<1 << CHAR_BIT> CharBits;
+ for (size_type i = 0; i != Chars.size(); ++i)
+ CharBits.set((unsigned char)Chars[i]);
+
+ for (size_type i = std::min(From, Length) - 1, e = -1; i != e; --i)
+ if (CharBits.test((unsigned char)Data[i]))
+ return i;
+ return npos;
+}
+
+/// find_last_not_of - Find the last character in the string that is not
+/// \arg C, or npos if not found.
+StringRef::size_type StringRef::find_last_not_of(char C, size_t From) const {
+ for (size_type i = std::min(From, Length) - 1, e = -1; i != e; --i)
+ if (Data[i] != C)
+ return i;
+ return npos;
+}
+
+/// find_last_not_of - Find the last character in the string that is not in
+/// \arg Chars, or npos if not found.
+///
+/// Note: O(size() + Chars.size())
+StringRef::size_type StringRef::find_last_not_of(StringRef Chars,
+ size_t From) const {
+ std::bitset<1 << CHAR_BIT> CharBits;
+ for (size_type i = 0, e = Chars.size(); i != e; ++i)
+ CharBits.set((unsigned char)Chars[i]);
+
+ for (size_type i = std::min(From, Length) - 1, e = -1; i != e; --i)
+ if (!CharBits.test((unsigned char)Data[i]))
+ return i;
+ return npos;
+}
+
+void StringRef::split(SmallVectorImpl<StringRef> &A,
+ StringRef Separators, int MaxSplit,
+ bool KeepEmpty) const {
+ StringRef rest = *this;
+
+ // rest.data() is used to distinguish cases like "a," that splits into
+ // "a" + "" and "a" that splits into "a" + 0.
+ for (int splits = 0;
+ rest.data() != nullptr && (MaxSplit < 0 || splits < MaxSplit);
+ ++splits) {
+ std::pair<StringRef, StringRef> p = rest.split(Separators);
+
+ if (KeepEmpty || p.first.size() != 0)
+ A.push_back(p.first);
+ rest = p.second;
+ }
+ // If we have a tail left, add it.
+ if (rest.data() != nullptr && (rest.size() != 0 || KeepEmpty))
+ A.push_back(rest);
+}
+
+//===----------------------------------------------------------------------===//
+// Helpful Algorithms
+//===----------------------------------------------------------------------===//
+
+/// count - Return the number of non-overlapped occurrences of \arg Str in
+/// the string.
+size_t StringRef::count(StringRef Str) const {
+ size_t Count = 0;
+ size_t N = Str.size();
+ if (N > Length)
+ return 0;
+ for (size_t i = 0, e = Length - N + 1; i != e; ++i)
+ if (substr(i, N).equals(Str))
+ ++Count;
+ return Count;
+}
+
+static unsigned GetAutoSenseRadix(StringRef &Str) {
+ if (Str.startswith("0x")) {
+ Str = Str.substr(2);
+ return 16;
+ }
+
+ if (Str.startswith("0b")) {
+ Str = Str.substr(2);
+ return 2;
+ }
+
+ if (Str.startswith("0o")) {
+ Str = Str.substr(2);
+ return 8;
+ }
+
+ if (Str.startswith("0"))
+ return 8;
+
+ return 10;
+}
+
+
+/// GetAsUnsignedInteger - Workhorse method that converts a integer character
+/// sequence of radix up to 36 to an unsigned long long value.
+bool llvm::getAsUnsignedInteger(StringRef Str, unsigned Radix,
+ unsigned long long &Result) {
+ // Autosense radix if not specified.
+ if (Radix == 0)
+ Radix = GetAutoSenseRadix(Str);
+
+ // Empty strings (after the radix autosense) are invalid.
+ if (Str.empty()) return true;
+
+ // Parse all the bytes of the string given this radix. Watch for overflow.
+ Result = 0;
+ while (!Str.empty()) {
+ unsigned CharVal;
+ if (Str[0] >= '0' && Str[0] <= '9')
+ CharVal = Str[0]-'0';
+ else if (Str[0] >= 'a' && Str[0] <= 'z')
+ CharVal = Str[0]-'a'+10;
+ else if (Str[0] >= 'A' && Str[0] <= 'Z')
+ CharVal = Str[0]-'A'+10;
+ else
+ return true;
+
+ // If the parsed value is larger than the integer radix, the string is
+ // invalid.
+ if (CharVal >= Radix)
+ return true;
+
+ // Add in this character.
+ unsigned long long PrevResult = Result;
+ Result = Result*Radix+CharVal;
+
+ // Check for overflow by shifting back and seeing if bits were lost.
+ if (Result/Radix < PrevResult)
+ return true;
+
+ Str = Str.substr(1);
+ }
+
+ return false;
+}
+
+bool llvm::getAsSignedInteger(StringRef Str, unsigned Radix,
+ long long &Result) {
+ unsigned long long ULLVal;
+
+ // Handle positive strings first.
+ if (Str.empty() || Str.front() != '-') {
+ if (getAsUnsignedInteger(Str, Radix, ULLVal) ||
+ // Check for value so large it overflows a signed value.
+ (long long)ULLVal < 0)
+ return true;
+ Result = ULLVal;
+ return false;
+ }
+
+ // Get the positive part of the value.
+ if (getAsUnsignedInteger(Str.substr(1), Radix, ULLVal) ||
+ // Reject values so large they'd overflow as negative signed, but allow
+ // "-0". This negates the unsigned so that the negative isn't undefined
+ // on signed overflow.
+ (long long)-ULLVal > 0)
+ return true;
+
+ Result = -ULLVal;
+ return false;
+}
diff --git a/src/networktables/NetworkTable.cpp b/src/networktables/NetworkTable.cpp
new file mode 100644
index 0000000..7eb3772
--- /dev/null
+++ b/src/networktables/NetworkTable.cpp
@@ -0,0 +1,381 @@
+#include "networktables/NetworkTable.h"
+
+#include <algorithm>
+
+#include "llvm/SmallString.h"
+#include "llvm/StringMap.h"
+#include "tables/ITableListener.h"
+#include "tables/TableKeyNotDefinedException.h"
+#include "ntcore.h"
+
+using llvm::StringRef;
+
+const char NetworkTable::PATH_SEPARATOR_CHAR = '/';
+std::string NetworkTable::s_ip_address;
+std::string NetworkTable::s_persistent_filename = "networktables.ini";
+bool NetworkTable::s_client = false;
+bool NetworkTable::s_running = false;
+unsigned int NetworkTable::s_port = NT_DEFAULT_PORT;
+
+void NetworkTable::Initialize() {
+ if (s_running) Shutdown();
+ if (s_client)
+ nt::StartClient(s_ip_address.c_str(), s_port);
+ else
+ nt::StartServer(s_persistent_filename, "", s_port);
+ s_running = true;
+}
+
+void NetworkTable::Shutdown() {
+ if (!s_running) return;
+ if (s_client)
+ nt::StopClient();
+ else
+ nt::StopServer();
+ s_running = false;
+}
+
+void NetworkTable::SetClientMode() { s_client = true; }
+
+void NetworkTable::SetServerMode() { s_client = false; }
+
+void NetworkTable::SetTeam(int team) {
+ char tmp[30];
+#ifdef _MSC_VER
+ sprintf_s(tmp, "roboRIO-%d-FRC.local\n", team);
+#else
+ std::snprintf(tmp, 30, "roboRIO-%d-FRC.local\n",team);
+#endif
+ SetIPAddress(tmp);
+}
+
+void NetworkTable::SetIPAddress(StringRef address) { s_ip_address = address; }
+
+void NetworkTable::SetPort(unsigned int port) { s_port = port; }
+
+void NetworkTable::SetPersistentFilename(StringRef filename) {
+ s_persistent_filename = filename;
+}
+
+void NetworkTable::SetNetworkIdentity(StringRef name) {
+ nt::SetNetworkIdentity(name);
+}
+
+void NetworkTable::GlobalDeleteAll() { nt::DeleteAllEntries(); }
+
+void NetworkTable::Flush() { nt::Flush(); }
+
+void NetworkTable::SetUpdateRate(double interval) {
+ nt::SetUpdateRate(interval);
+}
+
+const char* NetworkTable::SavePersistent(llvm::StringRef filename) {
+ return nt::SavePersistent(filename);
+}
+
+const char* NetworkTable::LoadPersistent(
+ llvm::StringRef filename,
+ std::function<void(size_t line, const char* msg)> warn) {
+ return nt::LoadPersistent(filename, warn);
+}
+
+std::shared_ptr<NetworkTable> NetworkTable::GetTable(StringRef key) {
+ if (!s_running) Initialize();
+ llvm::SmallString<128> path;
+ if (!key.empty()) {
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ }
+ return std::make_shared<NetworkTable>(path, private_init());
+}
+
+NetworkTable::NetworkTable(StringRef path, const private_init&)
+ : m_path(path) {}
+
+NetworkTable::~NetworkTable() {
+ for (auto& i : m_listeners)
+ nt::RemoveEntryListener(i.second);
+}
+
+void NetworkTable::AddTableListener(ITableListener* listener) {
+ AddTableListenerEx(listener, NT_NOTIFY_NEW | NT_NOTIFY_UPDATE);
+}
+
+void NetworkTable::AddTableListener(ITableListener* listener,
+ bool immediateNotify) {
+ unsigned int flags = NT_NOTIFY_NEW | NT_NOTIFY_UPDATE;
+ if (immediateNotify) flags |= NT_NOTIFY_IMMEDIATE;
+ AddTableListenerEx(listener, flags);
+}
+
+void NetworkTable::AddTableListenerEx(ITableListener* listener,
+ unsigned int flags) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ std::size_t prefix_len = path.size();
+ unsigned int id = nt::AddEntryListener(
+ path,
+ [=](unsigned int /*uid*/, StringRef name,
+ std::shared_ptr<nt::Value> value, unsigned int flags_) {
+ StringRef relative_key = name.substr(prefix_len);
+ if (relative_key.find(PATH_SEPARATOR_CHAR) != StringRef::npos) return;
+ listener->ValueChangedEx(this, relative_key, value, flags_);
+ },
+ flags);
+ m_listeners.emplace_back(listener, id);
+}
+
+void NetworkTable::AddTableListener(StringRef key, ITableListener* listener,
+ bool immediateNotify) {
+ unsigned int flags = NT_NOTIFY_NEW | NT_NOTIFY_UPDATE;
+ if (immediateNotify) flags |= NT_NOTIFY_IMMEDIATE;
+ AddTableListenerEx(key, listener, flags);
+}
+
+void NetworkTable::AddTableListenerEx(StringRef key, ITableListener* listener,
+ unsigned int flags) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ std::size_t prefix_len = path.size();
+ path += key;
+ unsigned int id = nt::AddEntryListener(
+ path,
+ [=](unsigned int /*uid*/, StringRef name, std::shared_ptr<nt::Value> value,
+ unsigned int flags_) {
+ if (name != path) return;
+ listener->ValueChangedEx(this, name.substr(prefix_len), value, flags_);
+ },
+ flags);
+ m_listeners.emplace_back(listener, id);
+}
+
+void NetworkTable::AddSubTableListener(ITableListener* listener) {
+ AddSubTableListener(listener, false);
+}
+
+void NetworkTable::AddSubTableListener(ITableListener* listener,
+ bool localNotify) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ std::size_t prefix_len = path.size();
+
+ // The lambda needs to be copyable, but StringMap is not, so use
+ // a shared_ptr to it.
+ auto notified_tables = std::make_shared<llvm::StringMap<char>>();
+
+ unsigned int flags = NT_NOTIFY_NEW | NT_NOTIFY_IMMEDIATE;
+ if (localNotify) flags |= NT_NOTIFY_LOCAL;
+ unsigned int id = nt::AddEntryListener(
+ path,
+ [=](unsigned int /*uid*/, StringRef name,
+ std::shared_ptr<nt::Value> /*value*/, unsigned int flags_) mutable {
+ StringRef relative_key = name.substr(prefix_len);
+ auto end_sub_table = relative_key.find(PATH_SEPARATOR_CHAR);
+ if (end_sub_table == StringRef::npos) return;
+ StringRef sub_table_key = relative_key.substr(0, end_sub_table);
+ if (notified_tables->find(sub_table_key) == notified_tables->end())
+ return;
+ notified_tables->insert(std::make_pair(sub_table_key, '\0'));
+ listener->ValueChangedEx(this, sub_table_key, nullptr, flags_);
+ },
+ flags);
+ m_listeners.emplace_back(listener, id);
+}
+
+void NetworkTable::RemoveTableListener(ITableListener* listener) {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ auto matches_begin =
+ std::remove_if(m_listeners.begin(), m_listeners.end(),
+ [=](const Listener& x) { return x.first == listener; });
+
+ for (auto i = matches_begin; i != m_listeners.end(); ++i)
+ nt::RemoveEntryListener(i->second);
+ m_listeners.erase(matches_begin, m_listeners.end());
+}
+
+std::shared_ptr<ITable> NetworkTable::GetSubTable(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return std::make_shared<NetworkTable>(path, private_init());
+}
+
+bool NetworkTable::ContainsKey(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return nt::GetEntryValue(path) != nullptr;
+}
+
+bool NetworkTable::ContainsSubTable(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ path += PATH_SEPARATOR_CHAR;
+ return !nt::GetEntryInfo(path, 0).empty();
+}
+
+std::vector<std::string> NetworkTable::GetKeys(int types) const {
+ std::vector<std::string> keys;
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ for (auto& entry : nt::GetEntryInfo(path, types)) {
+ auto relative_key = StringRef(entry.name).substr(path.size());
+ if (relative_key.find(PATH_SEPARATOR_CHAR) != StringRef::npos)
+ continue;
+ keys.push_back(relative_key);
+ }
+ return keys;
+}
+
+std::vector<std::string> NetworkTable::GetSubTables() const {
+ std::vector<std::string> keys;
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ for (auto& entry : nt::GetEntryInfo(path, 0)) {
+ auto relative_key = StringRef(entry.name).substr(path.size());
+ std::size_t end_subtable = relative_key.find(PATH_SEPARATOR_CHAR);
+ if (end_subtable == StringRef::npos) continue;
+ keys.push_back(relative_key.substr(0, end_subtable));
+ }
+ return keys;
+}
+
+void NetworkTable::SetPersistent(StringRef key) {
+ SetFlags(key, NT_PERSISTENT);
+}
+
+void NetworkTable::ClearPersistent(StringRef key) {
+ ClearFlags(key, NT_PERSISTENT);
+}
+
+bool NetworkTable::IsPersistent(StringRef key) const {
+ return (GetFlags(key) & NT_PERSISTENT) != 0;
+}
+
+void NetworkTable::SetFlags(StringRef key, unsigned int flags) {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ nt::SetEntryFlags(path, nt::GetEntryFlags(key) | flags);
+}
+
+void NetworkTable::ClearFlags(StringRef key, unsigned int flags) {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ nt::SetEntryFlags(path, nt::GetEntryFlags(path) & ~flags);
+}
+
+unsigned int NetworkTable::GetFlags(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return nt::GetEntryFlags(path);
+}
+
+void NetworkTable::Delete(StringRef key) {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ nt::DeleteEntry(path);
+}
+
+bool NetworkTable::PutNumber(StringRef key, double value) {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return nt::SetEntryValue(path, nt::Value::MakeDouble(value));
+}
+
+double NetworkTable::GetNumber(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ auto value = nt::GetEntryValue(path);
+ if (!value || value->type() != NT_DOUBLE)
+ throw TableKeyNotDefinedException(path);
+ return value->GetDouble();
+}
+
+double NetworkTable::GetNumber(StringRef key, double defaultValue) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ auto value = nt::GetEntryValue(path);
+ if (!value || value->type() != NT_DOUBLE)
+ return defaultValue;
+ return value->GetDouble();
+}
+
+bool NetworkTable::PutString(StringRef key, StringRef value) {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return nt::SetEntryValue(path, nt::Value::MakeString(value));
+}
+
+std::string NetworkTable::GetString(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ auto value = nt::GetEntryValue(path);
+ if (!value || value->type() != NT_STRING)
+ throw TableKeyNotDefinedException(path);
+ return value->GetString();
+}
+
+std::string NetworkTable::GetString(StringRef key,
+ StringRef defaultValue) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ auto value = nt::GetEntryValue(path);
+ if (!value || value->type() != NT_STRING)
+ return defaultValue;
+ return value->GetString();
+}
+
+bool NetworkTable::PutBoolean(StringRef key, bool value) {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return nt::SetEntryValue(path, nt::Value::MakeBoolean(value));
+}
+
+bool NetworkTable::GetBoolean(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ auto value = nt::GetEntryValue(path);
+ if (!value || value->type() != NT_BOOLEAN)
+ throw TableKeyNotDefinedException(path);
+ return value->GetBoolean();
+}
+
+bool NetworkTable::GetBoolean(StringRef key, bool defaultValue) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ auto value = nt::GetEntryValue(path);
+ if (!value || value->type() != NT_BOOLEAN)
+ return defaultValue;
+ return value->GetBoolean();
+}
+
+bool NetworkTable::PutValue(StringRef key, std::shared_ptr<nt::Value> value) {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return nt::SetEntryValue(path, value);
+}
+
+std::shared_ptr<nt::Value> NetworkTable::GetValue(StringRef key) const {
+ llvm::SmallString<128> path(m_path);
+ path += PATH_SEPARATOR_CHAR;
+ path += key;
+ return nt::GetEntryValue(path);
+}
diff --git a/src/ntcore_c.cpp b/src/ntcore_c.cpp
new file mode 100644
index 0000000..d7f5456
--- /dev/null
+++ b/src/ntcore_c.cpp
@@ -0,0 +1,770 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "ntcore.h"
+
+#include <cassert>
+#include <cstdlib>
+
+#include "Value_internal.h"
+
+using namespace nt;
+
+// Conversion helpers
+
+static void ConvertToC(llvm::StringRef in, char** out) {
+ *out = static_cast<char*>(std::malloc(in.size() + 1));
+ std::memmove(*out, in.data(), in.size());
+ (*out)[in.size()] = '\0';
+}
+
+static void ConvertToC(const EntryInfo& in, NT_EntryInfo* out) {
+ ConvertToC(in.name, &out->name);
+ out->type = in.type;
+ out->flags = in.flags;
+ out->last_change = in.last_change;
+}
+
+static void ConvertToC(const ConnectionInfo& in, NT_ConnectionInfo* out) {
+ ConvertToC(in.remote_id, &out->remote_id);
+ ConvertToC(in.remote_name, &out->remote_name);
+ out->remote_port = in.remote_port;
+ out->last_update = in.last_update;
+ out->protocol_version = in.protocol_version;
+}
+
+static void ConvertToC(const RpcParamDef& in, NT_RpcParamDef* out) {
+ ConvertToC(in.name, &out->name);
+ ConvertToC(*in.def_value, &out->def_value);
+}
+
+static void ConvertToC(const RpcResultDef& in, NT_RpcResultDef* out) {
+ ConvertToC(in.name, &out->name);
+ out->type = in.type;
+}
+
+static void ConvertToC(const RpcDefinition& in, NT_RpcDefinition* out) {
+ out->version = in.version;
+ ConvertToC(in.name, &out->name);
+
+ out->num_params = in.params.size();
+ out->params = static_cast<NT_RpcParamDef*>(
+ std::malloc(in.params.size() * sizeof(NT_RpcParamDef)));
+ for (size_t i = 0; i < in.params.size(); ++i)
+ ConvertToC(in.params[i], &out->params[i]);
+
+ out->num_results = in.results.size();
+ out->results = static_cast<NT_RpcResultDef*>(
+ std::malloc(in.results.size() * sizeof(NT_RpcResultDef)));
+ for (size_t i = 0; i < in.results.size(); ++i)
+ ConvertToC(in.results[i], &out->results[i]);
+}
+
+static void ConvertToC(const RpcCallInfo& in, NT_RpcCallInfo* out) {
+ out->rpc_id = in.rpc_id;
+ out->call_uid = in.call_uid;
+ ConvertToC(in.name, &out->name);
+ ConvertToC(in.params, &out->params);
+}
+
+static void DisposeConnectionInfo(NT_ConnectionInfo *info) {
+ std::free(info->remote_id.str);
+ std::free(info->remote_name);
+}
+
+static void DisposeEntryInfo(NT_EntryInfo *info) {
+ std::free(info->name.str);
+}
+
+static RpcParamDef ConvertFromC(const NT_RpcParamDef& in) {
+ RpcParamDef out;
+ out.name = ConvertFromC(in.name);
+ out.def_value = ConvertFromC(in.def_value);
+ return out;
+}
+
+static RpcResultDef ConvertFromC(const NT_RpcResultDef& in) {
+ RpcResultDef out;
+ out.name = ConvertFromC(in.name);
+ out.type = in.type;
+ return out;
+}
+
+static RpcDefinition ConvertFromC(const NT_RpcDefinition& in) {
+ RpcDefinition out;
+ out.version = in.version;
+ out.name = ConvertFromC(in.name);
+
+ out.params.reserve(in.num_params);
+ for (size_t i = 0; i < in.num_params; ++i)
+ out.params.push_back(ConvertFromC(in.params[i]));
+
+ out.results.reserve(in.num_results);
+ for (size_t i = 0; i < in.num_results; ++i)
+ out.results.push_back(ConvertFromC(in.results[i]));
+
+ return out;
+}
+
+extern "C" {
+
+/*
+ * Table Functions
+ */
+
+void NT_GetEntryValue(const char *name, size_t name_len,
+ struct NT_Value *value) {
+ NT_InitValue(value);
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v) return;
+ ConvertToC(*v, value);
+}
+
+int NT_SetEntryValue(const char *name, size_t name_len,
+ const struct NT_Value *value) {
+ return nt::SetEntryValue(StringRef(name, name_len), ConvertFromC(*value));
+}
+
+void NT_SetEntryTypeValue(const char *name, size_t name_len,
+ const struct NT_Value *value) {
+ nt::SetEntryTypeValue(StringRef(name, name_len), ConvertFromC(*value));
+}
+
+void NT_SetEntryFlags(const char *name, size_t name_len, unsigned int flags) {
+ nt::SetEntryFlags(StringRef(name, name_len), flags);
+}
+
+unsigned int NT_GetEntryFlags(const char *name, size_t name_len) {
+ return nt::GetEntryFlags(StringRef(name, name_len));
+}
+
+void NT_DeleteEntry(const char *name, size_t name_len) {
+ nt::DeleteEntry(StringRef(name, name_len));
+}
+
+void NT_DeleteAllEntries(void) { nt::DeleteAllEntries(); }
+
+struct NT_EntryInfo *NT_GetEntryInfo(const char *prefix, size_t prefix_len,
+ unsigned int types, size_t *count) {
+ auto info_v = nt::GetEntryInfo(StringRef(prefix, prefix_len), types);
+ *count = info_v.size();
+ if (info_v.size() == 0) return nullptr;
+
+ // create array and copy into it
+ NT_EntryInfo* info = static_cast<NT_EntryInfo*>(
+ std::malloc(info_v.size() * sizeof(NT_EntryInfo)));
+ for (size_t i = 0; i < info_v.size(); ++i) ConvertToC(info_v[i], &info[i]);
+ return info;
+}
+
+void NT_Flush(void) { nt::Flush(); }
+
+/*
+ * Callback Creation Functions
+ */
+
+void NT_SetListenerOnStart(void (*on_start)(void *data), void *data) {
+ nt::SetListenerOnStart([=]() { on_start(data); });
+}
+
+void NT_SetListenerOnExit(void (*on_exit)(void *data), void *data) {
+ nt::SetListenerOnExit([=]() { on_exit(data); });
+}
+
+unsigned int NT_AddEntryListener(const char *prefix, size_t prefix_len,
+ void *data,
+ NT_EntryListenerCallback callback,
+ unsigned int flags) {
+ return nt::AddEntryListener(
+ StringRef(prefix, prefix_len),
+ [=](unsigned int uid, StringRef name, std::shared_ptr<Value> value,
+ unsigned int flags_) {
+ callback(uid, data, name.data(), name.size(), &value->value(), flags_);
+ },
+ flags);
+}
+
+void NT_RemoveEntryListener(unsigned int entry_listener_uid) {
+ nt::RemoveEntryListener(entry_listener_uid);
+}
+
+unsigned int NT_AddConnectionListener(void *data,
+ NT_ConnectionListenerCallback callback,
+ int immediate_notify) {
+ return nt::AddConnectionListener(
+ [=](unsigned int uid, bool connected, const ConnectionInfo &conn) {
+ NT_ConnectionInfo conn_c;
+ ConvertToC(conn, &conn_c);
+ callback(uid, data, connected ? 1 : 0, &conn_c);
+ DisposeConnectionInfo(&conn_c);
+ },
+ immediate_notify != 0);
+}
+
+void NT_RemoveConnectionListener(unsigned int conn_listener_uid) {
+ nt::RemoveConnectionListener(conn_listener_uid);
+}
+
+int NT_NotifierDestroyed() { return nt::NotifierDestroyed(); }
+
+/*
+ * Remote Procedure Call Functions
+ */
+
+void NT_CreateRpc(const char *name, size_t name_len, const char *def,
+ size_t def_len, void *data, NT_RpcCallback callback) {
+ nt::CreateRpc(
+ StringRef(name, name_len), StringRef(def, def_len),
+ [=](StringRef name, StringRef params) -> std::string {
+ size_t results_len;
+ char* results_c = callback(data, name.data(), name.size(),
+ params.data(), params.size(), &results_len);
+ std::string results(results_c, results_len);
+ std::free(results_c);
+ return results;
+ });
+}
+
+void NT_CreatePolledRpc(const char *name, size_t name_len, const char *def,
+ size_t def_len) {
+ nt::CreatePolledRpc(StringRef(name, name_len), StringRef(def, def_len));
+}
+
+int NT_PollRpc(int blocking, NT_RpcCallInfo* call_info) {
+ RpcCallInfo call_info_cpp;
+ if (!nt::PollRpc(blocking != 0, &call_info_cpp))
+ return 0;
+ ConvertToC(call_info_cpp, call_info);
+ return 1;
+}
+
+void NT_PostRpcResponse(unsigned int rpc_id, unsigned int call_uid,
+ const char *result, size_t result_len) {
+ nt::PostRpcResponse(rpc_id, call_uid, StringRef(result, result_len));
+}
+
+unsigned int NT_CallRpc(const char *name, size_t name_len,
+ const char *params, size_t params_len) {
+ return nt::CallRpc(StringRef(name, name_len), StringRef(params, params_len));
+}
+
+char *NT_GetRpcResult(int blocking, unsigned int call_uid, size_t *result_len) {
+ std::string result;
+ if (!nt::GetRpcResult(blocking != 0, call_uid, &result)) return nullptr;
+
+ // convert result
+ *result_len = result.size();
+ char *result_cstr;
+ ConvertToC(result, &result_cstr);
+ return result_cstr;
+}
+
+char *NT_PackRpcDefinition(const NT_RpcDefinition *def, size_t *packed_len) {
+ auto packed = nt::PackRpcDefinition(ConvertFromC(*def));
+
+ // convert result
+ *packed_len = packed.size();
+ char *packed_cstr;
+ ConvertToC(packed, &packed_cstr);
+ return packed_cstr;
+}
+
+int NT_UnpackRpcDefinition(const char *packed, size_t packed_len,
+ NT_RpcDefinition *def) {
+ nt::RpcDefinition def_v;
+ if (!nt::UnpackRpcDefinition(StringRef(packed, packed_len), &def_v))
+ return 0;
+
+ // convert result
+ ConvertToC(def_v, def);
+ return 1;
+}
+
+char *NT_PackRpcValues(const NT_Value **values, size_t values_len,
+ size_t *packed_len) {
+ // create input vector
+ std::vector<std::shared_ptr<Value>> values_v;
+ values_v.reserve(values_len);
+ for (size_t i = 0; i < values_len; ++i)
+ values_v.push_back(ConvertFromC(*values[i]));
+
+ // make the call
+ auto packed = nt::PackRpcValues(values_v);
+
+ // convert result
+ *packed_len = packed.size();
+ char *packed_cstr;
+ ConvertToC(packed, &packed_cstr);
+ return packed_cstr;
+}
+
+NT_Value **NT_UnpackRpcValues(const char *packed, size_t packed_len,
+ const NT_Type *types, size_t types_len) {
+ auto values_v = nt::UnpackRpcValues(StringRef(packed, packed_len),
+ ArrayRef<NT_Type>(types, types_len));
+ if (values_v.size() == 0) return nullptr;
+
+ // create array and copy into it
+ NT_Value** values = static_cast<NT_Value**>(
+ std::malloc(values_v.size() * sizeof(NT_Value*)));
+ for (size_t i = 0; i < values_v.size(); ++i) {
+ values[i] = static_cast<NT_Value*>(std::malloc(sizeof(NT_Value)));
+ ConvertToC(*values_v[i], values[i]);
+ }
+ return values;
+}
+
+/*
+ * Client/Server Functions
+ */
+
+void NT_SetNetworkIdentity(const char *name, size_t name_len) {
+ nt::SetNetworkIdentity(StringRef(name, name_len));
+}
+
+void NT_StartServer(const char *persist_filename, const char *listen_address,
+ unsigned int port) {
+ nt::StartServer(persist_filename, listen_address, port);
+}
+
+void NT_StopServer(void) { nt::StopServer(); }
+
+void NT_StartClient(const char *server_name, unsigned int port) {
+ nt::StartClient(server_name, port);
+}
+
+void NT_StopClient(void) {
+ nt::StopClient();
+}
+
+void NT_StopRpcServer(void) {
+ nt::StopRpcServer();
+}
+
+void NT_StopNotifier(void) {
+ nt::StopNotifier();
+}
+
+void NT_SetUpdateRate(double interval) {
+ nt::SetUpdateRate(interval);
+}
+
+struct NT_ConnectionInfo *NT_GetConnections(size_t *count) {
+ auto conn_v = nt::GetConnections();
+ *count = conn_v.size();
+ if (conn_v.size() == 0) return nullptr;
+
+ // create array and copy into it
+ NT_ConnectionInfo *conn = static_cast<NT_ConnectionInfo *>(
+ std::malloc(conn_v.size() * sizeof(NT_ConnectionInfo)));
+ for (size_t i = 0; i < conn_v.size(); ++i) ConvertToC(conn_v[i], &conn[i]);
+ return conn;
+}
+
+/*
+ * Persistent Functions
+ */
+
+const char *NT_SavePersistent(const char *filename) {
+ return nt::SavePersistent(filename);
+}
+
+const char *NT_LoadPersistent(const char *filename,
+ void (*warn)(size_t line, const char *msg)) {
+ return nt::LoadPersistent(filename, warn);
+}
+
+/*
+ * Utility Functions
+ */
+
+void NT_SetLogger(NT_LogFunc func, unsigned int min_level) {
+ nt::SetLogger(func, min_level);
+}
+
+void NT_DisposeValue(NT_Value *value) {
+ switch (value->type) {
+ case NT_UNASSIGNED:
+ case NT_BOOLEAN:
+ case NT_DOUBLE:
+ break;
+ case NT_STRING:
+ case NT_RAW:
+ case NT_RPC:
+ std::free(value->data.v_string.str);
+ break;
+ case NT_BOOLEAN_ARRAY:
+ std::free(value->data.arr_boolean.arr);
+ break;
+ case NT_DOUBLE_ARRAY:
+ std::free(value->data.arr_double.arr);
+ break;
+ case NT_STRING_ARRAY: {
+ for (size_t i = 0; i < value->data.arr_string.size; i++)
+ std::free(value->data.arr_string.arr[i].str);
+ std::free(value->data.arr_string.arr);
+ break;
+ }
+ default:
+ assert(false && "unknown value type");
+ }
+ value->type = NT_UNASSIGNED;
+ value->last_change = 0;
+}
+
+void NT_InitValue(NT_Value *value) {
+ value->type = NT_UNASSIGNED;
+ value->last_change = 0;
+}
+
+void NT_DisposeString(NT_String *str) {
+ std::free(str->str);
+ str->str = nullptr;
+ str->len = 0;
+}
+
+void NT_InitString(NT_String *str) {
+ str->str = nullptr;
+ str->len = 0;
+}
+
+enum NT_Type NT_GetType(const char *name, size_t name_len) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v) return NT_Type::NT_UNASSIGNED;
+ return v->type();
+}
+
+void NT_DisposeConnectionInfoArray(NT_ConnectionInfo *arr, size_t count) {
+ for (size_t i = 0; i < count; i++) DisposeConnectionInfo(&arr[i]);
+ std::free(arr);
+}
+
+void NT_DisposeEntryInfoArray(NT_EntryInfo *arr, size_t count){
+ for (size_t i = 0; i < count; i++) DisposeEntryInfo(&arr[i]);
+ std::free(arr);
+}
+
+void NT_DisposeRpcDefinition(NT_RpcDefinition *def) {
+ NT_DisposeString(&def->name);
+
+ for (size_t i = 0; i < def->num_params; ++i) {
+ NT_DisposeString(&def->params[i].name);
+ NT_DisposeValue(&def->params[i].def_value);
+ }
+ std::free(def->params);
+ def->params = nullptr;
+ def->num_params = 0;
+
+ for (size_t i = 0; i < def->num_results; ++i)
+ NT_DisposeString(&def->results[i].name);
+ std::free(def->results);
+ def->results = nullptr;
+ def->num_results = 0;
+}
+
+void NT_DisposeRpcCallInfo(NT_RpcCallInfo *call_info) {
+ NT_DisposeString(&call_info->name);
+ NT_DisposeString(&call_info->params);
+}
+
+/* Interop Utility Functions */
+
+/* Array and Struct Allocations */
+
+/* Allocates a char array of the specified size.*/
+char *NT_AllocateCharArray(size_t size) {
+ char *retVal = static_cast<char *>(std::malloc(size * sizeof(char)));
+ return retVal;
+}
+
+/* Allocates an integer or boolean array of the specified size. */
+int *NT_AllocateBooleanArray(size_t size) {
+ int *retVal = static_cast<int *>(std::malloc(size * sizeof(int)));
+ return retVal;
+}
+
+/* Allocates a double array of the specified size. */
+double *NT_AllocateDoubleArray(size_t size) {
+ double *retVal = static_cast<double *>(std::malloc(size * sizeof(double)));
+ return retVal;
+}
+
+/* Allocates an NT_String array of the specified size. */
+struct NT_String *NT_AllocateStringArray(size_t size) {
+ NT_String *retVal =
+ static_cast<NT_String *>(std::malloc(size * sizeof(NT_String)));
+ return retVal;
+}
+
+void NT_FreeCharArray(char *v_char) { std::free(v_char); }
+void NT_FreeDoubleArray(double *v_double) { std::free(v_double); }
+void NT_FreeBooleanArray(int *v_boolean) { std::free(v_boolean); }
+void NT_FreeStringArray(struct NT_String *v_string, size_t arr_size) {
+ for (size_t i = 0; i < arr_size; i++) std::free(v_string[i].str);
+ std::free(v_string);
+}
+
+int NT_SetEntryDouble(const char *name, size_t name_len, double v_double,
+ int force) {
+ if (force != 0) {
+ nt::SetEntryTypeValue(StringRef(name, name_len),
+ Value::MakeDouble(v_double));
+ return 1;
+ } else {
+ return nt::SetEntryValue(StringRef(name, name_len),
+ Value::MakeDouble(v_double));
+ }
+}
+
+int NT_SetEntryBoolean(const char *name, size_t name_len, int v_boolean,
+ int force) {
+ if (force != 0) {
+ nt::SetEntryTypeValue(StringRef(name, name_len),
+ Value::MakeBoolean(v_boolean != 0));
+ return 1;
+ } else {
+ return nt::SetEntryValue(StringRef(name, name_len),
+ Value::MakeBoolean(v_boolean != 0));
+ }
+}
+
+int NT_SetEntryString(const char *name, size_t name_len, const char *str,
+ size_t str_len, int force) {
+ if (force != 0) {
+ nt::SetEntryTypeValue(StringRef(name, name_len),
+ Value::MakeString(StringRef(str, str_len)));
+ return 1;
+ } else {
+ return nt::SetEntryValue(StringRef(name, name_len),
+ Value::MakeString(StringRef(str, str_len)));
+ }
+}
+
+int NT_SetEntryRaw(const char *name, size_t name_len, const char *raw,
+ size_t raw_len, int force) {
+ if (force != 0) {
+ nt::SetEntryTypeValue(StringRef(name, name_len),
+ Value::MakeRaw(StringRef(raw, raw_len)));
+ return 1;
+ } else {
+ return nt::SetEntryValue(StringRef(name, name_len),
+ Value::MakeRaw(StringRef(raw, raw_len)));
+ }
+}
+
+int NT_SetEntryBooleanArray(const char *name, size_t name_len, const int *arr,
+ size_t size, int force) {
+ if (force != 0) {
+ nt::SetEntryTypeValue(
+ StringRef(name, name_len),
+ Value::MakeBooleanArray(llvm::makeArrayRef(arr, size)));
+ return 1;
+ } else {
+ return nt::SetEntryValue(
+ StringRef(name, name_len),
+ Value::MakeBooleanArray(llvm::makeArrayRef(arr, size)));
+ }
+}
+
+int NT_SetEntryDoubleArray(const char *name, size_t name_len, const double *arr,
+ size_t size, int force) {
+ if (force != 0) {
+ nt::SetEntryTypeValue(
+ StringRef(name, name_len),
+ Value::MakeDoubleArray(llvm::makeArrayRef(arr, size)));
+ return 1;
+ } else {
+ return nt::SetEntryValue(
+ StringRef(name, name_len),
+ Value::MakeDoubleArray(llvm::makeArrayRef(arr, size)));
+ }
+}
+
+int NT_SetEntryStringArray(const char *name, size_t name_len,
+ const struct NT_String *arr, size_t size,
+ int force) {
+ std::vector<std::string> v;
+ v.reserve(size);
+ for (size_t i = 0; i < size; ++i) v.push_back(ConvertFromC(arr[i]));
+
+ if (force != 0) {
+ nt::SetEntryTypeValue(StringRef(name, name_len),
+ Value::MakeStringArray(std::move(v)));
+ return 1;
+ } else {
+ return nt::SetEntryValue(StringRef(name, name_len),
+ Value::MakeStringArray(std::move(v)));
+ }
+}
+
+enum NT_Type NT_GetValueType(const struct NT_Value *value) {
+ if (!value) return NT_Type::NT_UNASSIGNED;
+ return value->type;
+}
+
+int NT_GetValueBoolean(const struct NT_Value *value,
+ unsigned long long *last_change, int *v_boolean) {
+ if (!value || value->type != NT_Type::NT_BOOLEAN) return 0;
+ *v_boolean = value->data.v_boolean;
+ *last_change = value->last_change;
+ return 1;
+}
+
+int NT_GetValueDouble(const struct NT_Value *value,
+ unsigned long long *last_change, double *v_double) {
+ if (!value || value->type != NT_Type::NT_DOUBLE) return 0;
+ *last_change = value->last_change;
+ *v_double = value->data.v_double;
+ return 1;
+}
+
+char *NT_GetValueString(const struct NT_Value *value,
+ unsigned long long *last_change, size_t *str_len) {
+ if (!value || value->type != NT_Type::NT_STRING) return nullptr;
+ *last_change = value->last_change;
+ *str_len = value->data.v_string.len;
+ char *str = (char*)std::malloc(value->data.v_string.len + 1);
+ std::memcpy(str, value->data.v_string.str, value->data.v_string.len + 1);
+ return str;
+}
+
+char *NT_GetValueRaw(const struct NT_Value *value,
+ unsigned long long *last_change, size_t *raw_len) {
+ if (!value || value->type != NT_Type::NT_RAW) return nullptr;
+ *last_change = value->last_change;
+ *raw_len = value->data.v_string.len;
+ char *raw = (char*)std::malloc(value->data.v_string.len + 1);
+ std::memcpy(raw, value->data.v_string.str, value->data.v_string.len + 1);
+ return raw;
+}
+
+int *NT_GetValueBooleanArray(const struct NT_Value *value,
+ unsigned long long *last_change,
+ size_t *arr_size) {
+ if (!value || value->type != NT_Type::NT_BOOLEAN_ARRAY) return nullptr;
+ *last_change = value->last_change;
+ *arr_size = value->data.arr_boolean.size;
+ int *arr = (int*)std::malloc(value->data.arr_boolean.size * sizeof(int));
+ std::memcpy(arr, value->data.arr_boolean.arr,
+ value->data.arr_boolean.size * sizeof(int));
+ return arr;
+}
+
+double *NT_GetValueDoubleArray(const struct NT_Value *value,
+ unsigned long long *last_change,
+ size_t *arr_size) {
+ if (!value || value->type != NT_Type::NT_DOUBLE_ARRAY) return nullptr;
+ *last_change = value->last_change;
+ *arr_size = value->data.arr_double.size;
+ double *arr =
+ (double *)std::malloc(value->data.arr_double.size * sizeof(double));
+ std::memcpy(arr, value->data.arr_double.arr,
+ value->data.arr_double.size * sizeof(double));
+ return arr;
+}
+
+NT_String *NT_GetValueStringArray(const struct NT_Value *value,
+ unsigned long long *last_change,
+ size_t *arr_size) {
+ if (!value || value->type != NT_Type::NT_STRING_ARRAY) return nullptr;
+ *last_change = value->last_change;
+ *arr_size = value->data.arr_string.size;
+ NT_String *arr = static_cast<NT_String *>(
+ std::malloc(value->data.arr_string.size * sizeof(NT_String)));
+ for (size_t i = 0; i < value->data.arr_string.size; ++i) {
+ size_t len = value->data.arr_string.arr[i].len;
+ arr[i].len = len;
+ arr[i].str = (char*)std::malloc(len + 1);
+ std::memcpy(arr[i].str, value->data.arr_string.arr[i].str, len + 1);
+ }
+ return arr;
+}
+
+int NT_GetEntryBoolean(const char *name, size_t name_len,
+ unsigned long long *last_change, int *v_boolean) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v || !v->IsBoolean()) return 0;
+ *v_boolean = v->GetBoolean();
+ *last_change = v->last_change();
+ return 1;
+}
+
+int NT_GetEntryDouble(const char *name, size_t name_len,
+ unsigned long long *last_change, double *v_double) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v || !v->IsDouble()) return 0;
+ *last_change = v->last_change();
+ *v_double = v->GetDouble();
+ return 1;
+}
+
+char *NT_GetEntryString(const char *name, size_t name_len,
+ unsigned long long *last_change, size_t *str_len) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v || !v->IsString()) return nullptr;
+ *last_change = v->last_change();
+ struct NT_String v_string;
+ nt::ConvertToC(v->GetString(), &v_string);
+ *str_len = v_string.len;
+ return v_string.str;
+}
+
+char *NT_GetEntryRaw(const char *name, size_t name_len,
+ unsigned long long *last_change, size_t *raw_len) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v || !v->IsRaw()) return nullptr;
+ *last_change = v->last_change();
+ struct NT_String v_raw;
+ nt::ConvertToC(v->GetRaw(), &v_raw);
+ *raw_len = v_raw.len;
+ return v_raw.str;
+}
+
+int *NT_GetEntryBooleanArray(const char *name, size_t name_len,
+ unsigned long long *last_change,
+ size_t *arr_size) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v || !v->IsBooleanArray()) return nullptr;
+ *last_change = v->last_change();
+ auto vArr = v->GetBooleanArray();
+ int *arr = static_cast<int *>(std::malloc(vArr.size() * sizeof(int)));
+ *arr_size = vArr.size();
+ std::copy(vArr.begin(), vArr.end(), arr);
+ return arr;
+}
+
+double *NT_GetEntryDoubleArray(const char *name, size_t name_len,
+ unsigned long long *last_change,
+ size_t *arr_size) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v || !v->IsDoubleArray()) return nullptr;
+ *last_change = v->last_change();
+ auto vArr = v->GetDoubleArray();
+ double *arr =
+ static_cast<double *>(std::malloc(vArr.size() * sizeof(double)));
+ *arr_size = vArr.size();
+ std::copy(vArr.begin(), vArr.end(), arr);
+ return arr;
+}
+
+NT_String *NT_GetEntryStringArray(const char *name, size_t name_len,
+ unsigned long long *last_change,
+ size_t *arr_size) {
+ auto v = nt::GetEntryValue(StringRef(name, name_len));
+ if (!v || !v->IsStringArray()) return nullptr;
+ *last_change = v->last_change();
+ auto vArr = v->GetStringArray();
+ NT_String *arr =
+ static_cast<NT_String *>(std::malloc(vArr.size() * sizeof(NT_String)));
+ for (size_t i = 0; i < vArr.size(); ++i) {
+ ConvertToC(vArr[i], &arr[i]);
+ }
+ *arr_size = vArr.size();
+ return arr;
+}
+
+} // extern "C"
diff --git a/src/ntcore_cpp.cpp b/src/ntcore_cpp.cpp
new file mode 100644
index 0000000..c50b3d4
--- /dev/null
+++ b/src/ntcore_cpp.cpp
@@ -0,0 +1,277 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "ntcore.h"
+
+#include <cassert>
+#include <cstdio>
+#include <cstdlib>
+
+#include "Dispatcher.h"
+#include "Log.h"
+#include "Notifier.h"
+#include "RpcServer.h"
+#include "Storage.h"
+#include "WireDecoder.h"
+#include "WireEncoder.h"
+
+namespace nt {
+
+/*
+ * Table Functions
+ */
+
+std::shared_ptr<Value> GetEntryValue(StringRef name) {
+ return Storage::GetInstance().GetEntryValue(name);
+}
+
+bool SetEntryValue(StringRef name, std::shared_ptr<Value> value) {
+ return Storage::GetInstance().SetEntryValue(name, value);
+}
+
+void SetEntryTypeValue(StringRef name, std::shared_ptr<Value> value) {
+ Storage::GetInstance().SetEntryTypeValue(name, value);
+}
+
+void SetEntryFlags(StringRef name, unsigned int flags) {
+ Storage::GetInstance().SetEntryFlags(name, flags);
+}
+
+unsigned int GetEntryFlags(StringRef name) {
+ return Storage::GetInstance().GetEntryFlags(name);
+}
+
+void DeleteEntry(StringRef name) {
+ Storage::GetInstance().DeleteEntry(name);
+}
+
+void DeleteAllEntries() {
+ Storage::GetInstance().DeleteAllEntries();
+}
+
+std::vector<EntryInfo> GetEntryInfo(StringRef prefix, unsigned int types) {
+ return Storage::GetInstance().GetEntryInfo(prefix, types);
+}
+
+void Flush() {
+ Dispatcher::GetInstance().Flush();
+}
+
+/*
+ * Callback Creation Functions
+ */
+
+void SetListenerOnStart(std::function<void()> on_start) {
+ Notifier::GetInstance().SetOnStart(on_start);
+}
+
+void SetListenerOnExit(std::function<void()> on_exit) {
+ Notifier::GetInstance().SetOnExit(on_exit);
+}
+
+unsigned int AddEntryListener(StringRef prefix, EntryListenerCallback callback,
+ unsigned int flags) {
+ Notifier& notifier = Notifier::GetInstance();
+ unsigned int uid = notifier.AddEntryListener(prefix, callback, flags);
+ notifier.Start();
+ if ((flags & NT_NOTIFY_IMMEDIATE) != 0)
+ Storage::GetInstance().NotifyEntries(prefix, callback);
+ return uid;
+}
+
+void RemoveEntryListener(unsigned int entry_listener_uid) {
+ Notifier::GetInstance().RemoveEntryListener(entry_listener_uid);
+}
+
+unsigned int AddConnectionListener(ConnectionListenerCallback callback,
+ bool immediate_notify) {
+ Notifier& notifier = Notifier::GetInstance();
+ unsigned int uid = notifier.AddConnectionListener(callback);
+ Notifier::GetInstance().Start();
+ if (immediate_notify) Dispatcher::GetInstance().NotifyConnections(callback);
+ return uid;
+}
+
+void RemoveConnectionListener(unsigned int conn_listener_uid) {
+ Notifier::GetInstance().RemoveConnectionListener(conn_listener_uid);
+}
+
+bool NotifierDestroyed() { return Notifier::destroyed(); }
+
+/*
+ * Remote Procedure Call Functions
+ */
+
+void CreateRpc(StringRef name, StringRef def, RpcCallback callback) {
+ Storage::GetInstance().CreateRpc(name, def, callback);
+}
+
+void CreatePolledRpc(StringRef name, StringRef def) {
+ Storage::GetInstance().CreatePolledRpc(name, def);
+}
+
+bool PollRpc(bool blocking, RpcCallInfo* call_info) {
+ return RpcServer::GetInstance().PollRpc(blocking, call_info);
+}
+
+void PostRpcResponse(unsigned int rpc_id, unsigned int call_uid,
+ StringRef result) {
+ RpcServer::GetInstance().PostRpcResponse(rpc_id, call_uid, result);
+}
+
+unsigned int CallRpc(StringRef name, StringRef params) {
+ return Storage::GetInstance().CallRpc(name, params);
+}
+
+bool GetRpcResult(bool blocking, unsigned int call_uid, std::string* result) {
+ return Storage::GetInstance().GetRpcResult(blocking, call_uid, result);
+}
+
+std::string PackRpcDefinition(const RpcDefinition& def) {
+ WireEncoder enc(0x0300);
+ enc.Write8(def.version);
+ enc.WriteString(def.name);
+
+ // parameters
+ unsigned int params_size = def.params.size();
+ if (params_size > 0xff) params_size = 0xff;
+ enc.Write8(params_size);
+ for (std::size_t i = 0; i < params_size; ++i) {
+ enc.WriteType(def.params[i].def_value->type());
+ enc.WriteString(def.params[i].name);
+ enc.WriteValue(*def.params[i].def_value);
+ }
+
+ // results
+ unsigned int results_size = def.results.size();
+ if (results_size > 0xff) results_size = 0xff;
+ enc.Write8(results_size);
+ for (std::size_t i = 0; i < results_size; ++i) {
+ enc.WriteType(def.results[i].type);
+ enc.WriteString(def.results[i].name);
+ }
+
+ return enc.ToStringRef();
+}
+
+bool UnpackRpcDefinition(StringRef packed, RpcDefinition* def) {
+ raw_mem_istream is(packed.data(), packed.size());
+ WireDecoder dec(is, 0x0300);
+ if (!dec.Read8(&def->version)) return false;
+ if (!dec.ReadString(&def->name)) return false;
+
+ // parameters
+ unsigned int params_size;
+ if (!dec.Read8(¶ms_size)) return false;
+ def->params.resize(0);
+ def->params.reserve(params_size);
+ for (std::size_t i = 0; i < params_size; ++i) {
+ RpcParamDef pdef;
+ NT_Type type;
+ if (!dec.ReadType(&type)) return false;
+ if (!dec.ReadString(&pdef.name)) return false;
+ pdef.def_value = dec.ReadValue(type);
+ if (!pdef.def_value) return false;
+ def->params.emplace_back(std::move(pdef));
+ }
+
+ // results
+ unsigned int results_size;
+ if (!dec.Read8(&results_size)) return false;
+ def->results.resize(0);
+ def->results.reserve(results_size);
+ for (std::size_t i = 0; i < results_size; ++i) {
+ RpcResultDef rdef;
+ if (!dec.ReadType(&rdef.type)) return false;
+ if (!dec.ReadString(&rdef.name)) return false;
+ def->results.emplace_back(std::move(rdef));
+ }
+
+ return true;
+}
+
+std::string PackRpcValues(ArrayRef<std::shared_ptr<Value>> values) {
+ WireEncoder enc(0x0300);
+ for (auto& value : values) enc.WriteValue(*value);
+ return enc.ToStringRef();
+}
+
+std::vector<std::shared_ptr<Value>> UnpackRpcValues(StringRef packed,
+ ArrayRef<NT_Type> types) {
+ raw_mem_istream is(packed.data(), packed.size());
+ WireDecoder dec(is, 0x0300);
+ std::vector<std::shared_ptr<Value>> vec;
+ for (auto type : types) {
+ auto item = dec.ReadValue(type);
+ if (!item) return std::vector<std::shared_ptr<Value>>();
+ vec.emplace_back(std::move(item));
+ }
+ return vec;
+}
+
+/*
+ * Client/Server Functions
+ */
+
+void SetNetworkIdentity(StringRef name) {
+ Dispatcher::GetInstance().SetIdentity(name);
+}
+
+void StartServer(StringRef persist_filename, const char *listen_address,
+ unsigned int port) {
+ Dispatcher::GetInstance().StartServer(persist_filename, listen_address, port);
+}
+
+void StopServer() {
+ Dispatcher::GetInstance().Stop();
+}
+
+void StartClient(const char *server_name, unsigned int port) {
+ Dispatcher::GetInstance().StartClient(server_name, port);
+}
+
+void StopClient() {
+ Dispatcher::GetInstance().Stop();
+}
+
+void StopRpcServer() {
+ RpcServer::GetInstance().Stop();
+}
+
+void StopNotifier() {
+ Notifier::GetInstance().Stop();
+}
+
+void SetUpdateRate(double interval) {
+ Dispatcher::GetInstance().SetUpdateRate(interval);
+}
+
+std::vector<ConnectionInfo> GetConnections() {
+ return Dispatcher::GetInstance().GetConnections();
+}
+
+/*
+ * Persistent Functions
+ */
+
+const char* SavePersistent(StringRef filename) {
+ return Storage::GetInstance().SavePersistent(filename, false);
+}
+
+const char* LoadPersistent(
+ StringRef filename,
+ std::function<void(size_t line, const char* msg)> warn) {
+ return Storage::GetInstance().LoadPersistent(filename, warn);
+}
+
+void SetLogger(LogFunc func, unsigned int min_level) {
+ Logger& logger = Logger::GetInstance();
+ logger.SetLogger(func);
+ logger.set_min_level(min_level);
+}
+
+} // namespace nt
diff --git a/src/raw_istream.cpp b/src/raw_istream.cpp
new file mode 100644
index 0000000..f300b9e
--- /dev/null
+++ b/src/raw_istream.cpp
@@ -0,0 +1,20 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "raw_istream.h"
+
+#include <cstring>
+
+using namespace nt;
+
+bool raw_mem_istream::read(void* data, std::size_t len) {
+ if (len > m_left) return false;
+ std::memcpy(data, m_cur, len);
+ m_cur += len;
+ m_left -= len;
+ return true;
+}
diff --git a/src/raw_istream.h b/src/raw_istream.h
new file mode 100644
index 0000000..f8adc23
--- /dev/null
+++ b/src/raw_istream.h
@@ -0,0 +1,40 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_RAW_ISTREAM_H_
+#define NT_RAW_ISTREAM_H_
+
+#include <cstddef>
+
+namespace nt {
+
+class raw_istream {
+ public:
+ raw_istream() = default;
+ virtual ~raw_istream() = default;
+ virtual bool read(void* data, std::size_t len) = 0;
+ virtual void close() = 0;
+
+ raw_istream(const raw_istream&) = delete;
+ raw_istream& operator=(const raw_istream&) = delete;
+};
+
+class raw_mem_istream : public raw_istream {
+ public:
+ raw_mem_istream(const char* mem, std::size_t len) : m_cur(mem), m_left(len) {}
+ virtual ~raw_mem_istream() = default;
+ virtual bool read(void* data, std::size_t len);
+ virtual void close() {}
+
+ private:
+ const char* m_cur;
+ std::size_t m_left;
+};
+
+} // namespace nt
+
+#endif // NT_RAW_ISTREAM_H_
diff --git a/src/raw_socket_istream.cpp b/src/raw_socket_istream.cpp
new file mode 100644
index 0000000..a8e71c5
--- /dev/null
+++ b/src/raw_socket_istream.cpp
@@ -0,0 +1,26 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "raw_socket_istream.h"
+
+using namespace nt;
+
+bool raw_socket_istream::read(void* data, std::size_t len) {
+ char* cdata = static_cast<char*>(data);
+ std::size_t pos = 0;
+
+ while (pos < len) {
+ NetworkStream::Error err;
+ std::size_t count =
+ m_stream.receive(&cdata[pos], len - pos, &err, m_timeout);
+ if (count == 0) return false;
+ pos += count;
+ }
+ return true;
+}
+
+void raw_socket_istream::close() { m_stream.close(); }
diff --git a/src/raw_socket_istream.h b/src/raw_socket_istream.h
new file mode 100644
index 0000000..91bcc1f
--- /dev/null
+++ b/src/raw_socket_istream.h
@@ -0,0 +1,32 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef NT_RAW_SOCKET_ISTREAM_H_
+#define NT_RAW_SOCKET_ISTREAM_H_
+
+#include "raw_istream.h"
+
+#include "tcpsockets/NetworkStream.h"
+
+namespace nt {
+
+class raw_socket_istream : public raw_istream {
+ public:
+ raw_socket_istream(NetworkStream& stream, int timeout = 0)
+ : m_stream(stream), m_timeout(timeout) {}
+ virtual ~raw_socket_istream() = default;
+ virtual bool read(void* data, std::size_t len);
+ virtual void close();
+
+ private:
+ NetworkStream& m_stream;
+ int m_timeout;
+};
+
+} // namespace nt
+
+#endif // NT_RAW_SOCKET_ISTREAM_H_
diff --git a/src/support/ConcurrentQueue.h b/src/support/ConcurrentQueue.h
new file mode 100644
index 0000000..fa99477
--- /dev/null
+++ b/src/support/ConcurrentQueue.h
@@ -0,0 +1,79 @@
+//
+// Copyright (c) 2013 Juan Palacios juan.palacios.puyana@gmail.com
+// Subject to the BSD 2-Clause License
+// - see < http://opensource.org/licenses/BSD-2-Clause>
+//
+
+#ifndef NT_SUPPORT_CONCURRENT_QUEUE_H_
+#define NT_SUPPORT_CONCURRENT_QUEUE_H_
+
+#include <queue>
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+
+template <typename T>
+class ConcurrentQueue {
+ public:
+ bool empty() const {
+ std::unique_lock<std::mutex> mlock(mutex_);
+ return queue_.empty();
+ }
+
+ typename std::queue<T>::size_type size() const {
+ std::unique_lock<std::mutex> mlock(mutex_);
+ return queue_.size();
+ }
+
+ T pop() {
+ std::unique_lock<std::mutex> mlock(mutex_);
+ while (queue_.empty()) {
+ cond_.wait(mlock);
+ }
+ auto item = std::move(queue_.front());
+ queue_.pop();
+ return item;
+ }
+
+ void pop(T& item) {
+ std::unique_lock<std::mutex> mlock(mutex_);
+ while (queue_.empty()) {
+ cond_.wait(mlock);
+ }
+ item = queue_.front();
+ queue_.pop();
+ }
+
+ void push(const T& item) {
+ std::unique_lock<std::mutex> mlock(mutex_);
+ queue_.push(item);
+ mlock.unlock();
+ cond_.notify_one();
+ }
+
+ void push(T&& item) {
+ std::unique_lock<std::mutex> mlock(mutex_);
+ queue_.push(std::forward<T>(item));
+ mlock.unlock();
+ cond_.notify_one();
+ }
+
+ template <typename... Args>
+ void emplace(Args&&... args) {
+ std::unique_lock<std::mutex> mlock(mutex_);
+ queue_.emplace(std::forward<Args>(args)...);
+ mlock.unlock();
+ cond_.notify_one();
+ }
+
+ ConcurrentQueue() = default;
+ ConcurrentQueue(const ConcurrentQueue&) = delete;
+ ConcurrentQueue& operator=(const ConcurrentQueue&) = delete;
+
+ private:
+ std::queue<T> queue_;
+ mutable std::mutex mutex_;
+ std::condition_variable cond_;
+};
+
+#endif // NT_SUPPORT_CONCURRENT_QUEUE_H_
diff --git a/src/support/timestamp.cpp b/src/support/timestamp.cpp
new file mode 100644
index 0000000..6dd4387
--- /dev/null
+++ b/src/support/timestamp.cpp
@@ -0,0 +1,89 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+#include "timestamp.h"
+
+#ifdef _WIN32
+#include <cassert>
+#include <exception>
+#include <windows.h>
+#else
+#include <chrono>
+#endif
+
+// offset in microseconds
+static unsigned long long zerotime() {
+#ifdef _WIN32
+ FILETIME ft;
+ unsigned long long tmpres = 0;
+ // 100-nanosecond intervals since January 1, 1601 (UTC)
+ // which means 0.1 us
+ GetSystemTimeAsFileTime(&ft);
+ tmpres |= ft.dwHighDateTime;
+ tmpres <<= 32;
+ tmpres |= ft.dwLowDateTime;
+ // January 1st, 1970 - January 1st, 1601 UTC ~ 369 years
+ // or 116444736000000000 us
+ static const unsigned long long deltaepoch = 116444736000000000ull;
+ tmpres -= deltaepoch;
+ return tmpres;
+#else
+ // 100-ns intervals
+ using namespace std::chrono;
+ return duration_cast<nanoseconds>(
+ high_resolution_clock::now().time_since_epoch()).count() / 100u;
+#endif
+}
+
+static unsigned long long timestamp() {
+#ifdef _WIN32
+ LARGE_INTEGER li;
+ QueryPerformanceCounter(&li);
+ // there is an imprecision with the initial value,
+ // but what matters is that timestamps are monotonic and consistent
+ return static_cast<unsigned long long>(li.QuadPart);
+#else
+ // 100-ns intervals
+ using namespace std::chrono;
+ return duration_cast<nanoseconds>(
+ steady_clock::now().time_since_epoch()).count() / 100u;
+#endif
+}
+
+#ifdef _WIN32
+static unsigned long long update_frequency() {
+ LARGE_INTEGER li;
+ if (!QueryPerformanceFrequency(&li) || !li.QuadPart) {
+ // log something
+ std::terminate();
+ }
+ return static_cast<unsigned long long>(li.QuadPart);
+}
+#endif
+
+static const unsigned long long zerotime_val = zerotime();
+static const unsigned long long offset_val = timestamp();
+#ifdef _WIN32
+static const unsigned long long frequency_val = update_frequency();
+#endif
+
+unsigned long long nt::Now() {
+#ifdef _WIN32
+ assert(offset_val > 0u);
+ assert(frequency_val > 0u);
+ unsigned long long delta = timestamp() - offset_val;
+ // because the frequency is in update per seconds, we have to multiply the
+ // delta by 10,000,000
+ unsigned long long delta_in_us = delta * 10000000ull / frequency_val;
+ return delta_in_us + zerotime_val;
+#else
+ return zerotime_val + timestamp() - offset_val;
+#endif
+}
+
+unsigned long long NT_Now() {
+ return nt::Now();
+}
diff --git a/src/support/timestamp.h b/src/support/timestamp.h
new file mode 100644
index 0000000..215aa88
--- /dev/null
+++ b/src/support/timestamp.h
@@ -0,0 +1,28 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+#ifndef NT_SUPPORT_TIMESTAMP_H_
+#define NT_SUPPORT_TIMESTAMP_H_
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+unsigned long long NT_Now(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#ifdef __cplusplus
+namespace nt {
+
+unsigned long long Now();
+
+} // namespace nt
+#endif
+
+#endif // NT_SUPPORT_TIMESTAMP_H_
diff --git a/src/tables/ITableListener.cpp b/src/tables/ITableListener.cpp
new file mode 100644
index 0000000..df1f273
--- /dev/null
+++ b/src/tables/ITableListener.cpp
@@ -0,0 +1,9 @@
+#include "tables/ITableListener.h"
+
+#include "ntcore_c.h"
+
+void ITableListener::ValueChangedEx(ITable* source, llvm::StringRef key,
+ std::shared_ptr<nt::Value> value,
+ unsigned int flags) {
+ ValueChanged(source, key, value, (flags & NT_NOTIFY_NEW) != 0);
+}
diff --git a/src/tables/TableKeyNotDefinedException.cpp b/src/tables/TableKeyNotDefinedException.cpp
new file mode 100644
index 0000000..30bedd0
--- /dev/null
+++ b/src/tables/TableKeyNotDefinedException.cpp
@@ -0,0 +1,19 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "tables/TableKeyNotDefinedException.h"
+
+TableKeyNotDefinedException::TableKeyNotDefinedException(llvm::StringRef key)
+ : msg("Unknown Table Key: ") {
+ msg += key;
+}
+
+const char* TableKeyNotDefinedException::what() const NT_NOEXCEPT {
+ return msg.c_str();
+}
+
+TableKeyNotDefinedException::~TableKeyNotDefinedException() NT_NOEXCEPT {}
diff --git a/src/tcpsockets/NetworkAcceptor.h b/src/tcpsockets/NetworkAcceptor.h
new file mode 100644
index 0000000..4702b7f
--- /dev/null
+++ b/src/tcpsockets/NetworkAcceptor.h
@@ -0,0 +1,26 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef TCPSOCKETS_NETWORKACCEPTOR_H_
+#define TCPSOCKETS_NETWORKACCEPTOR_H_
+
+#include "NetworkStream.h"
+
+class NetworkAcceptor {
+ public:
+ NetworkAcceptor() = default;
+ virtual ~NetworkAcceptor() = default;
+
+ virtual int start() = 0;
+ virtual void shutdown() = 0;
+ virtual std::unique_ptr<NetworkStream> accept() = 0;
+
+ NetworkAcceptor(const NetworkAcceptor&) = delete;
+ NetworkAcceptor& operator=(const NetworkAcceptor&) = delete;
+};
+
+#endif // TCPSOCKETS_NETWORKACCEPTOR_H_
diff --git a/src/tcpsockets/NetworkStream.h b/src/tcpsockets/NetworkStream.h
new file mode 100644
index 0000000..63aedb4
--- /dev/null
+++ b/src/tcpsockets/NetworkStream.h
@@ -0,0 +1,39 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef TCPSOCKETS_NETWORKSTREAM_H_
+#define TCPSOCKETS_NETWORKSTREAM_H_
+
+#include <cstddef>
+
+#include "llvm/StringRef.h"
+
+class NetworkStream {
+ public:
+ NetworkStream() = default;
+ virtual ~NetworkStream() = default;
+
+ enum Error {
+ kConnectionClosed = 0,
+ kConnectionReset = -1,
+ kConnectionTimedOut = -2
+ };
+
+ virtual std::size_t send(const char* buffer, std::size_t len, Error* err) = 0;
+ virtual std::size_t receive(char* buffer, std::size_t len, Error* err,
+ int timeout = 0) = 0;
+ virtual void close() = 0;
+
+ virtual llvm::StringRef getPeerIP() const = 0;
+ virtual int getPeerPort() const = 0;
+ virtual void setNoDelay() = 0;
+
+ NetworkStream(const NetworkStream&) = delete;
+ NetworkStream& operator=(const NetworkStream&) = delete;
+};
+
+#endif // TCPSOCKETS_NETWORKSTREAM_H_
diff --git a/src/tcpsockets/SocketError.cpp b/src/tcpsockets/SocketError.cpp
new file mode 100644
index 0000000..9619edd
--- /dev/null
+++ b/src/tcpsockets/SocketError.cpp
@@ -0,0 +1,31 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#include "SocketError.h"
+
+#ifdef _WIN32
+#include <windows.h>
+#else
+#include <string.h>
+#endif
+
+namespace tcpsockets {
+
+std::string SocketStrerror(int code) {
+#ifdef _WIN32
+ LPSTR errstr = nullptr;
+ FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
+ 0, code, 0, (LPSTR)&errstr, 0, 0);
+ std::string rv(errstr);
+ LocalFree(errstr);
+ return rv;
+#else
+ return strerror(code);
+#endif
+}
+
+} // namespace tcpsockets
diff --git a/src/tcpsockets/SocketError.h b/src/tcpsockets/SocketError.h
new file mode 100644
index 0000000..267e8da
--- /dev/null
+++ b/src/tcpsockets/SocketError.h
@@ -0,0 +1,37 @@
+/*----------------------------------------------------------------------------*/
+/* Copyright (c) FIRST 2015. All Rights Reserved. */
+/* Open Source Software - may be modified and shared by FRC teams. The code */
+/* must be accompanied by the FIRST BSD license file in the root directory of */
+/* the project. */
+/*----------------------------------------------------------------------------*/
+
+#ifndef TCPSOCKETS_SOCKETERROR_H_
+#define TCPSOCKETS_SOCKETERROR_H_
+
+#include <string>
+
+#ifdef _WIN32
+#include <WinSock2.h>
+#else
+#include <errno.h>
+#endif
+
+namespace tcpsockets {
+
+static inline int SocketErrno() {
+#ifdef _WIN32
+ return WSAGetLastError();
+#else
+ return errno;
+#endif
+}
+
+std::string SocketStrerror(int code);
+
+static inline std::string SocketStrerror() {
+ return SocketStrerror(SocketErrno());
+}
+
+} // namespace tcpsockets
+
+#endif // TCPSOCKETS_SOCKETERROR_H_
diff --git a/src/tcpsockets/TCPAcceptor.cpp b/src/tcpsockets/TCPAcceptor.cpp
new file mode 100644
index 0000000..3b7e16e
--- /dev/null
+++ b/src/tcpsockets/TCPAcceptor.cpp
@@ -0,0 +1,182 @@
+/*
+ TCPAcceptor.cpp
+
+ TCPAcceptor class definition. TCPAcceptor provides methods to passively
+ establish TCP/IP connections with clients.
+
+ ------------------------------------------
+
+ Copyright © 2013 [Vic Hargrave - http://vichargrave.com]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+#include "TCPAcceptor.h"
+
+#include <cstdio>
+#include <cstring>
+#ifdef _WIN32
+#include <WinSock2.h>
+#pragma comment(lib, "Ws2_32.lib")
+#else
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <unistd.h>
+#include <fcntl.h>
+#endif
+
+#include "llvm/SmallString.h"
+#include "../Log.h"
+#include "SocketError.h"
+
+using namespace tcpsockets;
+
+TCPAcceptor::TCPAcceptor(int port, const char* address)
+ : m_lsd(0),
+ m_port(port),
+ m_address(address),
+ m_listening(false) {
+ m_shutdown = false;
+#ifdef _WIN32
+ WSAData wsaData;
+ WORD wVersionRequested = MAKEWORD(2, 2);
+ WSAStartup(wVersionRequested, &wsaData);
+#endif
+}
+
+TCPAcceptor::~TCPAcceptor() {
+ if (m_lsd > 0) {
+ shutdown();
+#ifdef _WIN32
+ closesocket(m_lsd);
+#else
+ close(m_lsd);
+#endif
+ }
+#ifdef _WIN32
+ WSACleanup();
+#endif
+}
+
+int TCPAcceptor::start() {
+ if (m_listening) return 0;
+
+ m_lsd = socket(PF_INET, SOCK_STREAM, 0);
+ struct sockaddr_in address;
+
+ std::memset(&address, 0, sizeof(address));
+ address.sin_family = PF_INET;
+ if (m_address.size() > 0) {
+#ifdef _WIN32
+ llvm::SmallString<128> addr_copy(m_address);
+ addr_copy.push_back('\0');
+ int size = sizeof(address);
+ WSAStringToAddress(addr_copy.data(), PF_INET, nullptr, (struct sockaddr*)&address, &size);
+#else
+ inet_pton(PF_INET, m_address.c_str(), &(address.sin_addr));
+#endif
+ } else {
+ address.sin_addr.s_addr = INADDR_ANY;
+ }
+ address.sin_port = htons(m_port);
+
+ int optval = 1;
+ setsockopt(m_lsd, SOL_SOCKET, SO_REUSEADDR, (char*)&optval, sizeof optval);
+
+ int result = bind(m_lsd, (struct sockaddr*)&address, sizeof(address));
+ if (result != 0) {
+ ERROR("bind() failed: " << SocketStrerror());
+ return result;
+ }
+
+ result = listen(m_lsd, 5);
+ if (result != 0) {
+ ERROR("listen() failed: " << SocketStrerror());
+ return result;
+ }
+ m_listening = true;
+ return result;
+}
+
+void TCPAcceptor::shutdown() {
+ m_shutdown = true;
+#ifdef _WIN32
+ ::shutdown(m_lsd, SD_BOTH);
+
+ // this is ugly, but the easiest way to do this
+ // force wakeup of accept() with a non-blocking connect to ourselves
+ struct sockaddr_in address;
+
+ std::memset(&address, 0, sizeof(address));
+ address.sin_family = PF_INET;
+ llvm::SmallString<128> addr_copy;
+ if (m_address.size() > 0)
+ addr_copy = m_address;
+ else
+ addr_copy = "127.0.0.1";
+ addr_copy.push_back('\0');
+ int size = sizeof(address);
+ if (WSAStringToAddress(addr_copy.data(), PF_INET, nullptr,
+ (struct sockaddr*)&address, &size) != 0)
+ return;
+ address.sin_port = htons(m_port);
+
+ fd_set sdset;
+ struct timeval tv;
+ int result = -1, valopt, sd = socket(AF_INET, SOCK_STREAM, 0);
+
+ // Set socket to non-blocking
+ u_long mode = 1;
+ ioctlsocket(sd, FIONBIO, &mode);
+
+ // Try to connect
+ ::connect(sd, (struct sockaddr*)&address, sizeof(address));
+
+ // Close
+ ::closesocket(sd);
+
+#else
+ ::shutdown(m_lsd, SHUT_RDWR);
+ int nullfd = ::open("/dev/null", O_RDONLY);
+ if (nullfd >= 0) {
+ ::dup2(nullfd, m_lsd);
+ ::close(nullfd);
+ }
+#endif
+}
+
+std::unique_ptr<NetworkStream> TCPAcceptor::accept() {
+ if (!m_listening || m_shutdown) return nullptr;
+
+ struct sockaddr_in address;
+#ifdef _WIN32
+ int len = sizeof(address);
+#else
+ socklen_t len = sizeof(address);
+#endif
+ std::memset(&address, 0, sizeof(address));
+ int sd = ::accept(m_lsd, (struct sockaddr*)&address, &len);
+ if (sd < 0) {
+ if (!m_shutdown) ERROR("accept() failed: " << SocketStrerror());
+ return nullptr;
+ }
+ if (m_shutdown) {
+#ifdef _WIN32
+ closesocket(sd);
+#else
+ close(sd);
+#endif
+ return nullptr;
+ }
+ return std::unique_ptr<NetworkStream>(new TCPStream(sd, &address));
+}
diff --git a/src/tcpsockets/TCPAcceptor.h b/src/tcpsockets/TCPAcceptor.h
new file mode 100644
index 0000000..d6a6ccc
--- /dev/null
+++ b/src/tcpsockets/TCPAcceptor.h
@@ -0,0 +1,50 @@
+/*
+ TCPAcceptor.h
+
+ TCPAcceptor class interface. TCPAcceptor provides methods to passively
+ establish TCP/IP connections with clients.
+
+ ------------------------------------------
+
+ Copyright © 2013 [Vic Hargrave - http://vichargrave.com]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+#ifndef TCPSOCKETS_TCPACCEPTOR_H_
+#define TCPSOCKETS_TCPACCEPTOR_H_
+
+#include <atomic>
+#include <memory>
+#include <string>
+
+#include "NetworkAcceptor.h"
+#include "TCPStream.h"
+
+class TCPAcceptor : public NetworkAcceptor {
+ int m_lsd;
+ int m_port;
+ std::string m_address;
+ bool m_listening;
+ std::atomic_bool m_shutdown;
+
+ public:
+ TCPAcceptor(int port, const char* address);
+ ~TCPAcceptor();
+
+ int start() override;
+ void shutdown() override;
+ std::unique_ptr<NetworkStream> accept() override;
+};
+
+#endif
diff --git a/src/tcpsockets/TCPConnector.cpp b/src/tcpsockets/TCPConnector.cpp
new file mode 100644
index 0000000..665500d
--- /dev/null
+++ b/src/tcpsockets/TCPConnector.cpp
@@ -0,0 +1,167 @@
+/*
+ TCPConnector.h
+
+ TCPConnector class definition. TCPConnector provides methods to actively
+ establish TCP/IP connections with a server.
+
+ ------------------------------------------
+
+ Copyright © 2013 [Vic Hargrave - http://vichargrave.com]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License
+*/
+
+#include "TCPConnector.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <cstdio>
+#include <cstring>
+#ifdef _WIN32
+#include <WinSock2.h>
+#include <WS2tcpip.h>
+#else
+#include <netdb.h>
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <sys/select.h>
+#endif
+
+#include "TCPStream.h"
+
+#include "llvm/SmallString.h"
+#include "../Log.h"
+#include "SocketError.h"
+
+using namespace tcpsockets;
+
+static int ResolveHostName(const char* hostname, struct in_addr* addr) {
+ struct addrinfo hints;
+ struct addrinfo* res;
+
+ hints.ai_flags = 0;
+ hints.ai_family = AF_INET;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = 0;
+ hints.ai_addrlen = 0;
+ hints.ai_addr = nullptr;
+ hints.ai_canonname = nullptr;
+ hints.ai_next = nullptr;
+ int result = getaddrinfo(hostname, nullptr, &hints, &res);
+ if (result == 0) {
+ std::memcpy(addr, &((struct sockaddr_in*)res->ai_addr)->sin_addr,
+ sizeof(struct in_addr));
+ freeaddrinfo(res);
+ }
+ return result;
+}
+
+std::unique_ptr<NetworkStream> TCPConnector::connect(const char* server,
+ int port, int timeout) {
+#ifdef _WIN32
+ struct WSAHelper {
+ WSAHelper() {
+ WSAData wsaData;
+ WORD wVersionRequested = MAKEWORD(2, 2);
+ WSAStartup(wVersionRequested, &wsaData);
+ }
+ ~WSAHelper() { WSACleanup(); }
+ };
+ static WSAHelper helper;
+#endif
+ struct sockaddr_in address;
+
+ std::memset(&address, 0, sizeof(address));
+ address.sin_family = AF_INET;
+ if (ResolveHostName(server, &(address.sin_addr)) != 0) {
+#ifdef _WIN32
+ llvm::SmallString<128> addr_copy(server);
+ addr_copy.push_back('\0');
+ int size = sizeof(address);
+ if (WSAStringToAddress(addr_copy.data(), PF_INET, nullptr, (struct sockaddr*)&address, &size) != 0) {
+ ERROR("could not resolve " << server << " address");
+ return nullptr;
+ }
+#else
+ inet_pton(PF_INET, server, &(address.sin_addr));
+#endif
+ }
+ address.sin_port = htons(port);
+
+ if (timeout == 0) {
+ int sd = socket(AF_INET, SOCK_STREAM, 0);
+ if (::connect(sd, (struct sockaddr*)&address, sizeof(address)) != 0) {
+ ERROR("connect() to " << server << " port " << port << " failed: " << SocketStrerror());
+ return nullptr;
+ }
+ return std::unique_ptr<NetworkStream>(new TCPStream(sd, &address));
+ }
+
+ fd_set sdset;
+ struct timeval tv;
+ socklen_t len;
+ int result = -1, valopt, sd = socket(AF_INET, SOCK_STREAM, 0);
+
+ // Set socket to non-blocking
+#ifdef _WIN32
+ u_long mode = 1;
+ ioctlsocket(sd, FIONBIO, &mode);
+#else
+ long arg;
+ arg = fcntl(sd, F_GETFL, nullptr);
+ arg |= O_NONBLOCK;
+ fcntl(sd, F_SETFL, arg);
+#endif
+
+ // Connect with time limit
+ if ((result = ::connect(sd, (struct sockaddr*)&address, sizeof(address))) <
+ 0) {
+ int my_errno = SocketErrno();
+#ifdef _WIN32
+ if (my_errno == WSAEWOULDBLOCK || my_errno == WSAEINPROGRESS) {
+#else
+ if (my_errno == EWOULDBLOCK || my_errno == EINPROGRESS) {
+#endif
+ tv.tv_sec = timeout;
+ tv.tv_usec = 0;
+ FD_ZERO(&sdset);
+ FD_SET(sd, &sdset);
+ if (select(sd + 1, nullptr, &sdset, nullptr, &tv) > 0) {
+ len = sizeof(int);
+ getsockopt(sd, SOL_SOCKET, SO_ERROR, (char*)(&valopt), &len);
+ if (valopt) {
+ ERROR("select() to " << server << " port " << port << " error " << valopt << " - " << SocketStrerror(valopt));
+ }
+ // connection established
+ else
+ result = 0;
+ } else
+ INFO("connect() to " << server << " port " << port << " timed out");
+ } else
+ ERROR("connect() to " << server << " port " << port << " error " << SocketErrno() << " - " << SocketStrerror());
+ }
+
+ // Return socket to blocking mode
+#ifdef _WIN32
+ mode = 0;
+ ioctlsocket(sd, FIONBIO, &mode);
+#else
+ arg = fcntl(sd, F_GETFL, nullptr);
+ arg &= (~O_NONBLOCK);
+ fcntl(sd, F_SETFL, arg);
+#endif
+
+ // Create stream object if connected
+ if (result == -1) return nullptr;
+ return std::unique_ptr<NetworkStream>(new TCPStream(sd, &address));
+}
diff --git a/src/tcpsockets/TCPConnector.h b/src/tcpsockets/TCPConnector.h
new file mode 100644
index 0000000..ebac859
--- /dev/null
+++ b/src/tcpsockets/TCPConnector.h
@@ -0,0 +1,37 @@
+/*
+ TCPConnector.h
+
+ TCPConnector class interface. TCPConnector provides methods to actively
+ establish TCP/IP connections with a server.
+
+ ------------------------------------------
+
+ Copyright © 2013 [Vic Hargrave - http://vichargrave.com]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License
+*/
+
+#ifndef TCPSOCKETS_TCPCONNECTOR_H_
+#define TCPSOCKETS_TCPCONNECTOR_H_
+
+#include <memory>
+
+#include "NetworkStream.h"
+
+class TCPConnector {
+ public:
+ static std::unique_ptr<NetworkStream> connect(const char* server, int port,
+ int timeout = 0);
+};
+
+#endif
diff --git a/src/tcpsockets/TCPStream.cpp b/src/tcpsockets/TCPStream.cpp
new file mode 100644
index 0000000..3149be6
--- /dev/null
+++ b/src/tcpsockets/TCPStream.cpp
@@ -0,0 +1,157 @@
+/*
+ TCPStream.h
+
+ TCPStream class definition. TCPStream provides methods to trasnfer
+ data between peers over a TCP/IP connection.
+
+ ------------------------------------------
+
+ Copyright © 2013 [Vic Hargrave - http://vichargrave.com]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+#include "TCPStream.h"
+
+#ifdef _WIN32
+#include <WinSock2.h>
+#else
+#include <arpa/inet.h>
+#include <netinet/tcp.h>
+#include <unistd.h>
+#endif
+
+TCPStream::TCPStream(int sd, struct sockaddr_in* address) : m_sd(sd) {
+ char ip[50];
+#ifdef _WIN32
+ unsigned long size = sizeof(ip) - 1;
+ WSAAddressToString((struct sockaddr*)address, sizeof sockaddr_in, nullptr, ip, &size);
+#else
+ inet_ntop(PF_INET, (struct in_addr*)&(address->sin_addr.s_addr), ip,
+ sizeof(ip) - 1);
+#endif
+ m_peerIP = ip;
+ m_peerPort = ntohs(address->sin_port);
+}
+
+TCPStream::~TCPStream() { close(); }
+
+std::size_t TCPStream::send(const char* buffer, std::size_t len, Error* err) {
+ if (m_sd < 0) {
+ *err = kConnectionClosed;
+ return 0;
+ }
+#ifdef _WIN32
+ WSABUF wsaBuf;
+ wsaBuf.buf = const_cast<char*>(buffer);
+ wsaBuf.len = (ULONG)len;
+ DWORD rv;
+ bool result = true;
+ while (WSASend(m_sd, &wsaBuf, 1, &rv, 0, nullptr, nullptr) == SOCKET_ERROR) {
+ if (WSAGetLastError() != WSAEWOULDBLOCK) {
+ result = false;
+ break;
+ }
+ Sleep(1);
+ }
+ if (!result) {
+ char Buffer[128];
+#ifdef _MSC_VER
+ sprintf_s(Buffer, "Send() failed: WSA error=%d\n", WSAGetLastError());
+#else
+ std::snprintf(Buffer, 128, "Send() failed: WSA error=%d\n", WSAGetLastError());
+#endif
+ OutputDebugStringA(Buffer);
+ *err = kConnectionReset;
+ return 0;
+ }
+#else
+ ssize_t rv = write(m_sd, buffer, len);
+ if (rv < 0) {
+ *err = kConnectionReset;
+ return 0;
+ }
+#endif
+ return static_cast<std::size_t>(rv);
+}
+
+std::size_t TCPStream::receive(char* buffer, std::size_t len, Error* err,
+ int timeout) {
+ if (m_sd < 0) {
+ *err = kConnectionClosed;
+ return 0;
+ }
+#ifdef _WIN32
+ int rv;
+#else
+ ssize_t rv;
+#endif
+ if (timeout <= 0) {
+#ifdef _WIN32
+ rv = recv(m_sd, buffer, len, 0);
+#else
+ rv = read(m_sd, buffer, len);
+#endif
+ }
+ else if (WaitForReadEvent(timeout)) {
+#ifdef _WIN32
+ rv = recv(m_sd, buffer, len, 0);
+#else
+ rv = read(m_sd, buffer, len);
+#endif
+ } else {
+ *err = kConnectionTimedOut;
+ return 0;
+ }
+ if (rv < 0) {
+ *err = kConnectionReset;
+ return 0;
+ }
+ return static_cast<std::size_t>(rv);
+}
+
+void TCPStream::close() {
+ if (m_sd >= 0) {
+#ifdef _WIN32
+ ::shutdown(m_sd, SD_BOTH);
+ closesocket(m_sd);
+#else
+ ::shutdown(m_sd, SHUT_RDWR);
+ ::close(m_sd);
+#endif
+ }
+ m_sd = -1;
+}
+
+llvm::StringRef TCPStream::getPeerIP() const { return m_peerIP; }
+
+int TCPStream::getPeerPort() const { return m_peerPort; }
+
+void TCPStream::setNoDelay() {
+ int optval = 1;
+ setsockopt(m_sd, IPPROTO_TCP, TCP_NODELAY, (char*)&optval, sizeof optval);
+}
+
+bool TCPStream::WaitForReadEvent(int timeout) {
+ fd_set sdset;
+ struct timeval tv;
+
+ tv.tv_sec = timeout;
+ tv.tv_usec = 0;
+ FD_ZERO(&sdset);
+ FD_SET(m_sd, &sdset);
+ if (select(m_sd + 1, &sdset, NULL, NULL, &tv) > 0) {
+ return true;
+ }
+ return false;
+}
diff --git a/src/tcpsockets/TCPStream.h b/src/tcpsockets/TCPStream.h
new file mode 100644
index 0000000..21ef6fd
--- /dev/null
+++ b/src/tcpsockets/TCPStream.h
@@ -0,0 +1,67 @@
+/*
+ TCPStream.h
+
+ TCPStream class interface. TCPStream provides methods to trasnfer
+ data between peers over a TCP/IP connection.
+
+ ------------------------------------------
+
+ Copyright © 2013 [Vic Hargrave - http://vichargrave.com]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+#ifndef TCPSOCKETS_TCPSTREAM_H_
+#define TCPSOCKETS_TCPSTREAM_H_
+
+#include <cstddef>
+#include <string>
+
+#ifdef _WIN32
+#include <winsock2.h>
+#else
+#include <sys/socket.h>
+#endif
+
+#include "NetworkStream.h"
+
+class TCPStream : public NetworkStream {
+ int m_sd;
+ std::string m_peerIP;
+ int m_peerPort;
+
+ public:
+ friend class TCPAcceptor;
+ friend class TCPConnector;
+
+ ~TCPStream();
+
+ std::size_t send(const char* buffer, std::size_t len, Error* err) override;
+ std::size_t receive(char* buffer, std::size_t len, Error* err,
+ int timeout = 0) override;
+ void close() override;
+
+ llvm::StringRef getPeerIP() const override;
+ int getPeerPort() const override;
+ void setNoDelay() override;
+
+ TCPStream(const TCPStream& stream) = delete;
+ TCPStream& operator=(const TCPStream&) = delete;
+ private:
+ bool WaitForReadEvent(int timeout);
+
+ TCPStream(int sd, struct sockaddr_in* address);
+ TCPStream() = delete;
+};
+
+#endif