import sys
from os.path import  realpath
script_root = realpath(__file__)
sys.path.insert(1, script_root)

from matplotlib import pyplot as plt
from matplotlib.widgets import Button, CheckButtons, Slider, RadioButtons, RangeSlider, Button
import numpy as np
from collections.abc import Iterable

hand_skeleton_graph = [
    ("black", ["WristIn", "WristOut", "HandOut", "HandIn", "Thumb1"]),
    ("red", ["WristIn", "Thumb1", "ThumbTip"]),
    ("green", ["WristIn", "HandIn", "Index2", "IndexTip"]),
    ("red", ["HandIn", "Middle2", "MiddleTip"]),
    ("green", ["HandOut", "Ring2", "RingTip"]),
    ("red", ["HandOut", "Pinky2", "PinkyTip"])
]

body_skeleton_graph = [
    ("black", ["head", "torso", "pelvis"]),
    ("red", ["torso", "l_uarm", "l_larm", "l_hand"]),
    ("green", ["torso", "r_uarm", "r_larm", "r_hand"]),
    ("red", ["pelvis", "l_thigh", "l_shank", "l_foot"]),
    ("green", ["pelvis", "r_thigh", "r_shank", "r_foot"])
]

R_VEL_KEY = "Right Hand"
L_VEL_KEY = "Left Hand"
FLAGS_KEY = "Flags"
IGNORE_HAND_KEY = "Ignore"
HVP_KEY = "Hand Velocity Peaks"
START_KEY = "Peak Starts"
END_KEY = "Peak Ends"
MANUAL_ANNOTATIONS_KEY = "Manual Annotations"
INITIAL_BALLISTIC_KEY = "Start Peak"
TERMINAL_BALLISTIC_KEY = "End Peak"

FLAGS = ["CORRECTION", "NO_GESTURE", "WRONG_TARGET", "MULTIPLE_GESTURES", "TARGET_TRACKING", "BODY_TRACKING", "HAND_TRACKING", "GESTURE_CLIPPED", "B_HANDSHAPE", "MISSING_FINGER_MARKER"]

body_skeleton_markers = set(label for bone in body_skeleton_graph for label in bone[1])
hand_skeleton_markers = set(label for bone in hand_skeleton_graph for label in bone[1])

# provide interactive and non-interactive usage, e.g. create plot that can be tweaked manually and provide means to just generate images
class TrialVisualiser:
    def __init__(self, match_participant_perspective = False, fig_size = (16,9),
                 width=1000, depth=1000, height=1000,
                 width_offset=None, depth_offset=None, height_offset=None, render_all=False
                ):
        self.width = width
        self.height = height
        self.depth = depth
        self.render_all = render_all
        self.width_offset = width_offset if width_offset is not None else 0
        self.depth_offset = depth_offset if depth_offset is not None else 0
        self.height_offset = height_offset if height_offset is not None else 200
        self.trials = {}
        self.match_participant_perspective = match_participant_perspective
        self.fig, self.ax = plt.subplots(figsize=fig_size, subplot_kw=dict(projection="3d", focal_length=1.2, computed_zorder=False, xlabel='X', ylabel='Y', zlabel='Z'))
        self.ax.set_autoscale_on(False)
        self.ax.set_box_aspect((self.depth*2, self.width*2, self.height*2))
        self.ax.set_xlim(-self.depth, self.depth)
        self.ax.set_ylim(-self.width, self.width)
        self.ax.set_zlim(500, (self.height*2) - 500)
        self.plots = {}

    def plot_target_positions(self):
        if "TARGETS" in self.plots:
            for _, plot in self.plots["TARGETS"].items():
                if type(plot) is list:
                    for p in plot:
                        if type(p) is list:
                            for point in p:
                                point.remove()
                        else:
                            p.remove()
                else:
                    plot.remove()

        if "TARGET" in self.trials[self.condition][self.trial_idx]:
            self.plots["TARGETS"] = {
                "PLANES": [],
                "LABELS": [],
            }
            self.active_target = self.trials[self.condition][self.trial_idx]["TARGET"]
            active_target_key = f"({self.active_target.row}{self.active_target.column})[{self.active_target.subid}]"

            subtarget = self.active_target.subtarget
            target_mesh = self.active_target.get_surface()

            if not np.any(np.isnan(target_mesh[0])):
                self.plots["TARGETS"]["PLANES"].append(self.ax.plot_surface(*target_mesh, color="black", zorder=10))
                self.plots["TARGETS"]["PLANES"].append(self.ax.plot(*self.active_target.get_border(),color="green", zorder=50))
                self.plots["TARGETS"]["LABELS"] = self.ax.text(self.active_target.top_right.x, self.active_target.top_right.y, self.active_target.top_right.z, active_target_key, color="black", zorder=100)
                self.plots["TARGETS"]["LED"] = self.ax.scatter(subtarget.x, subtarget.y, subtarget.z, color="lime", s=5, zorder=80)
            else:
                print("Unable to render target as it does not contain a valid position")
                
            if not self.match_participant_perspective:
                self.ax.set_xlim(-self.depth, self.depth)
                self.ax.set_ylim(-self.width, self.width)
                head = self.trials[self.condition][self.trial_idx]["FULL"].loc[self.frame, "head Z"] if "head Z" in self.trials[self.condition][self.trial_idx]["FULL"] else 1500
                self.ax.set_zlim(head-(self.height + self.height_offset), head+(self.height + self.height_offset))
                t = self.active_target if self.active_target.override is None else self.active_target.override
                self.ax.view_init(azim=t.orientation.yaw, elev=-t.orientation.pitch)

    def get_head_rotation(self, segment, joint, axis, local_coord_len):
        angle = segment - joint
        return 90 - np.rad2deg(np.arccos(angle[axis] / local_coord_len))

    def get_axis_angle(self, axis, inverse=False):
        angle = np.nan
        if axis > 1:
            angle = (np.pi/2) + np.arcsin(axis-1)
        elif axis < -1:
            angle = -(np.pi/2) + np.arcsin(axis+1)
        else:
            angle = np.arcsin(axis)

        angle = np.rad2deg(angle)
        if inverse:
            angle = 0 - angle
        return angle

    def plot_body(self):
        prefix = "BODY_JOINTS"

        def clear_old_bones(seg_labels = None):
            if seg_labels is None:
                seg_labels = [k for k in self.plots[prefix].keys()]
            for seg_label in seg_labels:
                seg = self.plots[prefix][seg_label]
                if type(seg) is list:
                    for i in seg:
                        i.remove
                else:
                    self.plots[prefix][seg_label].remove()
                del self.plots[prefix][seg_label]

        joint_frame = self.trials[self.condition][self.trial_idx]["FULL"].iloc[self.frame]
        joints = [label for label in body_skeleton_markers if f"{label} X" in joint_frame]
        if not self.render_all and len(joints) < 1:
            clear_old_bones()
            return
        
        x = joint_frame[[f"{label} X" for label in joints]]
        y = joint_frame[[f"{label} Y" for label in joints]]
        z = joint_frame[[f"{label} Z" for label in joints]]
        if prefix not in self.plots: 
            self.plots[prefix] = {}
        
        for i in range(x.shape[0]):
            if np.isnan(x.iloc[i]) or np.isnan(y.iloc[i]) or np.isnan(z.iloc[i]):
                x.iloc[i] = np.nan
                y.iloc[i] = np.nan
                z.iloc[i] = np.nan
        x.dropna()
        y.dropna()
        z.dropna()
        
        side_key = "r" if self.hand == "RIGHT" else "l"
        
        if self.hand is not None:
            ray_origin_labels = joint_frame[[f"{side_key}_larm {axis}" for axis in ["X", "Y", "Z"]]].values
            ray_direction_labels = joint_frame[[f"{side_key}_hand {axis}" for axis in ["X", "Y", "Z"]]].values

            if "FRC_RAY" in self.plots[prefix]:
                self.plots[prefix]["FRC_RAY"].set_data(
                                [ray_origin_labels[0], ray_direction_labels[0]], 
                                [ray_origin_labels[1], ray_direction_labels[1]]
                            )
                self.plots[prefix]["FRC_RAY"].set_3d_properties([ray_origin_labels[2], ray_direction_labels[2]])
            else:
                seg_plot = self.ax.plot([ray_origin_labels[0], ray_direction_labels[0]], [ray_origin_labels[1], ray_direction_labels[1]], [ray_origin_labels[2], ray_direction_labels[2]], color="blue", linestyle="dashed", zorder=80)
                self.plots[prefix]["FRC_RAY"] = seg_plot[0]


        if "TORSO_NORMAL" in self.plots[prefix]:
            self.plots[prefix]["TORSO_NORMAL"].remove()
            del self.plots[prefix]["TORSO_NORMAL"]

        for colour, segments in body_skeleton_graph:
            bones = [[]]
            for segment in segments:
                if f"{segment} X" in x:
                    bones[-1].append(segment)
                else:
                    bones.append([])
            seg_labels = set()
            for bone in bones:
                if len(bone) > 0:
                    joints = [[f"{b} {axis}" for b in bone] for axis in ["X", "Y", "Z"]]
                    seg_label = f"{bone[0]} - {bone[1]}"
                    seg_labels.add(seg_label)
                    if seg_label in self.plots[prefix] and not self.render_all:
                        self.plots[prefix][seg_label].set_data(
                            x[joints[0]].values, 
                            y[joints[1]].values
                        )
                        self.plots[prefix][seg_label].set_3d_properties(z[joints[2]].values)
                    else:
                        seg_plot = self.ax.plot(x[joints[0]], y[joints[1]], z[joints[2]], color=colour, zorder=80)
                        self.plots[prefix][seg_label] = seg_plot[0]
        old_segments = seg_labels - set(self.plots[prefix].keys())
        if not self.render_all and len(old_segments) > 0:
            clear_old_bones(old_segments)

    def plot_eyes(self, clear=False):
        def clear_old(label):
            if prefix in self.plots:
                if f"{label}_LED_ERROR" in self.plots[prefix]:
                    self.plots[prefix][f"{label}_LED_ERROR"].remove()
                    del self.plots[prefix][f"{label}_LED_ERROR"]
                if f"{label}_LED_VECTOR" in self.plots[prefix]:
                    self.plots[prefix][f"{label}_LED_VECTOR"].remove()
                    del self.plots[prefix][f"{label}_LED_VECTOR"]
                if f"{label}_LED_ERROR_TXT" in self.plots[prefix]:
                    self.plots[prefix][f"{label}_LED_ERROR_TXT"].remove()
                    del self.plots[prefix][f"{label}_LED_ERROR_TXT"]
        
        def plot_eye(label, gaze_df):
            if label == "CYCLOPS":
                x = [cyclops[0]]
                y = [cyclops[1]]
                z = [cyclops[2]]
            else:
                x = [*gaze_df.loc[[f"{label}_X_POS"]].values]
                y = [*gaze_df.loc[[f"{label}_Y_POS"]].values]
                z = [*gaze_df.loc[[f"{label}_Z_POS"]].values]
            
            if len(x) == len(y) == len(z) == 1 and (not np.isnan(x[0]) and not np.isnan(y[0]) and not np.isnan(z[0])):
                if f"{label}_EYES" in self.plots[prefix]:
                    self.plots[prefix][f"{label}_EYES"]._offsets3d = (x, y, z)
                else:
                    self.plots[prefix][f"{label}_EYES"] = self.ax.scatter(x, y, z, color="blue" if label != 'CYCLOPS' else 'cyan')
            elif not self.render_all:
                if f"{label}_EYES" in self.plots[prefix]:
                    self.plots[prefix][f"{label}_EYES"].remove()
                    del self.plots[prefix][f"{label}_EYES"]

        def plot_efrc_error(label, gaze_df, led_cyclops=None):
            hand_side_key = "RH_R" if label == "RIGHT" else "LH_L"
            if "GAZE" != label:
                label = f"{label}_EFRC"
            accuracy_df = self.trials[self.condition][self.trial_idx]["FULL"].iloc[self.frame]
            if "TARGET" in self.trials[self.condition][self.trial_idx]:
                led = self.trials[self.condition][self.trial_idx]["TARGET"].subtarget.to_array()
                if len([header for header in accuracy_df.index.values if f"{hand_side_key}IndexTip X" in header]) > 0:
                    if f"{label}_COSINE_ERR" in accuracy_df.index:
                        led_err = accuracy_df[f"{label}_COSINE_ERR"]
                        if (np.isnan(led_err) or led_err > 90): # if no value, or exceeds 90 degrees, don't plot
                            clear_old(label)
                            return "GAZE" == label
                    
                    emitter_x = accuracy_df[f"{hand_side_key}IndexTip X"]
                    emitter_y = accuracy_df[f"{hand_side_key}IndexTip Y"]
                    emitter_z = accuracy_df[f"{hand_side_key}IndexTip Z"]
                    if led_cyclops is None:
                        eye_root_labels = [f"{self.side}_{axis}{'' if self.side == 'CYCLOPS' else '_POS'}" for axis in ["X", "Y", "Z"]]
                        led_cyclops = np.array([*gaze_df.loc[eye_root_labels].values])

                    if label != "GAZE":
                        emitter = np.array([emitter_x, emitter_y, emitter_z])
                        ray = emitter - led_cyclops
                        ray = ray / np.linalg.norm(ray)
                        correct_ray = led - led_cyclops
                        correct_ray = correct_ray / np.linalg.norm(correct_ray)
                        ray = led_cyclops + (ray * 3000)
                    
                    if (label == "GAZE" or (self.hand is None or f"{self.hand}_EFRC" == label)):
                        if f"{label}_LED_VECTOR" in self.plots[prefix]:
                            self.plots[prefix][f"{label}_LED_VECTOR"].set_data_3d([led_cyclops[0], ray[0]], [led_cyclops[1], ray[1]], [led_cyclops[2], ray[2]])
                        else:
                            self.plots[prefix][f"{label}_LED_VECTOR"] = self.ax.plot([led_cyclops[0], ray[0]], [led_cyclops[1], ray[1]], [led_cyclops[2], ray[2]], linestyle="dashed", color="purple" if "GAZE" != label else "cyan", zorder=100)[0]
                elif not self.render_all:
                    clear_old(label)
            elif not self.render_all:
                clear_old(label)
            return False

        prefix = "GAZE_VECTORS"
        
        if clear:
            clear_old("LEFT")
            clear_old("RIGHT")

        if prefix not in self.plots:
            self.plots[prefix] = {}
        gaze_df = self.trials[self.condition][self.trial_idx]["FULL"].iloc[self.frame]
        
        cyclops = gaze_df[["CYCLOPS X", "CYCLOPS Y", "CYCLOPS Z"]].values
        if "CYCLOPS_POINT" in self.plots[prefix]:
            self.plots[prefix][f"CYCLOPS_POINT"]._offsets3d = ([cyclops[0]], [cyclops[1]], [cyclops[2]])
        else:
            self.plots[prefix][f"CYCLOPS_POINT"] = self.ax.scatter(*cyclops)

        plot_efrc_error("LEFT", gaze_df, cyclops)
        plot_efrc_error("RIGHT", gaze_df, cyclops)
        # plot_gaze = plot_efrc_error("GAZE", gaze_df, hand)

        # plot_eye("LEFT", gaze_df)
        # plot_eye("RIGHT", gaze_df)
        plot_eye("CYCLOPS", cyclops)

        if self.match_participant_perspective:
            local_coord_len = 50
            roll, azim, elev = 0, 0, 0
            if "head X" in gaze_df:
                head = gaze_df[["head X", "head Y", "head Z"]].values
                rot_mat = gaze_df[[f"{head} {idx}" for idx in range(16)]].values.reshape(4,4).T
                segment_x = (rot_mat * np.array([-local_coord_len,0,0,1])).sum(axis=1)[:-1]
                segment_y = (rot_mat * np.array([0,local_coord_len,0,1])).sum(axis=1)[:-1]
                azim = 180-self.get_head_rotation(segment_x, head, 0, local_coord_len)
                elev = -self.get_head_rotation(segment_y, head, 2, local_coord_len)
                self.ax.set_xlim(head[0]-self.depth, head[0]+self.depth)
                self.ax.set_ylim(head[1]-self.width, head[1]+self.width)
                self.ax.set_zlim(head[2]-(self.height+self.height_offset), head[2]+(self.height+self.height_offset))
            self.ax.view_init(azim=azim, elev=elev, roll=roll)

    def plot_hand_markers(self, prefix, acc_label, clear=False):
        def clear_old_bones(seg_labels = None):
            if seg_labels is None:
                seg_labels = [k for k in self.plots[prefix].keys()]
            for seg_label in seg_labels:
                seg = self.plots[prefix][seg_label]
                if type(seg) is list:
                    for i in seg:
                        i.remove
                else:
                    self.plots[prefix][seg_label].remove()
                del self.plots[prefix][seg_label]

        def clear_old(acc_label):
            if prefix in self.plots:
                if f"{acc_label}_IFRC_LED_ERROR" in self.plots[prefix]:
                    self.plots[prefix][f"{acc_label}_IFRC_LED_ERROR"].remove()
                    del self.plots[prefix][f"{acc_label}_IFRC_LED_ERROR"]
                if f"{acc_label}_LED_VECTOR" in self.plots[prefix]:
                    self.plots[prefix][f"{acc_label}_LED_VECTOR"].remove()
                    del self.plots[prefix][f"{acc_label}_LED_VECTOR"]
                if f"{acc_label}_IFRC_ERROR_TXT" in self.plots[prefix]:
                    self.plots[prefix][f"{acc_label}_IFRC_ERROR_TXT"].remove()
                    del self.plots[prefix][f"{acc_label}_IFRC_ERROR_TXT"]
        markers = self.trials[self.condition][self.trial_idx]["FULL"].iloc[self.frame]
        x_markers = markers[[f"{prefix}{label} X" for label in hand_skeleton_markers if f"{prefix}{label} X" in markers]].replace(0, np.nan)
        y_markers = markers[[f"{prefix}{label} Y" for label in hand_skeleton_markers if f"{prefix}{label} Y" in markers]].replace(0, np.nan)
        z_markers = markers[[f"{prefix}{label} Z" for label in hand_skeleton_markers if f"{prefix}{label} Z" in markers]].replace(0, np.nan)
        
        if clear:
            clear_old("LEFT")
            clear_old("RIGHT")

        if prefix not in self.plots:
            self.plots[prefix] = {}
        
        for i in range(x_markers.shape[0]):
            if np.isnan(x_markers.iloc[i]) or np.isnan(y_markers.iloc[i]) or np.isnan(z_markers.iloc[i]):
                x_markers.iloc[i] = np.nan
                y_markers.iloc[i] = np.nan
                z_markers.iloc[i] = np.nan
        x_markers.dropna()
        y_markers.dropna()
        z_markers.dropna()

        seg_labels = set()
        for colour, segments in hand_skeleton_graph:
            bones = [[]]
            for segment in segments:
                if f"{prefix}{segment} X" in x_markers:
                    bones[-1].append(f"{prefix}{segment}")
                else:
                    bones.append([])
            for bone in bones:
                if len(bone) > 0:
                    seg_label = f"{bone[0]}"
                    seg_labels.add(seg_label)
                    joints = [[f"{b} {axis}" for b in bone] for axis in ["X", "Y", "Z"]]
                    if seg_label in self.plots[prefix]:
                        self.plots[prefix][seg_label].set_data(
                            x_markers[joints[0]].values, 
                            y_markers[joints[1]].values
                        )
                        self.plots[prefix][seg_label].set_3d_properties(z_markers[joints[2]].values)
                    else:
                        self.plots[prefix][seg_label] = self.ax.plot(x_markers[joints[0]], y_markers[joints[1]], z_markers[joints[2]], color=colour)[0]
        old_segments = seg_labels - set(self.plots[prefix].keys())
        if len(old_segments) > 0 and not self.render_all:
            clear_old_bones(old_segments)

        accuracy_df = self.trials[self.condition][self.trial_idx]["FULL"].iloc[self.frame]
        if "TARGET" in self.trials[self.condition][self.trial_idx]:
            led = self.trials[self.condition][self.trial_idx]["TARGET"].subtarget.to_array()
            if f"{prefix}IndexTip X" in accuracy_df and f"{prefix}HandIn X" in accuracy_df:
                if f"{acc_label}_IFRC_COSINE_ERR" in accuracy_df.index:
                    led_err = accuracy_df[f"{acc_label}_IFRC_COSINE_ERR"]
                    
                    if (np.isnan(led_err) or led_err > 90): # if no value, or exceeds 90 degrees, don't plot
                        clear_old(acc_label)
                        return
                
                emitter_x = accuracy_df[f"{prefix}IndexTip X"]
                emitter_y = accuracy_df[f"{prefix}IndexTip Y"]
                emitter_z = accuracy_df[f"{prefix}IndexTip Z"]
                
                origin_x = accuracy_df[f"{prefix}HandIn X"]
                origin_y = accuracy_df[f"{prefix}HandIn Y"]
                origin_z = accuracy_df[f"{prefix}HandIn Z"]
                
                intersect = np.array([emitter_x, emitter_y, emitter_z])
                origin = np.array([origin_x, origin_y, origin_z])
                ray = intersect - origin
                ray = ray / np.linalg.norm(ray)
                correct_ray = led - origin
                correct_ray = correct_ray / np.linalg.norm(correct_ray)
                ray = origin + (ray * 3000)
                
                if self.hand is None or self.hand == acc_label:
                    if f"{acc_label}_LED_VECTOR" in self.plots[prefix]:
                        self.plots[prefix][f"{acc_label}_LED_VECTOR"].set_data_3d([origin[0], ray[0]], [origin[1], ray[1]], [origin[2], ray[2]])
                    else:
                        self.plots[prefix][f"{acc_label}_LED_VECTOR"] = self.ax.plot([origin[0], ray[0]], [origin[1], ray[1]], [origin[2], ray[2]], linestyle="dashed", color="red", zorder=100)[0]
            elif not self.render_all:
                clear_old(acc_label)
        elif not self.render_all:
            clear_old(acc_label)

    def process_pointing_window_change(self, side, ballistic_idx, start, end):
        if MANUAL_ANNOTATIONS_KEY not in self.trials[self.condition][self.trial_idx][HVP_KEY][side]:
            self.trials[self.condition][self.trial_idx][HVP_KEY][side][MANUAL_ANNOTATIONS_KEY] = {INITIAL_BALLISTIC_KEY: [], TERMINAL_BALLISTIC_KEY: []}
        start_changed = start != self.ballistic_ranges[side][ballistic_idx][0]
        end_changed = end != self.ballistic_ranges[side][ballistic_idx][1]
        ballistic_key = INITIAL_BALLISTIC_KEY if ballistic_idx == 0 else TERMINAL_BALLISTIC_KEY
        self.trials[self.condition][self.trial_idx][HVP_KEY][side][MANUAL_ANNOTATIONS_KEY][ballistic_key] = [start, end]
        self.overriden = True
        self.ballistic_ranges[side][ballistic_idx] = (start, end)
        frame = start if start_changed else (end if end_changed else self.frame)
        if frame != self.frame:
            self.update_frame(frame)

    def plot_angles(self):
        def get_joint_angle(joint_frame, plane_labels, v0_labels, v1_labels, orthogonal_plane: bool = False, perpendicular_plane: bool = False, foo: str | None = None):
            for joint_label in set(plane_labels) | set(v0_labels) | set(v1_labels):
                if f"{joint_label} X" not in joint_frame:
                    return np.nan
            p0 = joint_frame[[f"{plane_labels[0]} {axis}" for axis in ["X", "Y", "Z"]]].values
            p1 = joint_frame[[f"{plane_labels[1]} {axis}" for axis in ["X", "Y", "Z"]]].values - p0
            p1 = p1 / np.linalg.norm(p1)
            if foo is not None:
                p2 = joint_frame[[f"{foo} {axis}" for axis in ["X", "Y", "Z"]]].values - p0
                p2 = joint_frame[[f"{plane_labels[2]} {axis}" for axis in ["X", "Y", "Z"]]].values - p0 - p2
            else:
                p2 = joint_frame[[f"{plane_labels[2]} {axis}" for axis in ["X", "Y", "Z"]]].values - p0
            p2 = p2 / np.linalg.norm(p2)           

            plane_normal = np.cross(p1, p2)
            if plane_labels[3] == plane_labels[1]:
                y_axis = p1
            elif plane_labels[3] == plane_labels[2]:
                y_axis = p2
            else: # only calculate if not a point already provided.
                y_axis = joint_frame[[f"{plane_labels[3]} {axis}" for axis in ["X", "Y", "Z"]]].values
                y_axis = y_axis - (np.dot(y_axis-p0, plane_normal) * plane_normal)
                y_axis = y_axis / np.linalg.norm(y_axis)
            
            if orthogonal_plane:
                x_axis = np.copy(plane_normal)
                plane_normal = np.cross(x_axis, y_axis)
            else:
                x_axis = np.cross(plane_normal, y_axis)

            if perpendicular_plane:
                old_y = np.copy(y_axis)
                y_axis = np.copy(plane_normal)
                plane_normal = old_y

            # shift points to lie on the plane provided
            v0_origin = joint_frame[[f"{v0_labels[0]} {axis}" for axis in ["X", "Y", "Z"]]].values
            v0_origin = v0_origin - (np.dot(v0_origin-p0, plane_normal) * plane_normal)
            v0 = joint_frame[[f"{v0_labels[1]} {axis}" for axis in ["X", "Y", "Z"]]].values
            v0 = (v0 - (np.dot(v0-p0, plane_normal) * plane_normal)) - v0_origin
            v0_mag = np.linalg.norm(v0)

            v1_origin = joint_frame[[f"{v1_labels[0]} {axis}" for axis in ["X", "Y", "Z"]]].values
            v1_origin = v1_origin - (np.dot(v1_origin-p0, plane_normal) * plane_normal)
            v1 = joint_frame[[f"{v1_labels[1]} {axis}" for axis in ["X", "Y", "Z"]]].values
            v1 = (v1 - (np.dot(v1-p0, plane_normal) * plane_normal)) - v1_origin
            v1_mag = np.linalg.norm(v1)

            angle = np.rad2deg(np.arccos(np.dot(v0, v1) / (v0_mag * v1_mag)))
            
            y_markers = (np.dot(v0 / v0_mag, y_axis), 0, np.dot(v1 / v1_mag, y_axis))
            x_markers = (np.dot(v0 / v0_mag, x_axis), 0, np.dot(v1 / v1_mag, x_axis))
            planar_angle = np.rad2deg(np.arccos(np.dot([x_markers[0], y_markers[0]], [x_markers[2], y_markers[2]]))) 
            return (p0, y_axis, x_axis, v0, v0_origin, v1, v1_origin, plane_normal), x_markers, y_markers, angle, planar_angle # bad name?

        motion_x = 0
        joint_angles = ["ELBOW", "ARM_TORSO", "SHOULDER_TRANSVERSE"]
        for label in ["l", "r"]:
            prefix = f"{label}_JOINT_ANGLES"

            if prefix not in self.plots:
                self.plots[prefix] = {}
                for idx, p in enumerate(joint_angles):
                    self.plots[prefix][f"{p}_{label}_AX"] = self.fig.add_axes([motion_x+(0.25/(len(joint_angles))*idx), 0.25, 0.25/len(joint_angles), 0.05])
                    self.plots[prefix][f"{p}_{label}_AX"].set_autoscale_on(False)
                    self.plots[prefix][f"{p}_{label}_AX"].set_aspect(1)
                    self.plots[prefix][f"{p}_{label}_AX"].set_xbound(-1, 1)
                    if p == "ARM_TORSO":
                        self.plots[prefix][f"{p}_{label}_AX"].set_ybound(1, -1)
                    else:
                        self.plots[prefix][f"{p}_{label}_AX"].set_ybound(-1, 1)

            joint_frame = self.trials[self.condition][self.trial_idx]["FULL"].iloc[self.frame]
            elbow_proj_3d, elbow_x, elbow_y, elbow_extension, planar_elbow_extension = get_joint_angle(joint_frame, (f"{label}_larm", f"{label}_uarm", f"{label}_hand", f"{label}_uarm"), (f"{label}_larm", f"{label}_hand"), (f"{label}_larm", f"{label}_uarm"))
            shoulder_ext_proj_3d, shoulder_extension_x, shoulder_extension_y, shoulder_extension, planar_shoulder_extension = get_joint_angle(joint_frame, ("torso", "pelvis", f"{label}_larm", "pelvis"), (f"torso", f"pelvis"), (f"{label}_uarm", f"{label}_larm"), foo=f"{label}_uarm")
            shoulder_rot_proj_3d, shoulder_rotation_x, shoulder_rotation_y, shoulder_rotation, planar_shoulder_rotation = get_joint_angle(joint_frame, ("pelvis", "l_uarm", "r_uarm", "torso"), ("l_uarm", "r_uarm"), (f"{label}_uarm", f"{label}_larm"), perpendicular_plane=True)
            planar_shoulder_extension = np.cos(np.deg2rad(shoulder_rotation)) * shoulder_extension
            if label == "l":
                planar_shoulder_extension = 0 - planar_shoulder_extension
            planar_shoulder_rotation = np.sin(np.deg2rad(shoulder_rotation)) * shoulder_extension

            plots = [(f"ELBOW_{label}", elbow_proj_3d, elbow_x, elbow_y, elbow_extension, planar_elbow_extension), (f"ARM_TORSO_{label}", shoulder_ext_proj_3d, shoulder_extension_x, shoulder_extension_y, shoulder_extension, planar_shoulder_extension), (f"SHOULDER_TRANSVERSE_{label}", shoulder_rot_proj_3d, shoulder_rotation_x, shoulder_rotation_y, shoulder_rotation, planar_shoulder_rotation)]
            for ax_label, projection, x_markers, y_markers, angle, planar_angle in plots:
                origin, y_axis, x_axis, v0, v0_origin, v1, v1_origin, normal = projection
                plot_label = f"ANGLE_PLOT_{ax_label}"
                if f"{plot_label}_UPPER" in self.plots[prefix]:
                    self.plots[prefix][f"{plot_label}_UPPER"].set_data(
                        x_markers[:2], 
                        y_markers[:2]
                    )
                    self.plots[prefix][f"{plot_label}_LOWER"].set_data(
                        x_markers[1:], 
                        y_markers[1:]
                    )
                    if f"{ax_label}_TEXT" in self.plots[prefix]:
                        self.plots[prefix][f"{ax_label}_TEXT"].remove()
                        del self.plots[prefix][f"{ax_label}_TEXT"]
                    if "ARM_TORSION" in ax_label:
                        self.plots[prefix][f"{plot_label}_NORM"].set_data(
                            [origin[0], (normal[0]*500)+origin[0]], [origin[1], (normal[1]*500)+origin[1]]
                        )
                        self.plots[prefix][f"{plot_label}_NORM"].set_3d_properties([origin[2], (normal[2]*500)+origin[2]])
                        self.plots[prefix][f"{plot_label}_X"].set_data(
                            [origin[0], (x_axis[0]*500)+origin[0]], [origin[1], (x_axis[1]*500)+origin[1]]
                        )
                        self.plots[prefix][f"{plot_label}_X"].set_3d_properties([origin[2], (x_axis[2]*500)+origin[2]])
                        self.plots[prefix][f"{plot_label}_Y"].set_data(
                            [origin[0], (y_axis[0]*500)+origin[0]], [origin[1], (y_axis[1]*500)+origin[1]]
                        )
                        self.plots[prefix][f"{plot_label}_Y"].set_3d_properties([origin[2], (y_axis[2]*500)+origin[2]])
                        self.plots[prefix][f"{plot_label}_v0"].set_data(
                            [origin[0], (v0[0]*500)+origin[0]], [origin[1], (v0[1]*500)+origin[1]]
                        )
                        self.plots[prefix][f"{plot_label}_v0"].set_3d_properties([origin[2], (v0[2]*500)+origin[2]])
                        self.plots[prefix][f"{plot_label}_v1"].set_data(
                            [origin[0], (v1[0]*500)+origin[0]], [origin[1], (v1[1]*500)+origin[1]]
                        )
                        self.plots[prefix][f"{plot_label}_v1"].set_3d_properties([origin[2], (v1[2]*500)+origin[2]])
                else:
                    self.plots[prefix][f"{plot_label}_UPPER"] = self.plots[prefix][f"{ax_label}_AX"].plot(x_markers[1:], y_markers[1:], color="red" if label == "l" else "green")[0]
                    self.plots[prefix][f"{plot_label}_LOWER"] = self.plots[prefix][f"{ax_label}_AX"].plot(x_markers[:2], y_markers[:2], color="pink" if label == "l" else "lime")[0]
                    self.plots[prefix][f"{ax_label}_AX"].set_title(ax_label[:-2])
                    if "ARM_TORSION" in ax_label:
                        self.plots[prefix][f"{plot_label}_NORM"] = self.ax.plot([origin[0], (normal[0]*500)+origin[0]], [origin[1], (normal[1]*500)+origin[1]], [origin[2], (normal[2]*500)+origin[2]], color="pink")[0]
                        self.plots[prefix][f"{plot_label}_X"] = self.ax.plot([origin[0], (x_axis[0]*500)+origin[0]], [origin[1], (x_axis[1]*500)+origin[1]], [origin[2], (x_axis[2]*500)+origin[2]], color="orange")[0]
                        self.plots[prefix][f"{plot_label}_v0"] = self.ax.plot([origin[0], (v0[0]*500)+origin[0]], [origin[1], (v0[1]*500)+origin[1]], [origin[2], (v0[2]*500)+origin[2]], color="orange", linestyle='dashed')[0]
                        self.plots[prefix][f"{plot_label}_Y"] = self.ax.plot([origin[0], (y_axis[0]*500)+origin[0]], [origin[1], (y_axis[1]*500)+origin[1]], [origin[2], (y_axis[2]*500)+origin[2]], color="lime")[0]
                        self.plots[prefix][f"{plot_label}_v1"] = self.ax.plot([origin[0], (v1[0]*500)+origin[0]], [origin[1], (v1[1]*500)+origin[1]], [origin[2], (v1[2]*500)+origin[2]], color="lime", linestyle='dashed')[0]

                self.plots[prefix][f"{ax_label}_TEXT"] = self.plots[prefix][f"{ax_label}_AX"].text(-1, -3, f"Angle: {angle:3.2f}\nPlanar: {planar_angle:3.2f}", color="black")
            motion_x = 0.25

    def plot_motion_graph(self, new_trial):
        motion_x = 0
        for label in [L_VEL_KEY, R_VEL_KEY]:
            for motion in ["Velocity"]:#, "Acceleration"]:
                prefix = f"{label}_{motion}"
                if not self.render_all and prefix in self.plots and "VLINE" in self.plots[prefix] and len(self.plots[prefix]["AX"].lines) > 0:
                    try:
                        self.plots[prefix]["VLINE"].remove()
                    except Exception:
                        pass

                frame_count = self.trials[self.condition][self.trial_idx]["FULL"].shape[0]
                if not self.render_all and new_trial and prefix in self.plots:
                    for line in self.plots[prefix]["AX"].lines:
                        line.remove()
                    if INITIAL_BALLISTIC_KEY in self.plots[prefix]:
                        try:
                            self.plots[prefix][INITIAL_BALLISTIC_KEY].remove()
                        except Exception:
                            print(f"Unable to remove {INITIAL_BALLISTIC_KEY} from {prefix} graph")
                    if "HOLD" in self.plots[prefix]:
                        try:
                            self.plots[prefix]["HOLD"].remove()
                        except Exception:
                            print(f"Unable to remove HOLD from {prefix} graph")
                    if TERMINAL_BALLISTIC_KEY in self.plots[prefix]:
                        try:
                            self.plots[prefix][TERMINAL_BALLISTIC_KEY].remove()
                        except Exception:
                            print(f"Unable to remove {TERMINAL_BALLISTIC_KEY} from {prefix} graph")
                if prefix not in self.plots:
                    self.plots[prefix] = {
                        "AX": self.fig.add_axes([0.65, 0.7-motion_x, 0.3, 0.1]),
                        "START": self.fig.add_axes([0.65, 0.7-motion_x-0.05, 0.3, 0.01]),
                        "END": self.fig.add_axes([0.65, 0.7-motion_x-0.075, 0.3, 0.01])
                    }
                    self.ballistic_ranges[label] = [(0,frame_count), (0,frame_count)]
                    if not self.render_all:
                        self.controls[f"{label}_START_RANGE"] = RangeSlider(self.plots[prefix]["START"], INITIAL_BALLISTIC_KEY, 0, frame_count-1, valstep=1)
                        self.controls[f"{label}_START_RANGE"].on_changed(lambda vals, side=label: self.process_pointing_window_change(side, 0, vals[0], vals[1]))
                        self.controls[f"{label}_END_RANGE"] = RangeSlider(self.plots[prefix]["END"], TERMINAL_BALLISTIC_KEY, 0, frame_count-1, valstep=1)
                        self.controls[f"{label}_END_RANGE"].on_changed(lambda vals, side=label: self.process_pointing_window_change(side, 1, vals[0], vals[1]))
                if new_trial:
                    if f"{label[0].lower()}_hand X {motion}" in self.trials[self.condition][self.trial_idx]["FULL"]: 
                        self.plots[prefix][f"{label}_{motion}_X"] = self.plots[prefix]["AX"].plot(self.trials[self.condition][self.trial_idx]["FULL"][f"{label[0].lower()}_hand X {motion}"], color="red", label="X")
                        self.plots[prefix][f"{label}_{motion}_Y"] = self.plots[prefix]["AX"].plot(self.trials[self.condition][self.trial_idx]["FULL"][f"{label[0].lower()}_hand Y {motion}"], color="blue", label="Y")
                        self.plots[prefix][f"{label}_{motion}_Z"] = self.plots[prefix]["AX"].plot(self.trials[self.condition][self.trial_idx]["FULL"][f"{label[0].lower()}_hand Z {motion}"], color="green", label="Z")
                        if motion == "Velocity":
                            windows = self.trials[self.condition][self.trial_idx][HVP_KEY]
                            smoothed = self.trials[self.condition][self.trial_idx]["FULL"][f"{label[0].lower()}_hand MEAN VELOCITY"]
                            self.plots[prefix][f"{label}_{motion}_MEAN"] = self.plots[prefix]["AX"].plot(smoothed, color="black", linestyle="dashed", label="MEAN", zorder=200)
                            
                            # get pointing regions
                            render_regions = False
                            if len(windows[label][START_KEY]) == 2 and len(windows[label][END_KEY]) == 2:
                                ballistic_0 = (windows[label][START_KEY][0], windows[label][END_KEY][0])
                                ballistic_1 = (windows[label][START_KEY][-1], windows[label][END_KEY][-1])
                                render_regions = True
                            else:
                                ballistic_0 = (0, frame_count)
                                ballistic_1 = (0, frame_count)
                            if MANUAL_ANNOTATIONS_KEY in windows[label]:
                                overrides = windows[label][MANUAL_ANNOTATIONS_KEY]
                                ballistic_0 = (overrides[INITIAL_BALLISTIC_KEY][0], overrides[INITIAL_BALLISTIC_KEY][1]) if len(overrides[INITIAL_BALLISTIC_KEY]) > 0 else (windows[label][START_KEY][0], windows[label][END_KEY][0])
                                ballistic_1 = (overrides[TERMINAL_BALLISTIC_KEY][0], overrides[TERMINAL_BALLISTIC_KEY][1]) if len(overrides[TERMINAL_BALLISTIC_KEY]) > 0 else (windows[label][START_KEY][-1], windows[label][END_KEY][-1])
                                render_regions = True

                            # plot the regions if not ignored
                            if render_regions and (IGNORE_HAND_KEY not in windows[label]):
                                self.plots[prefix][INITIAL_BALLISTIC_KEY] = self.plots[prefix]["AX"].axvspan(ballistic_0[0], ballistic_0[1], facecolor='firebrick', alpha=0.5)
                                self.plots[prefix]["HOLD"] = self.plots[prefix]["AX"].axvspan(ballistic_0[1], ballistic_1[0], facecolor='greenyellow', alpha=0.5)
                                self.plots[prefix][TERMINAL_BALLISTIC_KEY] = self.plots[prefix]["AX"].axvspan(ballistic_1[0], ballistic_1[1], facecolor='lightsteelblue', alpha=0.5)
                            
                            for idx, (key, frames) in enumerate(windows[label].items()):
                                if key == IGNORE_HAND_KEY:
                                    continue
                                for frame in frames:
                                    colour = "cyan" if key == START_KEY else ("purple" if key == END_KEY else "black")
                                    self.plots[prefix][f"EDGE {idx}-{frame}"] = self.plots[prefix]["AX"].axvline(frame, linestyle="-", color=colour)

                            self.ballistic_ranges[label][0] = ballistic_0
                            self.ballistic_ranges[label][1] = ballistic_1
                            if not self.render_all:
                                self.controls[f"{label}_START_RANGE"].eventson = False
                                self.controls[f"{label}_END_RANGE"].eventson = False
                                self.controls[f"{label}_START_RANGE"].set_val(ballistic_0)
                                self.controls[f"{label}_END_RANGE"].set_val(ballistic_1)
                                self.controls[f"{label}_START_RANGE"].eventson = True
                                self.controls[f"{label}_END_RANGE"].eventson = True
                    
                    self.plots[prefix]["AX"].set_xbound(0, frame_count)
                    self.plots[prefix]["AX"].set_ybound(-4 if motion == "Velocity" else -30, 4 if motion == "Velocity" else 30)
                    tick_count = frame_count // 50
                    ticks = [0]
                    tick_labels = [0]
                    for i in range(tick_count):
                        tick = (i+1)*50
                        tick_labels.append(tick)
                        ticks.append(frame_count / tick)
                    self.plots[prefix]["AX"].set_xticks(tick_labels, tick_labels)
                    
                if not self.render_all:
                    self.plots[prefix]["VLINE"] = self.plots[prefix]["AX"].axvline(self.frame)
            motion_x += 0.3

    def process_flag(self, flag):
        self.overriden = True
        if FLAGS_KEY not in self.trials[self.condition][self.trial_idx]:
            self.trials[self.condition][self.trial_idx][FLAGS_KEY] = {}
        if flag not in self.trials[self.condition][self.trial_idx][FLAGS_KEY]:
            print(f"Adding Flag: {flag}")
            self.trials[self.condition][self.trial_idx][FLAGS_KEY][flag] = True
        else:
            print(f"Removing Flag: {flag}")
            self.trials[self.condition][self.trial_idx][FLAGS_KEY].pop(flag)

    def process_ignore(self, label, side):
        self.overriden = True
        if "IGNORE" not in self.trials[self.condition][self.trial_idx][HVP_KEY][side]:
            self.trials[self.condition][self.trial_idx][HVP_KEY][side][IGNORE_HAND_KEY] = True
            print(f"Ignoring: ")
            self.controls[f"{label}_IGNORE"].label.set_text(f"{label} - IGNORED")
        else:
            self.trials[self.condition][self.trial_idx][HVP_KEY][side].pop(IGNORE_HAND_KEY)
            self.controls[f"{label}_IGNORE"].label.set_text(label) 


    def ignore_check(self, label):
        expected = [self.hand_dominance == "BOTH" or self.hand_dominance == label, True if np.isnan(self.trials[self.condition][self.trial_idx]["FULL"].iloc[self.frame][f"{label}_X_POS"]) else False] # lazy convert to py bool instead of np bool
        label = f"{label}_DOMINANCE"
        self.controls[label].eventson = False
        for f_idx, flag in enumerate(expected):
            self.controls[label].set_active(f_idx, flag)
        self.controls[label].eventson = True

    def plot_error_graph(self, new_trial):
        error_x = 0
        for label in ["LEFT", "RIGHT", "GAZE"]:
            side = "LH_L" if label == "LEFT" else "RH_R"
            vel_side = R_VEL_KEY if label == "RIGHT" else L_VEL_KEY
            plot_key = f"{label}_ERR"
            if plot_key in self.plots and "VLINE" in self.plots[plot_key] and "AX" in self.plots[plot_key] and len(self.plots[plot_key]["AX"].lines) > 0:
                self.plots[plot_key]["VLINE"].remove()
            if new_trial and plot_key in self.plots and "AX" in self.plots[plot_key]:
                for line in self.plots[plot_key]["AX"].lines:
                    line.remove()

            if plot_key not in self.plots:
                self.plots[plot_key] = {
                    "INPUT": self.fig.add_axes([0.5, 0.8-error_x-(0 if label != "GAZE" else 0.1), 0.1, 0.1 if label != "GAZE" else 0.2])
                }
                if label != "GAZE":
                    self.plots[plot_key]["AX"] = self.fig.add_axes([0.65, 0.8-error_x, 0.3, 0.1])
                    self.plots[plot_key]["INPUT"].set_title(label)
                    self.controls[f"{label}_IGNORE"] = Button(self.plots[plot_key]["INPUT"], "IGNORE")
                    self.controls[f"{label}_IGNORE"].on_clicked(lambda x, label=label, side=vel_side: self.process_ignore(label, side))
                    self.plots[plot_key]["DOMINANCE"] = self.fig.add_axes([0.5, 0.8-error_x-0.1, 0.1, 0.1])
                    actives=[self.hand_dominance == "BOTH" or self.hand_dominance == label, False]
                    self.controls[f"{label}_DOMINANCE"] = CheckButtons(self.plots[plot_key]["DOMINANCE"], ["DOMINANT HAND"], actives=actives)
                    self.controls[f"{label}_DOMINANCE"].on_clicked(lambda flag, label=label: self.ignore_check(label))
                else:
                    self.plots[plot_key]["INPUT"].set_title("Exclusion Flags")
                    self.controls[f"FLAGS"] = CheckButtons(self.plots[plot_key]["INPUT"], FLAGS)
                    self.controls[f"FLAGS"].on_clicked(self.process_flag)
            
            error_x += 0.3
            if new_trial:
                accuracy_df = self.trials[self.condition][self.trial_idx]["FULL"]
                if label != "GAZE":
                    label_text = "IGNORE" if IGNORE_HAND_KEY not in self.trials[self.condition][self.trial_idx][HVP_KEY][vel_side] else "IGNORED"
                    self.controls[f"{label}_IGNORE"].label.set_text(label_text)
                else:
                    active_flags=self.trials[self.condition][self.trial_idx][FLAGS_KEY] if FLAGS_KEY in self.trials[self.condition][self.trial_idx] else {}
                    if len(active_flags) == 0:
                        if "head X" not in accuracy_df:
                            active_flags["BODY_TRACKING"] = True
                            self.overriden = True
                        if np.isnan(self.active_target.top_right.x) and (self.active_target.override is not None and np.isnan(self.active_target.override.top_right.x)):
                            active_flags["TARGET_TRACKING"] = True
                            self.overriden = True
                    self.controls[f"FLAGS"].eventson = False
                    for f_idx, flag in enumerate(FLAGS):
                        self.controls[f"FLAGS"].set_active(f_idx, flag in active_flags)
                    self.controls[f"FLAGS"].eventson = True
                if "GAZE" != label:
                    for model in ["EFRC", "IFRC"]:
                        prefix = f"{label}_{model}_COSINE_ERR"
                        if prefix in set(accuracy_df.columns.values):
                            error = accuracy_df[prefix]
                            self.plots[plot_key][f"{label}_{model}"] = self.plots[plot_key]["AX"].plot(error, color="purple" if model == "EFRC" else "red", label=model)
                    self.plots[plot_key]["AX"].set_ybound(0, 180)
                    self.plots[plot_key]["AX"].set_xbound(0, accuracy_df.shape[0])

                    self.plots[plot_key]["AX"].get_xaxis().set_visible(False)

            if "GAZE" != label:
                self.plots[plot_key]["VLINE"] = self.plots[plot_key]["AX"].axvline(self.frame)

    def plot_joint_angles_graph(self, new_trial):
        ja_axis_x = 0

        for side in ["LEFT", "RIGHT"]:
            ja_axis_y = 0
            for joint in ["ELBOW", "SHOULDER"]:
                label = f"{side}_{joint}"
                plot_key = f"{label}_ANGLES"
                if plot_key in self.plots and "VLINE" in self.plots[plot_key] and len(self.plots[plot_key]["AX"].lines) > 0:
                    self.plots[plot_key]["VLINE"].remove()
                if new_trial and plot_key in self.plots:
                    for line in self.plots[plot_key]["AX"].lines:
                        line.remove()
                
                if plot_key not in self.plots:
                    self.plots[plot_key] = {
                        "AX": self.fig.add_axes([0.55+ja_axis_x, 0.30-ja_axis_y, 0.2, 0.1])
                    }
                ja_axis_y += 0.1

                if new_trial:
                    features_df = self.trials[self.condition][self.trial_idx]["FULL"]
                    if f"{side} {joint} X" in features_df:
                        self.plots[plot_key][f"{label}_X"] = self.plots[plot_key]["AX"].plot(features_df[f"{side} {joint} X"].values, color="red", label="X")
                        self.plots[plot_key][f"{label}_Y"] = self.plots[plot_key]["AX"].plot(features_df[f"{side} {joint} Y"].values, color="blue", label="Y")
                        self.plots[plot_key][f"{label}_Z"] = self.plots[plot_key]["AX"].plot(features_df[f"{side} {joint} Z"].values, color="green", label="Z")
                        self.plots[plot_key][f"{label}_LEGEND"] = self.plots[plot_key]["AX"].legend()
                    self.plots[plot_key]["AX"].set_xbound(0, features_df.shape[0])
                    self.plots[plot_key]["AX"].set_ybound(0, np.pi)

                self.plots[plot_key]["VLINE"] = self.plots[plot_key]["AX"].axvline(self.frame)

            ja_axis_x += 0.2

        plot_key = "HEAD_ANGLES"
        if plot_key in self.plots and "VLINE" in self.plots[plot_key] and len(self.plots[plot_key]["AX"].lines) > 0:
            self.plots[plot_key]["VLINE"].remove()
        if new_trial and plot_key in self.plots:
            for line in self.plots[plot_key]["AX"].lines:
                line.remove()
        if plot_key not in self.plots:
            self.plots[plot_key] = {
                "AX": self.fig.add_axes([0.65, 0.1, 0.2, 0.1])
            }
        if new_trial:
            features_df = self.trials[self.condition][self.trial_idx]["FULL"]
            if "HEAD X" in features_df:
                self.plots[plot_key][f"HEAD_X"] = self.plots[plot_key]["AX"].plot(features_df["HEAD X"].values, color="red", label="X")
                self.plots[plot_key][f"HEAD_Y"] = self.plots[plot_key]["AX"].plot(features_df["HEAD Y"].values, color="blue", label="Y")
                self.plots[plot_key][f"HEAD_Z"] = self.plots[plot_key]["AX"].plot(features_df["HEAD Z"].values, color="green", label="Z")
                self.plots[plot_key][f"HEAD_LEGEND"] = self.plots[plot_key]["AX"].legend()
            self.plots[plot_key]["AX"].set_xbound(0, features_df.shape[0])
            self.plots[plot_key]["AX"].set_ybound(0, np.pi)

        return None

    def compute_errors(self, df, target, prefix, origin_label, emitter_label):
        if f"{origin_label} X" in df and f"{emitter_label} X" in df:
            axes = ["X", "Y", "Z"]
            emitter = df.loc[:, [f"{emitter_label} {axis}" for axis in axes]].values
            origin = df.loc[:, [f"{origin_label} {axis}" for axis in axes]].values
            
            ray = emitter - origin
            ray = ray / np.linalg.norm(ray, axis=1, keepdims=True)
            correct_ray = target - origin
            correct_ray = correct_ray / np.linalg.norm(correct_ray, axis=1, keepdims=True)

            dot = np.clip(np.sum(correct_ray * ray, axis=1), -1.0, 1.0)
            cosine_error = np.rad2deg(np.arccos(dot))
            df[f"{prefix}_COSINE_ERR"] = cosine_error

    # Bind to radio buttons + button
    def update_condition(self, accuracy, distraction, trial = 0, frame = None):
        self.condition = (accuracy, distraction)
        if not self.animated:
            self.trial_slider.valmax = len(self.trials[self.condition]) - 1
            self.trial_slider.valstep = range(len(self.trials[self.condition]))
            self.trial_slider.ax.set_xlim(self.trial_slider.valmin, self.trial_slider.valmax)
        self.override(None)
        self.update_trial(trial, frame)
        return None

    def update_trial(self, i, frame = None):
        self.playing = None
        self.override(None)
        self.trial_idx = i
        self.new_trial = True
        if "Pointing Gesture" in self.trials[self.condition][self.trial_idx] and "Pointing Hand" in self.trials[self.condition][self.trial_idx]["Pointing Gesture"]:
            self.hand = self.trials[self.condition][self.trial_idx]["Pointing Gesture"]["Pointing Hand"]
        else:
            self.hand = None

        trial_df = self.trials[self.condition][self.trial_idx]["FULL"]
        if "Tobii3-Set-L2-R2 - 1 X" in trial_df:
            tobii_markers = {
                "L1": trial_df.loc[:, [f"Tobii3-Set-L2-R2 - 1 {j}" for j in ["X", "Y", "Z"]]].values,
                "R1": trial_df.loc[:, [f"Tobii3-Set-L2-R2 - 4 {j}" for j in ["X", "Y", "Z"]]].values,
                "L2": trial_df.loc[:, [f"Tobii3-Set-L2-R2 - 2 {j}" for j in ["X", "Y", "Z"]]].values,
                "R2": trial_df.loc[:, [f"Tobii3-Set-L2-R2 - 5 {j}" for j in ["X", "Y", "Z"]]].values,
                "L3": trial_df.loc[:, [f"Tobii3-Set-L2-R2 - 3 {j}" for j in ["X", "Y", "Z"]]].values,
                "R3": trial_df.loc[:, [f"Tobii3-Set-L2-R2 - 6 {j}" for j in ["X", "Y", "Z"]]].values,
            }
            for l, r in [("L1", "R1"), ("L2", "R2"), ("L3", "R3")]:
                point = (tobii_markers[l] + tobii_markers[r]) / 2
                tobii_markers[f"{l}_{r}"] = point
            line_to_eye = tobii_markers["L3_R3"] - tobii_markers["L1_R1"]
            line_to_eye = line_to_eye / np.linalg.norm(line_to_eye)
            line_to_eye = line_to_eye*20
            cyclops = tobii_markers["L3_R3"] + line_to_eye
            trial_df[["CYCLOPS X", "CYCLOPS Y", "CYCLOPS Z"]] = cyclops

        target = self.trials[self.condition][self.trial_idx]["TARGET"].subtarget.to_array()
        for prefix, origin, emitter in [("LEFT_IFRC", "LH_LHandIn", "LH_LIndexTip"), ("RIGHT_IFRC", "RH_RHandIn", "RH_RIndexTip"), ("LEFT_EFRC", "CYCLOPS", "LH_LIndexTip"), ("RIGHT_EFRC", "CYCLOPS", "RH_RIndexTip")]:
            self.compute_errors(trial_df, target, prefix, origin, emitter)

        self.recording_idx = self.trials[self.condition][self.trial_idx]["FULL"].index.values[0] if self.trials[self.condition][self.trial_idx]["FULL"].shape[0] > 0 else 0
        if not self.animated and self.trials[self.condition][self.trial_idx]["FULL"].shape[0] > 0:
            self.frame_slider.valmax = self.trials[self.condition][self.trial_idx]["FULL"].shape[0] - 1
            self.frame_slider.valstep = range(self.trials[self.condition][self.trial_idx]["FULL"].shape[0])
            self.frame_slider.ax.set_xlim(self.frame_slider.valmin, self.frame_slider.valmax)
            self.trial_slider.eventson = False
            self.trial_slider.set_val(self.trial_idx)
            self.trial_slider.eventson = True
            self.fig.suptitle(f"Particpant: {self.participant} - Trial: {self.trial_idx}")
        
        self.update_frame(frame if frame is not None else self.trials[self.condition][self.trial_idx]["FULL"].shape[0] // 2, clear_old=True)
        self.new_trial = False
        return None

    def update_frame(self, i, side = None, clear_old = False):
        def flatten(xs):
            for x in xs:
                if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
                    if isinstance(x, dict):
                        yield from flatten(x.values())
                    else:
                        yield from flatten(x)
                else:
                    yield x

        self.frame = i
        if self.frame_slider is not None and self.frame_slider.val != self.frame:
            self.frame_slider.eventson = False
            self.frame_slider.set_val(self.frame)
            self.frame_slider.eventson = True
        update_graphs = self.new_trial
        if side is not None:
            self.side = side
            update_graphs = True
        
        if self.new_trial and not self.render_all:
            self.plot_target_positions()

        if self.trials[self.condition][self.trial_idx]["FULL"].shape[0] > 0:
            if not self.render_all:
                self.plot_error_graph(update_graphs)
                self.plot_angles()
            self.plot_motion_graph(update_graphs)
            self.plot_hand_markers("RH_R", "RIGHT", clear_old)
            self.plot_hand_markers("LH_L", "LEFT", clear_old)
            self.plot_eyes(clear_old)
            self.plot_body()
        artists = list(flatten(self.plots.values()))
        return artists 

    def add_controls(self, frame_count, trial_count, condition):
        self.side = "CYCLOPS"
        self.ballistic_ranges = {}
        self.ax.set_position([0, 0.4, 0.5, 0.6]) 
        controls = {}
        if frame_count > 0:
            frame_axes = self.fig.add_axes([0.65, 0.25, 0.3, 0.01])
            self.frame_slider = Slider(
                ax=frame_axes,
                label='Frame',
                valmin=0,
                valmax=frame_count - 1,
                valstep=range(frame_count),
                valinit=0
            ) 
            previous_frame_axes = self.fig.add_axes([0.65, 0.18, 0.08, 0.05])
            self.previous_frame_button = Button(previous_frame_axes, "prev_frame")
            self.previous_frame_button.on_clicked(lambda _: self.update_frame(self.frame-1) if self.frame > 0 else None)
            next_frame_axes = self.fig.add_axes([0.87, 0.18, 0.08, 0.05])
            self.next_frame_button = Button(next_frame_axes, "next_frame")
            self.next_frame_button.on_clicked(lambda _: self.update_frame(self.frame+1) if self.frame < self.trials[self.condition][self.trial_idx]["FULL"].shape[0]-1 else None)
            controls["FRAMES"] = self.frame_slider
            controls["PREV_FRAME"] = self.previous_frame_button
            controls["NEXT_FRAME"] = self.next_frame_button
        else:
            self.frame_slider = None

        trial_axes = self.fig.add_axes([0.05, 0.15, 0.4, 0.01])
        self.trial_slider = Slider(
            ax=trial_axes,
            label='Trial',
            valmin=0,
            valmax=trial_count - 1,
            valstep=range(trial_count),
            valinit=0
        )
        previous_trial_axes = self.fig.add_axes([0.05, 0.08, 0.1, 0.05])
        self.previous_trial_button = Button(previous_trial_axes, "prev_trial")
        self.previous_trial_button.on_clicked(lambda _: self.update_trial(self.trial_idx-1) if self.trial_idx > 0 else None)
        next_trial_axes = self.fig.add_axes([0.35, 0.08, 0.1, 0.05])
        self.next_trial_button = Button(next_trial_axes, "next_trial")
        self.next_trial_button.on_clicked(lambda _: self.update_trial(self.trial_idx+1) if self.trial_idx < len(self.trials[self.condition])-1 else None)
        controls["TRIALS"] = self.trial_slider
        controls["PREV_TRIAL"] = self.previous_trial_button
        controls["NEXT_TRIAL"] = self.next_trial_button

        accuracy_axes = self.fig.add_axes([0.05, 0.0, 0.2, 0.08])
        acc_labels = ["ACCURATE", "CASUAL"] if trial_count > 1 else [condition[0]]
        acc_active = acc_labels.index(condition[0])
        self.accuracy_radio = RadioButtons(
            ax=accuracy_axes,
            labels=acc_labels, 
            active=acc_active
        )
        controls["ACCURACY"] = self.accuracy_radio
        
        distraction_axes = self.fig.add_axes([0.25, 0.0, 0.2, 0.08])
        focus_labels = ["FOCUSED", "DISTRACTED"] if trial_count > 1 else [condition[1]]
        focus_active = focus_labels.index(condition[1])
        self.distraction_radio = RadioButtons(
            ax=distraction_axes,
            labels=focus_labels,
            active=focus_active
        )
        controls["DISTRACTION"] = self.distraction_radio
            
        return controls

    def override(self, event):
        if self.override_callback is not None and self.overriden:
            print("Calling override callback...")
            self.override_callback(self.participant, self.condition, self.trial_idx, self.trials[self.condition][self.trial_idx][HVP_KEY])
        else:
            print("Skipping override as overriden not set")
        self.overriden = False

    def show(self, trials, participant, initial_condition, initial_trial, initial_frame = None, hand_dominance = None, override_callback = None, show_target_array = False, hand=None):
        self.participant = participant
        self.show_target_array = show_target_array
        self.trial_idx = initial_trial
        self.trials = trials
        self.animated = False
        self.override_callback = override_callback
        self.overriden = False
        self.hand_dominance = hand_dominance
        self.hand = hand
        
        controls = self.add_controls(trials[initial_condition][initial_trial]["FULL"].shape[0], len(trials[initial_condition]), initial_condition)
        self.controls = controls
        
        self.update_condition(*initial_condition, initial_trial, initial_frame)
        self.fig.canvas.mpl_connect("close_event", self.override)
        if self.frame_slider is not None:
            self.frame_slider.on_changed(lambda i: self.update_frame(i))
        if len(trials[initial_condition])>1:
            self.trial_slider.on_changed(lambda i: self.update_trial(i))
            self.accuracy_radio.on_clicked(lambda i: self.update_condition(i, self.condition[1]))
            self.distraction_radio.on_clicked(lambda i: self.update_condition(self.condition[0], i))
        return controls
