blob: faa63985a1511b10857e29dda0d79dbb1f7878ce [file] [log] [blame]
#ifndef AOS_EVENTS_SIMULATED_NETWORK_BRIDGE_H_
#define AOS_EVENTS_SIMULATED_NETWORK_BRIDGE_H_
#include "aos/events/event_loop.h"
#include "aos/events/simulated_event_loop.h"
#include "aos/network/message_bridge_client_status.h"
#include "aos/network/message_bridge_server_status.h"
#include "aos/network/remote_message_generated.h"
#include "aos/network/timestamp_channel.h"
namespace aos {
namespace message_bridge {
class RawMessageDelayer;
// This class moves messages between nodes. It is implemented as a separate
// class because it would have been even harder to manage forwarding in the
// SimulatedEventLoopFactory.
class SimulatedMessageBridge {
public:
// Constructs the bridge.
SimulatedMessageBridge(
SimulatedEventLoopFactory *simulated_event_loop_factory);
~SimulatedMessageBridge();
// Disables forwarding for this channel. This should be used very rarely only
// for things like the logger.
void DisableForwarding(const Channel *channel);
void Disconnect(const Node *source, const Node *other);
void Connect(const Node *source, const Node *other);
void SetState(const Node *source, const Node *other,
message_bridge::State state);
// Disables generating and sending the messages which message_gateway sends.
// The messages are the ClientStatistics, ServerStatistics and Timestamp
// messages.
enum class DestroySenders { kNo, kYes };
void DisableStatistics(DestroySenders destroy_senders = DestroySenders::kNo);
void DisableStatistics(const Node *node,
DestroySenders destroy_senders = DestroySenders::kNo);
void EnableStatistics();
void EnableStatistics(const Node *node);
private:
struct DelayersVector {
std::vector<std::unique_ptr<RawMessageDelayer>> v;
bool disable_forwarding = false;
};
struct State {
State(NodeEventLoopFactory *node_factory) : node_factory_(node_factory) {
const size_t num_nodes = node_factory->configuration()->nodes()->size();
boot_uuids_.resize(num_nodes, UUID::Zero());
client_state_.resize(num_nodes, message_bridge::State::CONNECTED);
server_state_.resize(num_nodes, message_bridge::State::CONNECTED);
}
State(const State &state) = delete;
void DisableStatistics(DestroySenders destroy_senders) {
disable_statistics_ = true;
destroy_senders_ = destroy_senders;
if (server_status) {
server_status->DisableStatistics(destroy_senders ==
DestroySenders::kYes);
}
if (client_status) {
client_status->DisableStatistics(destroy_senders ==
DestroySenders::kYes);
}
}
void EnableStatistics() {
disable_statistics_ = false;
if (server_status) {
server_status->EnableStatistics();
}
if (client_status) {
client_status->EnableStatistics();
}
}
void AddSourceDelayer(RawMessageDelayer *delayer) {
source_delayers_.emplace_back(delayer);
}
void AddDestinationDelayer(RawMessageDelayer *delayer) {
destination_delayers_.emplace_back(delayer);
}
void MakeEventLoop() {
// Message bridge isn't the thing that should be catching sent-too-fast,
// and may need to be able to forward too-fast messages replayed from old
// logfiles.
SetEventLoop(node_factory_->MakeEventLoop(
"message_bridge", {NodeEventLoopFactory::CheckSentTooFast::kNo,
NodeEventLoopFactory::ExclusiveSenders::kNo,
{}}));
}
void SetEventLoop(std::unique_ptr<aos::EventLoop> loop);
void SetSendData(std::function<void()> fn) {
CHECK(!fn_);
fn_ = std::move(fn);
if (server_status) {
server_status->set_send_data(fn_);
}
}
void AddDelayerWatcher(const Channel *channel, DelayersVector *v) {
delayer_watchers_.emplace_back(channel, v);
}
void SetBootUUID(size_t node_index, const UUID &boot_uuid) {
boot_uuids_[node_index] = boot_uuid;
const Node *node =
node_factory_->configuration()->nodes()->Get(node_index);
if (server_status) {
ServerConnection *connection =
server_status->FindServerConnection(node);
if (connection) {
if (boot_uuid == UUID::Zero()) {
server_status->Disconnect(node_index);
server_status->ResetFilter(node_index);
} else {
switch (server_state_[node_index]) {
case message_bridge::State::DISCONNECTED:
server_status->Disconnect(node_index);
break;
case message_bridge::State::CONNECTED:
server_status->Connect(node_index, event_loop->monotonic_now());
break;
}
server_status->ResetFilter(node_index);
server_status->SetBootUUID(node_index, boot_uuid);
}
}
}
if (client_status) {
const int client_index =
client_status->FindClientIndex(node->name()->string_view());
client_status->SampleReset(client_index);
if (boot_uuid == UUID::Zero()) {
client_status->Disconnect(client_index);
} else {
switch (client_state_[node_index]) {
case message_bridge::State::CONNECTED:
client_status->Connect(client_index);
break;
case message_bridge::State::DISCONNECTED:
client_status->Disconnect(client_index);
break;
}
}
}
}
void SetServerState(const Node *destination, message_bridge::State state) {
const size_t node_index = configuration::GetNodeIndex(
node_factory_->configuration(), destination);
server_state_[node_index] = state;
if (server_status) {
ServerConnection *connection =
server_status->FindServerConnection(destination);
if (connection == nullptr) return;
if (state == connection->state()) {
return;
}
switch (state) {
case message_bridge::State::DISCONNECTED:
server_status->Disconnect(node_index);
break;
case message_bridge::State::CONNECTED:
server_status->Connect(node_index, event_loop->monotonic_now());
break;
}
}
}
void SetClientState(const Node *source, message_bridge::State state) {
const size_t node_index =
configuration::GetNodeIndex(node_factory_->configuration(), source);
client_state_[node_index] = state;
if (client_status) {
const int client_index =
client_status->FindClientIndex(source->name()->string_view());
ClientConnection *connection =
client_status->GetClientConnection(source);
// TODO(austin): Are there cases where we want to dedup 2 CONNECTED
// calls?
if (connection->state() != state) {
switch (state) {
case message_bridge::State::CONNECTED:
client_status->Connect(client_index);
break;
case message_bridge::State::DISCONNECTED:
client_status->Disconnect(client_index);
break;
}
}
}
}
std::vector<UUID> boot_uuids_;
std::vector<message_bridge::State> client_state_;
std::vector<message_bridge::State> server_state_;
std::vector<std::pair<const Channel *, DelayersVector *>> delayer_watchers_;
std::function<void()> fn_;
NodeEventLoopFactory *node_factory_;
std::unique_ptr<aos::EventLoop> event_loop;
ChannelTimestampSender timestamp_loggers;
std::unique_ptr<MessageBridgeServerStatus> server_status;
std::unique_ptr<MessageBridgeClientStatus> client_status;
// List of delayers to update whenever this node starts or stops.
// Source delayers (which are the ones fetching).
std::vector<RawMessageDelayer *> source_delayers_;
// Destination delayers (which are the ones sending on the receiving nodes).
std::vector<RawMessageDelayer *> destination_delayers_;
bool disable_statistics_ = false;
DestroySenders destroy_senders_ = DestroySenders::kNo;
};
// Map of nodes to event loops. This is a member variable so that the
// lifetime of the event loops matches the lifetime of the bridge.
std::map<const Node *, State> event_loop_map_;
// List of delayers used to resend the messages.
std::vector<std::unique_ptr<DelayersVector>> delayers_list_;
};
} // namespace message_bridge
} // namespace aos
#endif // AOS_EVENTS_SIMULATED_NETWORK_BRIDGE_H_