Add a flywheel controller test

Signed-off-by: Maxwell Henderson <mxwhenderson@gmail.com>
Change-Id: Ifcead97b1abb04e2a5e013ede85571f0283b7ed8
diff --git a/frc971/control_loops/flywheel/BUILD b/frc971/control_loops/flywheel/BUILD
index 0cf5d00..f077682 100644
--- a/frc971/control_loops/flywheel/BUILD
+++ b/frc971/control_loops/flywheel/BUILD
@@ -1,3 +1,4 @@
+load("//aos:config.bzl", "aos_config")
 load("//aos/flatbuffers:generate.bzl", "static_flatbuffer")
 load("@com_github_google_flatbuffers//:typescript.bzl", "flatbuffer_ts_library")
 
@@ -46,3 +47,65 @@
         "//frc971/control_loops:profiled_subsystem",
     ],
 )
+
+cc_test(
+    name = "flywheel_controller_test",
+    srcs = ["flywheel_controller_test.cc"],
+    data = [
+        ":flywheel_controller_test_config",
+    ],
+    target_compatible_with = ["@platforms//os:linux"],
+    deps = [
+        ":flywheel_controller",
+        ":flywheel_controller_test_plants",
+        ":flywheel_test_plant",
+        "//aos/testing:googletest",
+        "//frc971/control_loops:control_loop_test",
+    ],
+)
+
+aos_config(
+    name = "flywheel_controller_test_config",
+    src = "flywheel_controller_test_config_source.json",
+    flatbuffers = [
+        "//frc971/input:joystick_state_fbs",
+        "//frc971/input:robot_state_fbs",
+        "//aos/logging:log_message_fbs",
+        "//aos/events:event_loop_fbs",
+        ":flywheel_controller_status_fbs",
+    ],
+    target_compatible_with = ["@platforms//os:linux"],
+)
+
+genrule(
+    name = "genrule_flywheel_controller_test",
+    outs = [
+        "flywheel_controller_test_plant.h",
+        "flywheel_controller_test_plant.cc",
+        "integral_flywheel_controller_test_plant.h",
+        "integral_flywheel_controller_test_plant.cc",
+    ],
+    cmd = "$(location //frc971/control_loops/python:flywheel_controller_test) $(OUTS)",
+    target_compatible_with = ["@platforms//os:linux"],
+    tools = [
+        "//frc971/control_loops/python:flywheel_controller_test",
+    ],
+)
+
+cc_library(
+    name = "flywheel_controller_test_plants",
+    srcs = [
+        "flywheel_controller_test_plant.cc",
+        "integral_flywheel_controller_test_plant.cc",
+    ],
+    hdrs = [
+        "flywheel_controller_test_plant.h",
+        "integral_flywheel_controller_test_plant.h",
+    ],
+    target_compatible_with = ["@platforms//os:linux"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//frc971/control_loops:hybrid_state_feedback_loop",
+        "//frc971/control_loops:state_feedback_loop",
+    ],
+)
diff --git a/frc971/control_loops/flywheel/flywheel_controller_status.fbs b/frc971/control_loops/flywheel/flywheel_controller_status.fbs
index 0f95818..5e079e3 100644
--- a/frc971/control_loops/flywheel/flywheel_controller_status.fbs
+++ b/frc971/control_loops/flywheel/flywheel_controller_status.fbs
@@ -20,3 +20,5 @@
   // The angular velocity of the flywheel computed using delta x / delta t
   dt_angular_velocity:double (id: 5);
 }
+
+root_type FlywheelControllerStatus;
diff --git a/frc971/control_loops/flywheel/flywheel_controller_test.cc b/frc971/control_loops/flywheel/flywheel_controller_test.cc
new file mode 100644
index 0000000..26aabc9
--- /dev/null
+++ b/frc971/control_loops/flywheel/flywheel_controller_test.cc
@@ -0,0 +1,103 @@
+#include "frc971/control_loops/flywheel/flywheel_controller.h"
+
+#include "glog/logging.h"
+#include "gtest/gtest.h"
+
+#include "aos/configuration.h"
+#include "frc971/control_loops/control_loop_test.h"
+#include "frc971/control_loops/flywheel/flywheel_controller_test_plant.h"
+#include "frc971/control_loops/flywheel/flywheel_test_plant.h"
+#include "frc971/control_loops/flywheel/integral_flywheel_controller_test_plant.h"
+
+namespace frc971 {
+namespace control_loops {
+namespace flywheel {
+namespace testing {
+class FlywheelTest : public ::frc971::testing::ControlLoopTest {
+ public:
+  FlywheelTest()
+      : ::frc971::testing::ControlLoopTest(
+            aos::configuration::ReadConfig(
+                "frc971/control_loops/flywheel/"
+                "flywheel_controller_test_config.json"),
+            std::chrono::microseconds(5050)),
+        test_event_loop_(MakeEventLoop("test")),
+        flywheel_plant_(
+            new FlywheelPlant(MakeFlywheelTestPlant(), kBemf, kResistance)),
+        flywheel_controller_(MakeIntegralFlywheelTestLoop(), kBemf,
+                             kResistance),
+        flywheel_controller_sender_(
+            test_event_loop_->MakeSender<FlywheelControllerStatus>("/loop")) {
+    phased_loop_handle_ =
+        test_event_loop_->AddPhasedLoop([this](int) { Simulate(); }, dt());
+  }
+
+  void Simulate() {
+    const aos::monotonic_clock::time_point timestamp =
+        test_event_loop_->context().monotonic_event_time;
+    ::Eigen::Matrix<double, 1, 1> flywheel_U;
+    flywheel_U << flywheel_voltage_ + flywheel_plant_->voltage_offset();
+
+    // Confirm that we aren't drawing too much current.  2 motors -> twice the
+    // lumped current since our model can't tell them apart.
+    CHECK_NEAR(flywheel_plant_->battery_current(flywheel_U), 0.0, 200.0);
+
+    flywheel_plant_->Update(flywheel_U);
+
+    flywheel_controller_.set_position(flywheel_plant_->Y(0, 0), timestamp);
+
+    flywheel_controller_.set_goal(goal_);
+
+    flywheel_controller_.Update(false);
+    aos::FlatbufferFixedAllocatorArray<FlywheelControllerStatus, 512>
+        flywheel_status_buffer;
+
+    flywheel_status_buffer.Finish(
+        flywheel_controller_.SetStatus(flywheel_status_buffer.fbb()));
+
+    flywheel_voltage_ = flywheel_controller_.voltage();
+
+    last_angular_velocity_ =
+        flywheel_status_buffer.message().angular_velocity();
+  }
+
+  void VerifyNearGoal() { EXPECT_NEAR(last_angular_velocity_, goal_, 0.1); }
+
+  void set_goal(double goal) { goal_ = goal; }
+
+ private:
+  ::std::unique_ptr<::aos::EventLoop> test_event_loop_;
+  ::aos::PhasedLoopHandler *phased_loop_handle_ = nullptr;
+
+  std::unique_ptr<FlywheelPlant> flywheel_plant_;
+  FlywheelController flywheel_controller_;
+
+  aos::Sender<FlywheelControllerStatus> flywheel_controller_sender_;
+
+  double last_angular_velocity_ = 0.0;
+
+  double flywheel_voltage_ = 0.0;
+  double goal_ = 0.0;
+};
+
+TEST_F(FlywheelTest, DoNothing) {
+  set_goal(0);
+  RunFor(std::chrono::seconds(2));
+  VerifyNearGoal();
+}
+
+TEST_F(FlywheelTest, PositiveTest) {
+  set_goal(700.0);
+  RunFor(std::chrono::seconds(4));
+  VerifyNearGoal();
+}
+
+TEST_F(FlywheelTest, NegativeTest) {
+  set_goal(-700.0);
+  RunFor(std::chrono::seconds(8));
+  VerifyNearGoal();
+}
+}  // namespace testing
+}  // namespace flywheel
+}  // namespace control_loops
+}  // namespace frc971
diff --git a/frc971/control_loops/flywheel/flywheel_controller_test_config_source.json b/frc971/control_loops/flywheel/flywheel_controller_test_config_source.json
new file mode 100644
index 0000000..1a9b727
--- /dev/null
+++ b/frc971/control_loops/flywheel/flywheel_controller_test_config_source.json
@@ -0,0 +1,27 @@
+{
+  "channels": [
+    {
+      "name": "/aos",
+      "type": "aos.JoystickState"
+    },
+    {
+      "name": "/aos",
+      "type": "aos.logging.LogMessageFbs",
+      "frequency": 400
+    },
+    {
+      "name": "/aos",
+      "type": "aos.RobotState",
+      "frequency": 250
+    },
+    {
+      "name": "/aos",
+      "type": "aos.timing.Report"
+    },
+    {
+      "name": "/loop",
+      "type": "frc971.control_loops.flywheel.FlywheelControllerStatus",
+      "frequency": 200
+    }
+  ]
+}
diff --git a/frc971/control_loops/python/BUILD b/frc971/control_loops/python/BUILD
index 699eb5a..ed8ef3d 100644
--- a/frc971/control_loops/python/BUILD
+++ b/frc971/control_loops/python/BUILD
@@ -276,3 +276,19 @@
         "@pip//python_gflags",
     ],
 )
+
+py_binary(
+    name = "flywheel_controller_test",
+    srcs = [
+        "flywheel_controller_test.py",
+    ],
+    legacy_create_init = False,
+    target_compatible_with = ["@platforms//cpu:x86_64"],
+    deps = [
+        ":controls",
+        ":flywheel",
+        ":python_init",
+        "@pip//glog",
+        "@pip//python_gflags",
+    ],
+)
diff --git a/frc971/control_loops/python/flywheel_controller_test.py b/frc971/control_loops/python/flywheel_controller_test.py
new file mode 100644
index 0000000..4f4d929
--- /dev/null
+++ b/frc971/control_loops/python/flywheel_controller_test.py
@@ -0,0 +1,48 @@
+#!/usr/bin/python3
+
+# Generates a test flywheel for flywheel_controller_test
+
+from frc971.control_loops.python import control_loop
+from frc971.control_loops.python import flywheel
+
+import numpy
+import sys
+import gflags
+import glog
+
+FLAGS = gflags.FLAGS
+
+try:
+    gflags.DEFINE_bool('plot', False, 'If true, plot the loop response.')
+except gflags.DuplicateFlagError:
+    pass
+
+kFlywheel = flywheel.FlywheelParams(name='FlywheelTest',
+                                    motor=control_loop.Falcon(),
+                                    G=(60.0 / 48.0),
+                                    J=0.0035,
+                                    q_pos=0.01,
+                                    q_vel=10.0,
+                                    q_voltage=4.0,
+                                    r_pos=0.01,
+                                    controller_poles=[.95])
+
+
+def main(argv):
+    if FLAGS.plot:
+        R = numpy.matrix([[0.0], [500.0], [0.0]])
+        flywheel.PlotSpinup(params=kFlywheel, goal=R, iterations=400)
+        return 0
+
+    # Write the generated constants out to a file.
+    if len(argv) != 5:
+        glog.fatal('Expected .h file name and .cc file name')
+    else:
+        namespaces = ['frc971', 'control_loops', 'flywheel']
+        flywheel.WriteFlywheel(kFlywheel, argv[1:3], argv[3:5], namespaces)
+
+
+if __name__ == '__main__':
+    argv = FLAGS(sys.argv)
+    glog.init()
+    sys.exit(main(argv))