LogReader Before Send Callback

Adds the ability to add a callback to LogReader to mutate or act on a
message right before it is sent.

Change-Id: I94a6b9fa2074c0a9aa8ea23cbc979e6cbce4bd05
Signed-off-by: James Kuszmaul <james.kuszmaul@bluerivertech.com>
diff --git a/aos/events/event_loop.h b/aos/events/event_loop.h
index e6e5499..e53bc21 100644
--- a/aos/events/event_loop.h
+++ b/aos/events/event_loop.h
@@ -159,9 +159,6 @@
  public:
   using SharedSpan = std::shared_ptr<const absl::Span<const uint8_t>>;
 
-  // This looks a little ugly with no space, but please leave it so clang-format
-  // doesn't keep changing it. Bug is filed at
-  // <https://github.com/llvm/llvm-project/issues/55457>.
   enum class [[nodiscard]] Error{
       // Represents success and no error
       kOk,
@@ -217,7 +214,6 @@
              monotonic_clock::time_point monotonic_remote_time,
              realtime_clock::time_point realtime_remote_time,
              uint32_t remote_queue_index, const UUID &remote_boot_uuid);
-
   const Channel *channel() const { return channel_; }
 
   // Returns the time_points that the last message was sent at.
diff --git a/aos/events/logging/log_reader.cc b/aos/events/logging/log_reader.cc
index 75a53e1..d0957bf 100644
--- a/aos/events/logging/log_reader.cc
+++ b/aos/events/logging/log_reader.cc
@@ -7,6 +7,7 @@
 #include <sys/uio.h>
 
 #include <climits>
+#include <utility>
 #include <vector>
 
 #include "absl/strings/escaping.h"
@@ -378,6 +379,8 @@
     }
     states_.resize(configuration()->nodes()->size());
   }
+
+  before_send_callbacks_.resize(configuration()->channels()->size());
 }
 
 LogReader::~LogReader() {
@@ -620,7 +623,8 @@
             ? nullptr
             : std::make_unique<TimestampMapper>(std::move(filtered_parts)),
         filters_.get(), std::bind(&LogReader::NoticeRealtimeEnd, this), node,
-        State::ThreadedBuffering::kNo, MaybeMakeReplayChannelIndices(node));
+        State::ThreadedBuffering::kNo, MaybeMakeReplayChannelIndices(node),
+        before_send_callbacks_);
     State *state = states_[node_index].get();
     state->SetNodeEventLoopFactory(
         event_loop_factory_->GetNodeEventLoopFactory(node),
@@ -804,7 +808,8 @@
             ? nullptr
             : std::make_unique<TimestampMapper>(std::move(filtered_parts)),
         filters_.get(), std::bind(&LogReader::NoticeRealtimeEnd, this), node,
-        State::ThreadedBuffering::kYes, MaybeMakeReplayChannelIndices(node));
+        State::ThreadedBuffering::kYes, MaybeMakeReplayChannelIndices(node),
+        before_send_callbacks_);
     State *state = states_[node_index].get();
 
     state->SetChannelCount(logged_configuration()->channels()->size());
@@ -1771,13 +1776,16 @@
     message_bridge::MultiNodeNoncausalOffsetEstimator *multinode_filters,
     std::function<void()> notice_realtime_end, const Node *node,
     LogReader::State::ThreadedBuffering threading,
-    std::unique_ptr<const ReplayChannelIndices> replay_channel_indices)
+    std::unique_ptr<const ReplayChannelIndices> replay_channel_indices,
+    const std::vector<std::function<void(void *message)>>
+        &before_send_callbacks)
     : timestamp_mapper_(std::move(timestamp_mapper)),
       notice_realtime_end_(notice_realtime_end),
       node_(node),
       multinode_filters_(multinode_filters),
       threading_(threading),
-      replay_channel_indices_(std::move(replay_channel_indices)) {
+      replay_channel_indices_(std::move(replay_channel_indices)),
+      before_send_callbacks_(before_send_callbacks) {
   // If timestamp_mapper_ is nullptr, then there are no log parts associated
   // with this node. If there are no log parts for the node, there will be no
   // log data, and so we do not need to worry about the replay channel filters.
@@ -1883,7 +1891,7 @@
   timing_statistics_sender_.CheckOk(builder.Send(timing_builder.Finish()));
 }
 
-bool LogReader::State::Send(const TimestampedMessage &timestamped_message) {
+bool LogReader::State::Send(const TimestampedMessage &&timestamped_message) {
   aos::RawSender *sender = channels_[timestamped_message.channel_index].get();
   CHECK(sender);
   uint32_t remote_queue_index = 0xffffffff;
@@ -1973,6 +1981,16 @@
                  ->boot_uuid());
   }
 
+  // Right before sending allow the user to process the message.
+  if (before_send_callbacks_[timestamped_message.channel_index]) {
+    // Only channels that are forwarded and sent from this State's node will be
+    // in the queue_index_map_
+    if (queue_index_map_[timestamped_message.channel_index]) {
+      before_send_callbacks_[timestamped_message.channel_index](
+          timestamped_message.data->mutable_data());
+    }
+  }
+
   // Send!  Use the replayed queue index here instead of the logged queue index
   // for the remote queue index.  This makes re-logging work.
   const RawSender::Error err = sender->Send(
@@ -2426,5 +2444,14 @@
   }
 }
 
+bool LogReader::AreStatesInitialized() const {
+  for (const auto &state : states_) {
+    if (state) {
+      return true;
+    }
+  }
+  return false;
+}
+
 }  // namespace logger
 }  // namespace aos
diff --git a/aos/events/logging/log_reader.h b/aos/events/logging/log_reader.h
index fb6d6f9..b691b30 100644
--- a/aos/events/logging/log_reader.h
+++ b/aos/events/logging/log_reader.h
@@ -12,6 +12,7 @@
 
 #include "aos/condition.h"
 #include "aos/events/event_loop.h"
+#include "aos/events/event_loop_tmpl.h"
 #include "aos/events/logging/logfile_sorting.h"
 #include "aos/events/logging/logfile_utils.h"
 #include "aos/events/logging/logger_generated.h"
@@ -352,6 +353,48 @@
   // Only applies when running against a SimulatedEventLoopFactory.
   void SetRealtimeReplayRate(double replay_rate);
 
+  // Adds a callback for a channel to be called right before sending a message.
+  // This allows a user to mutate a message or do any processing when a specific
+  // type of message is sent on a channel. The name and type of the channel
+  // corresponds to the logged_configuration's name and type.
+  //
+  // Note, only one callback can be registered per channel in the current
+  // implementation. And, the callback is called only once one the Sender's Node
+  // if the channel is forwarded.
+  //
+  // See multinode_logger_test for examples of usage.
+  template <typename Callback>
+  void AddBeforeSendCallback(std::string_view channel_name,
+                             Callback &&callback) {
+    CHECK(!AreStatesInitialized())
+        << ": Cannot add callbacks after calling Register";
+
+    using MessageType = typename std::remove_pointer<
+        typename event_loop_internal::watch_message_type_trait<
+            decltype(&Callback::operator())>::message_type>::type;
+
+    const Channel *channel = configuration::GetChannel(
+        logged_configuration(), channel_name,
+        MessageType::GetFullyQualifiedName(), "", nullptr);
+
+    CHECK(channel != nullptr)
+        << ": Channel { \"name\": \"" << channel_name << "\", \"type\": \""
+        << MessageType::GetFullyQualifiedName()
+        << "\" } not found in config for application.";
+    auto channel_index =
+        configuration::ChannelIndex(logged_configuration(), channel);
+
+    CHECK(!before_send_callbacks_[channel_index])
+        << ": Before Send Callback already registered for channel "
+        << ":{ \"name\": \"" << channel_name << "\", \"type\": \""
+        << MessageType::GetFullyQualifiedName() << "\" }";
+
+    before_send_callbacks_[channel_index] = [callback](void *message) {
+      callback(flatbuffers::GetMutableRoot<MessageType>(
+          reinterpret_cast<char *>(message)));
+    };
+  }
+
  private:
   void Register(EventLoop *event_loop, const Node *node);
 
@@ -433,7 +476,9 @@
           message_bridge::MultiNodeNoncausalOffsetEstimator *multinode_filters,
           std::function<void()> notice_realtime_end, const Node *node,
           ThreadedBuffering threading,
-          std::unique_ptr<const ReplayChannelIndices> replay_channel_indices);
+          std::unique_ptr<const ReplayChannelIndices> replay_channel_indices,
+          const std::vector<std::function<void(void *message)>>
+              &before_send_callbacks);
 
     // Connects up the timestamp mappers.
     void AddPeer(State *peer);
@@ -659,7 +704,7 @@
     }
 
     // Sends a buffer on the provided channel index.
-    bool Send(const TimestampedMessage &timestamped_message);
+    bool Send(const TimestampedMessage &&timestamped_message);
 
     void MaybeSetClockOffset();
     std::chrono::nanoseconds clock_offset() const { return clock_offset_; }
@@ -838,8 +883,14 @@
     // indices of the channels to replay for the Node represented by
     // the instance of LogReader::State.
     std::unique_ptr<const ReplayChannelIndices> replay_channel_indices_;
+    const std::vector<std::function<void(void *message)>>
+        before_send_callbacks_;
   };
 
+  // Checks if any of the States have been constructed yet.
+  // This happens during Register
+  bool AreStatesInitialized() const;
+
   // If a ReplayChannels was passed to LogReader then creates a
   // ReplayChannelIndices for the given node. Otherwise, returns a nullptr.
   std::unique_ptr<const ReplayChannelIndices> MaybeMakeReplayChannelIndices(
@@ -889,6 +940,10 @@
   // name and type of channels to replay which is used when creating States.
   const ReplayChannels *replay_channels_ = nullptr;
 
+  // The callbacks that will be called before sending a message indexed by the
+  // channel index from the logged_configuration
+  std::vector<std::function<void(void *message)>> before_send_callbacks_;
+
   // If true, the replay timer will ignore any missing data.  This is used
   // during startup when we are bootstrapping everything and trying to get to
   // the start of all the log files.
diff --git a/aos/events/logging/logfile_utils.h b/aos/events/logging/logfile_utils.h
index 6337005..67b6b84 100644
--- a/aos/events/logging/logfile_utils.h
+++ b/aos/events/logging/logfile_utils.h
@@ -474,6 +474,12 @@
   // pointer to track.
   absl::Span<const uint8_t> span;
 
+  // Used to be able to mutate the data in the span. This is only used for
+  // mutating the message inside of LogReader for the Before Send Callback. It
+  // is safe in this case since there is only one caller to Send, and the data
+  // is not mutated after Send is called.
+  uint8_t *mutable_data() { return const_cast<uint8_t *>(span.data()); }
+
   char actual_data[];
 
  private:
diff --git a/aos/events/logging/logger_test.cc b/aos/events/logging/logger_test.cc
index 1c48d2c..f2cd702 100644
--- a/aos/events/logging/logger_test.cc
+++ b/aos/events/logging/logger_test.cc
@@ -129,6 +129,62 @@
   EXPECT_EQ(pong_count, ping_count);
 }
 
+// Tests that we can mutate a message before sending
+TEST_F(LoggerTest, MutateCallback) {
+  const ::std::string tmpdir = aos::testing::TestTmpDir();
+  const ::std::string base_name = tmpdir + "/logfile";
+  const ::std::string config =
+      absl::StrCat(base_name, kSingleConfigSha256, ".bfbs");
+  const ::std::string logfile = base_name + "_data.part0.bfbs";
+  // Remove it.
+  unlink(config.c_str());
+  unlink(logfile.c_str());
+
+  LOG(INFO) << "Logging data to " << logfile;
+
+  {
+    std::unique_ptr<EventLoop> logger_event_loop =
+        event_loop_factory_.MakeEventLoop("logger");
+
+    event_loop_factory_.RunFor(chrono::milliseconds(95));
+
+    Logger logger(logger_event_loop.get());
+    logger.set_separate_config(false);
+    logger.set_polling_period(std::chrono::milliseconds(100));
+    logger.StartLoggingOnRun(base_name);
+    event_loop_factory_.RunFor(chrono::milliseconds(20000));
+  }
+
+  // Even though it doesn't make any difference here, exercise the logic for
+  // passing in a separate config.
+  LogReader reader(logfile, &config_.message());
+
+  reader.AddBeforeSendCallback("/test", [](aos::examples::Ping *ping) {
+    ping->mutate_value(ping->value() + 1);
+  });
+
+  // This sends out the fetched messages and advances time to the start of the
+  // log file.
+  reader.Register();
+
+  EXPECT_THAT(reader.LoggedNodes(), ::testing::ElementsAre(nullptr));
+
+  std::unique_ptr<EventLoop> test_event_loop =
+      reader.event_loop_factory()->MakeEventLoop("log_reader");
+
+  // Confirm that the ping and pong counts both match, and the value also
+  // matches.
+  int ping_count = 10;
+  test_event_loop->MakeWatcher("/test",
+                               [&ping_count](const examples::Ping &ping) {
+                                 ++ping_count;
+                                 EXPECT_EQ(ping.value(), ping_count);
+                               });
+
+  reader.event_loop_factory()->RunFor(std::chrono::seconds(100));
+  EXPECT_EQ(ping_count, 2010);
+}
+
 // Tests calling StartLogging twice.
 TEST_F(LoggerDeathTest, ExtraStart) {
   const ::std::string tmpdir = aos::testing::TestTmpDir();
diff --git a/aos/events/logging/multinode_logger_test.cc b/aos/events/logging/multinode_logger_test.cc
index 44736a5..677fc72 100644
--- a/aos/events/logging/multinode_logger_test.cc
+++ b/aos/events/logging/multinode_logger_test.cc
@@ -551,6 +551,253 @@
   reader.Deregister();
 }
 
+// MultinodeLoggerTest that tests the mutate callback works across multiple
+// nodes with remapping
+TEST_P(MultinodeLoggerTest, MultiNodeRemapMutateCallback) {
+  time_converter_.StartEqual();
+  std::vector<std::string> actual_filenames;
+
+  {
+    LoggerState pi1_logger = MakeLogger(pi1_);
+    LoggerState pi2_logger = MakeLogger(pi2_);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(95));
+
+    StartLogger(&pi1_logger);
+    StartLogger(&pi2_logger);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(20000));
+    pi1_logger.AppendAllFilenames(&actual_filenames);
+    pi2_logger.AppendAllFilenames(&actual_filenames);
+  }
+
+  const std::vector<LogFile> sorted_parts = SortParts(logfiles_);
+
+  LogReader reader(sorted_parts, &config_.message());
+  // Remap just on pi1.
+  reader.RemapLoggedChannel<examples::Pong>(
+      "/test", configuration::GetNode(reader.configuration(), "pi1"));
+
+  SimulatedEventLoopFactory log_reader_factory(reader.configuration());
+
+  int pong_count = 0;
+  // Adds a callback which mutates the value of the pong message before the
+  // message is sent which is the feature we are testing here
+  reader.AddBeforeSendCallback("/test",
+                               [&pong_count](aos::examples::Pong *pong) {
+                                 pong->mutate_value(pong->value() + 1);
+                                 pong_count = pong->value();
+                               });
+
+  // This sends out the fetched messages and advances time to the start of the
+  // log file.
+  reader.Register(&log_reader_factory);
+
+  const Node *pi1 =
+      configuration::GetNode(log_reader_factory.configuration(), "pi1");
+  const Node *pi2 =
+      configuration::GetNode(log_reader_factory.configuration(), "pi2");
+
+  EXPECT_THAT(reader.LoggedNodes(),
+              ::testing::ElementsAre(
+                  configuration::GetNode(reader.logged_configuration(), pi1),
+                  configuration::GetNode(reader.logged_configuration(), pi2)));
+
+  std::unique_ptr<EventLoop> pi1_event_loop =
+      log_reader_factory.MakeEventLoop("test", pi1);
+  std::unique_ptr<EventLoop> pi2_event_loop =
+      log_reader_factory.MakeEventLoop("test", pi2);
+
+  pi1_event_loop->MakeWatcher("/original/test",
+                              [&pong_count](const examples::Pong &pong) {
+                                EXPECT_EQ(pong_count, pong.value());
+                              });
+
+  pi2_event_loop->MakeWatcher("/test",
+                              [&pong_count](const examples::Pong &pong) {
+                                EXPECT_EQ(pong_count, pong.value());
+                              });
+
+  reader.event_loop_factory()->RunFor(std::chrono::seconds(100));
+  reader.Deregister();
+
+  EXPECT_EQ(pong_count, 2011);
+}
+
+// MultinodeLoggerTest that tests the mutate callback works across multiple
+// nodes
+TEST_P(MultinodeLoggerTest, MultiNodeMutateCallback) {
+  time_converter_.StartEqual();
+  std::vector<std::string> actual_filenames;
+
+  {
+    LoggerState pi1_logger = MakeLogger(pi1_);
+    LoggerState pi2_logger = MakeLogger(pi2_);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(95));
+
+    StartLogger(&pi1_logger);
+    StartLogger(&pi2_logger);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(20000));
+    pi1_logger.AppendAllFilenames(&actual_filenames);
+    pi2_logger.AppendAllFilenames(&actual_filenames);
+  }
+
+  const std::vector<LogFile> sorted_parts = SortParts(logfiles_);
+
+  LogReader reader(sorted_parts, &config_.message());
+
+  int pong_count = 0;
+  // Adds a callback which mutates the value of the pong message before the
+  // message is sent which is the feature we are testing here
+  reader.AddBeforeSendCallback("/test",
+                               [&pong_count](aos::examples::Pong *pong) {
+                                 pong->mutate_value(pong->value() + 1);
+                                 pong_count = pong->value();
+                               });
+
+  SimulatedEventLoopFactory log_reader_factory(reader.configuration());
+
+  // This sends out the fetched messages and advances time to the start of the
+  // log file.
+  reader.Register(&log_reader_factory);
+
+  const Node *pi1 =
+      configuration::GetNode(log_reader_factory.configuration(), "pi1");
+  const Node *pi2 =
+      configuration::GetNode(log_reader_factory.configuration(), "pi2");
+
+  EXPECT_THAT(reader.LoggedNodes(),
+              ::testing::ElementsAre(
+                  configuration::GetNode(reader.logged_configuration(), pi1),
+                  configuration::GetNode(reader.logged_configuration(), pi2)));
+
+  std::unique_ptr<EventLoop> pi1_event_loop =
+      log_reader_factory.MakeEventLoop("test", pi1);
+  std::unique_ptr<EventLoop> pi2_event_loop =
+      log_reader_factory.MakeEventLoop("test", pi2);
+
+  pi1_event_loop->MakeWatcher("/test",
+                              [&pong_count](const examples::Pong &pong) {
+                                EXPECT_EQ(pong_count, pong.value());
+                              });
+
+  pi2_event_loop->MakeWatcher("/test",
+                              [&pong_count](const examples::Pong &pong) {
+                                EXPECT_EQ(pong_count, pong.value());
+                              });
+
+  reader.event_loop_factory()->RunFor(std::chrono::seconds(100));
+  reader.Deregister();
+
+  EXPECT_EQ(pong_count, 2011);
+}
+
+// Tests that the before send callback is only called from the sender node if it
+// is forwarded
+TEST_P(MultinodeLoggerTest, OnlyDoBeforeSendCallbackOnSenderNode) {
+  time_converter_.StartEqual();
+  {
+    LoggerState pi1_logger = MakeLogger(pi1_);
+    LoggerState pi2_logger = MakeLogger(pi2_);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(95));
+
+    StartLogger(&pi1_logger);
+    StartLogger(&pi2_logger);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(20000));
+  }
+
+  LogReader reader(SortParts(logfiles_));
+
+  int ping_count = 0;
+  // Adds a callback which mutates the value of the pong message before the
+  // message is sent which is the feature we are testing here
+  reader.AddBeforeSendCallback("/test",
+                               [&ping_count](aos::examples::Ping *ping) {
+                                 ++ping_count;
+                                 ping->mutate_value(ping_count);
+                               });
+
+  SimulatedEventLoopFactory log_reader_factory(reader.configuration());
+  log_reader_factory.set_send_delay(chrono::microseconds(0));
+
+  reader.Register(&log_reader_factory);
+
+  const Node *pi1 =
+      configuration::GetNode(log_reader_factory.configuration(), "pi1");
+  const Node *pi2 =
+      configuration::GetNode(log_reader_factory.configuration(), "pi2");
+
+  std::unique_ptr<EventLoop> pi1_event_loop =
+      log_reader_factory.MakeEventLoop("test", pi1);
+  pi1_event_loop->SkipTimingReport();
+  std::unique_ptr<EventLoop> pi2_event_loop =
+      log_reader_factory.MakeEventLoop("test", pi2);
+  pi2_event_loop->SkipTimingReport();
+
+  MessageCounter<examples::Ping> pi1_ping(pi1_event_loop.get(), "/test");
+  MessageCounter<examples::Ping> pi2_ping(pi2_event_loop.get(), "/test");
+
+  std::unique_ptr<MessageCounter<message_bridge::RemoteMessage>>
+      pi1_ping_timestamp;
+  if (!shared()) {
+    pi1_ping_timestamp =
+        std::make_unique<MessageCounter<message_bridge::RemoteMessage>>(
+            pi1_event_loop.get(),
+            "/pi1/aos/remote_timestamps/pi2/test/aos-examples-Ping");
+  }
+
+  log_reader_factory.Run();
+
+  EXPECT_EQ(pi1_ping.count(), 2000u);
+  EXPECT_EQ(pi2_ping.count(), 2000u);
+  // If the BeforeSendCallback is called on both nodes, then the ping count
+  // would be 4002 instead of 2001
+  EXPECT_EQ(ping_count, 2001u);
+  if (!shared()) {
+    EXPECT_EQ(pi1_ping_timestamp->count(), 2000u);
+  }
+
+  reader.Deregister();
+}
+
+// Tests that we do not allow adding callbacks after Register is called
+TEST_P(MultinodeLoggerDeathTest, AddCallbackAfterRegister) {
+  time_converter_.StartEqual();
+  std::vector<std::string> actual_filenames;
+
+  {
+    LoggerState pi1_logger = MakeLogger(pi1_);
+    LoggerState pi2_logger = MakeLogger(pi2_);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(95));
+
+    StartLogger(&pi1_logger);
+    StartLogger(&pi2_logger);
+
+    event_loop_factory_.RunFor(chrono::milliseconds(20000));
+    pi1_logger.AppendAllFilenames(&actual_filenames);
+    pi2_logger.AppendAllFilenames(&actual_filenames);
+  }
+
+  const std::vector<LogFile> sorted_parts = SortParts(logfiles_);
+
+  LogReader reader(sorted_parts, &config_.message());
+  SimulatedEventLoopFactory log_reader_factory(reader.configuration());
+  reader.Register(&log_reader_factory);
+  EXPECT_DEATH(
+      {
+        reader.AddBeforeSendCallback("/test", [](aos::examples::Pong *) {
+          LOG(FATAL) << "This should not be called";
+        });
+      },
+      "Cannot add callbacks after calling Register");
+  reader.Deregister();
+}
+
 // Test that if we feed the replay with a mismatched node list that we die on
 // the LogReader constructor.
 TEST_P(MultinodeLoggerDeathTest, MultiNodeBadReplayConfig) {