Add plots for all joints in arm ui

Signed-off-by: milind-u <milind.upadhyay@gmail.com>
Change-Id: Ib0e1ad0383d66e54344bcaf6ca5c45d23daad32c
diff --git a/y2023/control_loops/python/graph_edit.py b/y2023/control_loops/python/graph_edit.py
index 9340d1a..1b7f1d2 100644
--- a/y2023/control_loops/python/graph_edit.py
+++ b/y2023/control_loops/python/graph_edit.py
@@ -156,10 +156,15 @@
         self.spline_edit = 0
         self.edit_control1 = True
 
-        self.roll_joint_thetas = None
-        self.roll_joint_point = None
+        self.joint_thetas = None
+        self.joint_points = None
         self.fig = plt.figure()
-        self.ax = self.fig.add_subplot(111)
+        self.axes = [
+            self.fig.add_subplot(3, 1, 1),
+            self.fig.add_subplot(3, 1, 2),
+            self.fig.add_subplot(3, 1, 3)
+        ]
+        self.fig.subplots_adjust(hspace=1.0)
         plt.show(block=False)
 
         self.index = 0
@@ -340,16 +345,20 @@
 
         set_color(cr, Color(0.0, 1.0, 0.5))
 
-        # Create the roll joint plot
-        if self.roll_joint_thetas:
-            self.ax.clear()
-            self.ax.plot(*self.roll_joint_thetas)
-            if self.roll_joint_point:
-                self.ax.scatter([self.roll_joint_point[0]],
-                                [self.roll_joint_point[1]],
-                                s=10,
-                                c="red")
-            plt.title("Roll Joint Angle")
+        # Create the plots
+        if self.joint_thetas:
+            if self.joint_points:
+                titles = ["Proximal", "Distal", "Roll joint"]
+                for i in range(len(self.joint_points)):
+                    self.axes[i].clear()
+                    self.axes[i].plot(self.joint_thetas[0],
+                                      self.joint_thetas[1][i])
+                    self.axes[i].scatter([self.joint_points[i][0]],
+                                         [self.joint_points[i][1]],
+                                         s=10,
+                                         c="red")
+                    self.axes[i].set_title(titles[i])
+            plt.title("Joint Angle")
             plt.xlabel("t (0 to 1)")
             plt.ylabel("theta (rad)")
 
@@ -372,21 +381,26 @@
         event.y = y / scale + self.center[1]
 
         for segment in self.segments:
-            self.roll_joint_thetas = segment.roll_joint_thetas()
+            self.joint_thetas = segment.joint_thetas()
 
             hovered_t = segment.intersection(event)
             if hovered_t:
                 min_diff = np.inf
                 closest_t = None
-                closest_theta = None
-                for i in range(len(self.roll_joint_thetas[0])):
-                    t = self.roll_joint_thetas[0][i]
+                closest_thetas = None
+                for i in range(len(self.joint_thetas[0])):
+                    t = self.joint_thetas[0][i]
                     diff = abs(t - hovered_t)
                     if diff < min_diff:
                         min_diff = diff
                         closest_t = t
-                        closest_theta = self.roll_joint_thetas[1][i]
-                self.roll_joint_point = (closest_t, closest_theta)
+                        closest_thetas = [
+                            self.joint_thetas[1][0][i],
+                            self.joint_thetas[1][1][i],
+                            self.joint_thetas[1][2][i]
+                        ]
+                self.joint_points = [(closest_t, closest_theta)
+                                     for closest_theta in closest_thetas]
                 break
 
         event.x = o_x