Fix crash in message_bridge_server
* The crash in `message_bridge_server` was due to an
attempt to access an unset or empty `std::optional`
that led to an exception being raised.
* The fix here is to access the `std::optional` only
if a value had been set.
* The function `MaybeIncrementInvalidConnectionCount`
has been moved from `MessageBridgeServer` to
`MessageBridgeServerStatus` since it was altering
only the state of `MessageBridgeServerStatus`.
* A unit test has been added for coverage of this
function.
Change-Id: I54ed0968a78741d3ca8ea0d9312fddcfb599d633
Signed-off-by: James Kuszmaul <james.kuszmaul@bluerivertech.com>
diff --git a/aos/network/BUILD b/aos/network/BUILD
index db62286..e31b77f 100644
--- a/aos/network/BUILD
+++ b/aos/network/BUILD
@@ -613,6 +613,23 @@
],
)
+cc_test(
+ name = "message_bridge_server_status_test",
+ srcs = [
+ "message_bridge_server_status_test.cc",
+ ],
+ data = [
+ ":message_bridge_test_combined_timestamps_common_config",
+ ],
+ target_compatible_with = ["@platforms//os:linux"],
+ deps = [
+ ":message_bridge_server_status",
+ "//aos/events:simulated_event_loop",
+ "//aos/testing:googletest",
+ "//aos/testing:path",
+ ],
+)
+
flatbuffer_cc_library(
name = "web_proxy_fbs",
srcs = ["web_proxy.fbs"],
diff --git a/aos/network/message_bridge_server_lib.cc b/aos/network/message_bridge_server_lib.cc
index a2df830..caecc81 100644
--- a/aos/network/message_bridge_server_lib.cc
+++ b/aos/network/message_bridge_server_lib.cc
@@ -672,44 +672,13 @@
HandleData(message.get());
break;
case Message::kOverflow:
- MaybeIncrementInvalidConnectionCount(nullptr);
+ server_status_.MaybeIncrementInvalidConnectionCount(nullptr);
NodeDisconnected(message->header.rcvinfo.rcv_assoc_id);
break;
}
server_.FreeMessage(std::move(message));
}
-void MessageBridgeServer::MaybeIncrementInvalidConnectionCount(
- const Node *node) {
- server_status_.increment_invalid_connection_count();
-
- if (node == nullptr) {
- return;
- }
-
- if (!node->has_name()) {
- return;
- }
-
- const aos::Node *client_node = configuration::GetNode(
- event_loop_->configuration(), node->name()->string_view());
-
- if (client_node == nullptr) {
- return;
- }
-
- const int node_index =
- configuration::GetNodeIndex(event_loop_->configuration(), client_node);
-
- ServerConnection *connection =
- server_status_.nodes()[node_index].value().server_connection;
-
- if (connection != nullptr) {
- connection->mutate_invalid_connection_count(
- connection->invalid_connection_count() + 1);
- }
-}
-
void MessageBridgeServer::HandleData(const Message *message) {
VLOG(2) << "Received data of length " << message->size;
@@ -725,7 +694,7 @@
}
server_.Abort(message->header.rcvinfo.rcv_assoc_id);
- MaybeIncrementInvalidConnectionCount(nullptr);
+ server_status_.MaybeIncrementInvalidConnectionCount(nullptr);
return;
}
}
@@ -737,7 +706,7 @@
}
server_.Abort(message->header.rcvinfo.rcv_assoc_id);
- MaybeIncrementInvalidConnectionCount(connect->node());
+ server_status_.MaybeIncrementInvalidConnectionCount(connect->node());
return;
}
@@ -750,7 +719,7 @@
}
server_.Abort(message->header.rcvinfo.rcv_assoc_id);
- MaybeIncrementInvalidConnectionCount(connect->node());
+ server_status_.MaybeIncrementInvalidConnectionCount(connect->node());
return;
}
@@ -762,7 +731,7 @@
}
server_.Abort(message->header.rcvinfo.rcv_assoc_id);
- MaybeIncrementInvalidConnectionCount(connect->node());
+ server_status_.MaybeIncrementInvalidConnectionCount(connect->node());
return;
}
@@ -803,7 +772,7 @@
}
server_.Abort(message->header.rcvinfo.rcv_assoc_id);
- MaybeIncrementInvalidConnectionCount(connect->node());
+ server_status_.MaybeIncrementInvalidConnectionCount(connect->node());
return;
}
++channel_index;
diff --git a/aos/network/message_bridge_server_lib.h b/aos/network/message_bridge_server_lib.h
index b47a4e6..1c3903f 100644
--- a/aos/network/message_bridge_server_lib.h
+++ b/aos/network/message_bridge_server_lib.h
@@ -192,10 +192,6 @@
// received.
void HandleData(const Message *message);
- // Increments the invalid connection count overall, and per node if we know
- // which node (ie, node is not nullptr).
- void MaybeIncrementInvalidConnectionCount(const Node *node);
-
// The maximum number of channels we support on a single connection. We need
// to configure the SCTP socket with this before any clients connect, so we
// need an upper bound on the number of channels any of them will use.
diff --git a/aos/network/message_bridge_server_status.cc b/aos/network/message_bridge_server_status.cc
index 0e8c6b0..4df717f 100644
--- a/aos/network/message_bridge_server_status.cc
+++ b/aos/network/message_bridge_server_status.cc
@@ -506,4 +506,47 @@
kPingPeriod);
}
+void MessageBridgeServerStatus::MaybeIncrementInvalidConnectionCount(
+ const Node *node) {
+ increment_invalid_connection_count();
+
+ if (node == nullptr) {
+ return;
+ }
+
+ if (!node->has_name()) {
+ return;
+ }
+
+ const aos::Node *client_node = configuration::GetNode(
+ event_loop_->configuration(), node->name()->string_view());
+
+ if (client_node == nullptr) {
+ return;
+ }
+
+ const int node_index =
+ configuration::GetNodeIndex(event_loop_->configuration(), client_node);
+
+ const std::vector<std::optional<MessageBridgeServerStatus::NodeState>>
+ &server_nodes = nodes();
+ // There is a chance that there is no server node for the given client
+ // `node_index`. This can happen if the other node has a different
+ // configuration such that it starts forwarding messages to the current node,
+ // but the current node's configuration does not expect messages from the
+ // other node. This is likely to happen during a multi-node software update
+ // where the other node has been updated with a different config, while the
+ // current node's update hasn't yet completed. In such cases, we want to
+ // ensure that a server node exists before attempting to access it.
+ if (server_nodes[node_index]) {
+ ServerConnection *connection =
+ server_nodes[node_index].value().server_connection;
+
+ if (connection != nullptr) {
+ connection->mutate_invalid_connection_count(
+ connection->invalid_connection_count() + 1);
+ }
+ }
+}
+
} // namespace aos::message_bridge
diff --git a/aos/network/message_bridge_server_status.h b/aos/network/message_bridge_server_status.h
index 8bdf4a8..1248c27 100644
--- a/aos/network/message_bridge_server_status.h
+++ b/aos/network/message_bridge_server_status.h
@@ -113,6 +113,10 @@
// connection that got rejected.
void increment_invalid_connection_count() { ++invalid_connection_count_; }
+ // Increments the invalid connection count overall, and per node if we know
+ // which node (ie, node is not nullptr).
+ void MaybeIncrementInvalidConnectionCount(const Node *node);
+
private:
static constexpr std::chrono::nanoseconds kStatisticsPeriod =
std::chrono::seconds(1);
diff --git a/aos/network/message_bridge_server_status_test.cc b/aos/network/message_bridge_server_status_test.cc
new file mode 100644
index 0000000..8b9aefa
--- /dev/null
+++ b/aos/network/message_bridge_server_status_test.cc
@@ -0,0 +1,44 @@
+#include "aos/network/message_bridge_server_status.h"
+
+#include "gtest/gtest.h"
+
+#include "aos/events/simulated_event_loop.h"
+
+namespace aos::message_bridge::testing {
+
+TEST(MessageBridgeServerStatus, NoThrowOnInvalidServerNode) {
+ aos::FlatbufferDetachedBuffer<aos::Configuration> config(
+ aos::configuration::ReadConfig(
+ "message_bridge_test_combined_timestamps_common_config.json"));
+ aos::SimulatedEventLoopFactory factory(&config.message());
+ // Configure the server node to be `pi1` - for details
+ // on the configuration, refer to
+ // `message_bridge_test_combined_timestamps_common.json`.
+ std::unique_ptr<EventLoop> event_loop =
+ factory.GetNodeEventLoopFactory("pi1")->MakeEventLoop("test");
+ MessageBridgeServerStatus server_status(event_loop.get());
+ // We want to choose a client node such that there is no server for that
+ // client on this node. A simple way to do this is to choose the client node
+ // to be the same as the server node. There will never be a valid `NodeState`
+ // object assigned in `MessageBridgeServerStatus::nodes_`, which is an
+ // `std::vector` of `std::optional<NodeState> elements`. This is because a
+ // node will not be allowed to forward messages to itself since that would
+ // cause a loop of the same message being forwarded over-and-over again. We're
+ // making use of this property to simulate a multi-node software update
+ // scenario in which one node was upgraded to a config that had a valid
+ // connection to another node and started forwarding messages to the other
+ // node. Since the other node was in the process of being updated to the new
+ // software, it did not have the updated config yet, and couldn't find a
+ // server node corresponding to the client node. In this situation,
+ // `MaybeIncrementInvalidConnectionCount()` ended-up accessing an
+ // `std::optional` that was unset. As a regression test, we want to ensure
+ // that no exceptions are raised in this scenario, now that the proper checks
+ // have been added.
+ const aos::Node *client_node =
+ aos::configuration::GetNode(&config.message(), "pi1");
+ EXPECT_NE(client_node, nullptr);
+ EXPECT_NO_THROW(
+ server_status.MaybeIncrementInvalidConnectionCount(client_node));
+}
+
+} // namespace aos::message_bridge::testing
\ No newline at end of file