import sys
from os.path import  realpath
script_root = realpath(__file__)
sys.path.insert(1, script_root)

import numpy as np
from dependencies.Markers import Orientation, Marker

# Targets
class Target():
    pitch = 15
    yaw = -35
    length = 190 # mm
    padding = 2 # mm
    led_separation = 85 # mm
    marker_strut_len = 53 # mm
    marker_strut_width = 10 # mm
    marker_strut_base_width = 12.4 # mm
    marker_width = 12.5
    marker_separation = 59 - (marker_width) # marker_strut_len + ((marker_strut_base_width + marker_width) / 2) # mm
    marker_strut_origin_offset = [1.68, -((marker_strut_base_width/2) + (190/2)), 25-marker_strut_base_width/2] # mm [x, y, z], double check order

    target_relabelling_dict = {
        # id : y-x
        "11": "(20)",
        "1": "(10)",
        "12": "(00)",
        "4": "(21)",
        "3": "(11)",
        "2": "(01)",
        "5": "(22)",
        "7": "(12)",
        "6": "(02)",
        "14": "(23)",
        "15": "(13)",
        "13": "(03)",
        "10": "(24)",
        "9": "(14)",
        "8": "(04)",
    }

    target_ids = {
        "TLM": ( 0,  0),
        "MLM": ( 1,  0),
        "BLM": ( 2,  0),
        "TL":  ( 0,  1),
        "ML":  ( 1,  1),
        "BL":  ( 2,  1),
        "TC":  ( 0,  2),
        "MC":  ( 1,  2),
        "BC":  ( 2,  2),
        "TR":  ( 0,  3),
        "MR":  ( 1,  3),
        "BR":  ( 2,  3),
        "TRM": ( 0,  4),
        "MRM": ( 1,  4),
        "BRM": ( 2,  4),
    }

    target_position_relabelling_dict = {
        "0": ("TOP", "LEFTMOST"),
        "1": ("MIDDLE", "LEFT"),
        "2": ("BOTTOM", "CENTRE"),
        "3": (None, "RIGHT"),
        "4": (None, "RIGHTMOST")
    }
    def get_target_label(row, col):
        row_name = Target.target_position_relabelling_dict[row][0]
        column_name = Target.target_position_relabelling_dict[col][1]
        return (row_name, column_name, f"{row_name} - {column_name}")

    def __init__(self, row, column, sub_id = None, origin = None, orientation = None, led = None):
        self.row = row
        self.column = column
        self.origin = None if origin is None else Marker.from_array(origin)
        self.orientation = None if orientation is None else Orientation.from_array(orientation)
        self.override = None
        self.subid = sub_id

        if sub_id is None:
            self.sub_col = -1
            self.sub_row = -1
            self.subtarget = None
        else:
            self.set_subtarget(sub_id)
            if led is not None:
                self.subtarget = Marker.from_array(led)

        if self.origin is not None:
            self.top_left = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, True, False)
            self.top_right = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, True, True)
            self.bottom_right = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, False, True)
            self.bottom_left = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, False, False)

            cross_product = np.cross(self.bottom_right.to_array() - self.bottom_left.to_array(), self.top_left.to_array() - self.bottom_left.to_array())
            self.normal = (cross_product / np.linalg.norm(cross_product))


    def get_subtarget(self, id):
        sub_col = (id % 3)
        sub_row = (id // 3)
        subtarget = Marker(0, (1-sub_col) * -Target.led_separation, (1-sub_row) * Target.led_separation).rotate_about_point(self.origin, self.orientation.yaw, self.orientation.pitch, self.orientation.roll)
        return subtarget

    def set_subtarget(self, id):
        self.sub_col = (id // 3)
        self.sub_row = (id % 3)
        self.subtarget = self.get_subtarget(id)
        if self.override is not None:
            self.override.set_subtarget(id)

    def get_corner(origin, length, yaw, pitch, roll, up, right):
        return Marker(0,-length/2 if right else length/2, length/2 if up else -length/2).rotate_about_point(origin, yaw, pitch, roll)

    def init_from_json(self, json_dict):
        self.origin = Marker(json_dict["X"], json_dict["Y"], json_dict["Z"])
        self.orientation = Orientation(json_dict["YAW"], json_dict["PITCH"], json_dict["ROLL"])
        overrides = json_dict.pop("OVERRIDES", {})
        if len(overrides) > 0:
            for feature in json_dict.keys():
                if feature not in overrides:
                    overrides[feature] = json_dict[feature]
                else:
                    overrides[feature] = overrides[feature][0]
            self.override = Target(self.row, self.column)
            self.override.init_from_json(overrides)
        else:
            self.override = None
        self.top_left = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, True, False)
        self.top_right = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, True, True)
        self.bottom_right = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, False, True)
        self.bottom_left = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, False, False)

        cross_product = np.cross(self.bottom_right.to_array() - self.bottom_left.to_array(), self.top_left.to_array() - self.bottom_left.to_array())
        self.normal = (cross_product / np.linalg.norm(cross_product))

    def get_surface(self):
        return (
            np.array([
                [self.bottom_left.x, self.bottom_right.x],
                [self.top_left.x, self.top_right.x],
            ]),
            np.array([
                [self.bottom_left.y, self.bottom_right.y],
                [self.top_left.y, self.top_right.y],
            ]),
            np.array([
                [self.bottom_left.z, self.bottom_right.z],
                [self.top_left.z, self.top_right.z],
            ])
        )
    
    def get_border(self):
        return (
            np.array([
                self.bottom_left.x, self.bottom_right.x, self.top_right.x, self.top_left.x, self.bottom_left.x
            ]),
            np.array([
                self.bottom_left.y, self.bottom_right.y, self.top_right.y, self.top_left.y, self.bottom_left.y
            ]),np.array([
                self.bottom_left.z, self.bottom_right.z, self.top_right.z, self.top_left.z, self.bottom_left.z
            ])
        )

    def intersection(self, ray_origin, ray_intersect, cluster_error_offset = False, epsilon=1e-6):
        if self.normal is None:
            return None
        
        if np.array_equal(ray_origin, ray_intersect): # Points cannot be same
            return None
        
        if np.any(np.isnan(ray_origin)) or np.any(np.isnan(ray_intersect)): # No point calculating
            return None

        # Code adapted from: https://stackoverflow.com/a/18543221
        def line_plane_intersection(p0, p1, p_c0, p_n0):
            def dot_v3v3(v0, v1):
                return (
                    (v0[0] * v1[0]) +
                    (v0[1] * v1[1]) +
                    (v0[2] * v1[2])
                )
            def mul_v3_fl(v0, f):
                return (
                    v0[0] * f,
                    v0[1] * f,
                    v0[2] * f,
                )
            def add_v3v3(v0, v1):
                return (
                    v0[0] + v1[0],
                    v0[1] + v1[1],
                    v0[2] + v1[2],
                )
            def sub_v3v3(v0, v1):
                return (
                    v0[0] - v1[0],
                    v0[1] - v1[1],
                    v0[2] - v1[2],
                )

            u = p1 - p0
            dot = dot_v3v3(p_n0, u)
            if abs(dot) > epsilon: # line is not parallel to plane
                w = sub_v3v3(p0, p_c0)
                fac = -dot_v3v3(p_n0, w) / dot
                if fac > 1: # intersection is in direction of pointing gesture
                    u = mul_v3_fl(u, fac)
                    intersect = add_v3v3(p0, u)
                    ray = (p1 - p0)
                    error = intersect - p_c0
                    diff = (intersect - p0) - ray
                    diff = diff / np.linalg.norm(diff)
                    err_yaw = np.rad2deg(np.arcsin(diff[0]))
                    err_pitch = np.rad2deg(np.arcsin(diff[2]))
                    return { "POS_X": intersect[0], "POS_Y": intersect[1], "POS_Z": intersect[2], "POS_ERROR": np.linalg.norm(error), "ERROR_YAW": err_yaw, "ERROR_PITCH": err_pitch }#np.sqrt(error.dot(error)) }
                
            return None

        intersections = {}

        target_normal = self.normal if self.override is None else self.override.normal

        if self.subtarget is not None:
            target_point = self.subtarget.to_array() if self.override is None else self.override.subtarget.to_array()
            intersection = line_plane_intersection(ray_origin, ray_intersect, target_point, target_normal)
            if intersection is not None:
                intersections["LED"] = {**intersection}

        if cluster_error_offset:
            target_point = self.origin.to_array() if self.override is None else self.override.origin.to_array()
            intersection = line_plane_intersection(ray_origin, ray_intersect, target_point, target_normal)
            if intersection is not None:
                intersections["CLUSTER"] = {**intersection}

        return intersections

    # Our marker sets provide 3 points, from which we can derive the orientation, and the position of the target.
    # To get the location we need to first calculate the corner of the marker set (where the 3 markers extend from). This is done be calculating the normal of the triangle defined by the 3 points.
    # The normal also allows us to calculate the orientation of the target.
    # With the marker set origin and orientation we can then, based on the placement of the marker set on the target, calculate the target centre.
    def init_position_from_marker_set(self, markers):
        def get_orientation(p0, p1, origin):
            p0 = p0 - origin
            p0_mag = np.sqrt(np.sum(np.square(p0)))
            p0 = p0 / p0_mag
            p1 = p1 - origin
            p1_mag = np.sqrt(np.sum(np.square(p1)))
            p1 = p1 / p1_mag
            # x_angle = np.arctan2( p0[1], p0[0] )  # np.dot(p0[idx,1:], p1[idx,1:])
            # y_angle = np.arctan2( p0[0], p0[1])  # np.dot(p0[idx,::2], p1[idx,::2])
            # z_angle = np.arctan2( p0[0], p0[2])  # np.dot(p0[idx,:2], p1[idx,:2])
            # https://gamedev.stackexchange.com/a/172153
            yaw = np.arctan2(p1[1], p1[0])
            pitch = -np.arcsin(p1[2])
            roll = np.arcsin((p0[0]*np.sin(yaw)) + (p0[1]*(-np.cos(yaw))))
            if p0[2] < 0:
                roll = np.sign(roll) * (np.pi - roll)
            angles = (np.rad2deg(yaw), np.rad2deg(pitch), np.rad2deg(roll))
            print(f"p0: {p0}, p1: {p1}, Corner: {origin}, Angle: {angles}, marker_sep: {Target.marker_separation}, mag: [p0: {p0_mag}, p1: {p1_mag}, expected: {np.sqrt(np.square(Target.marker_separation)*2)}]")
            # print(f"Vector: {angle}, adj: {angle[axis]}")
            return angles

        self.marker_x = Marker.from_array(markers[[key for key in markers.keys() if "_X" in key][0]])
        self.marker_y = Marker.from_array(markers[[key for key in markers.keys() if "_Y" in key][0]])
        self.marker_z = Marker.from_array(markers[[key for key in markers.keys() if "_Z" in key][0]])
        p0 = self.marker_y.to_array()
        p1 = self.marker_z.to_array()
        p2 = self.marker_x.to_array()
        # hyp = np.sqrt(np.sum(np.square(p0 - p1)))
        # measured_strut_len = hyp * np.cos(np.deg2rad(45))

        normal_co = np.cross(p0-p2, p1-p2)
        normal_vec = normal_co / np.linalg.norm(normal_co)
        normal_co = np.array([(p0[0] + p1[0] + p2[0]) / 3, (p0[1] + p1[1] + p2[1]) / 3, (p0[2] + p1[2] + p2[2]) / 3])

        # measured_strut_len = np.array([np.sqrt(np.sum(np.square(p0 - p1))), np.sqrt(np.sum(np.square(p0 - p2))), np.sqrt(np.sum(np.square(p2 - p1)))])
        # print(f"Distance Between Markers on Plane: {measured_strut_len}", end=None)
        # measured_strut_len = np.sum(measured_strut_len) / 3
        # print(f", Average: {measured_strut_len}", end = None)        
        # measured_strut_len = measured_strut_len * np.cos(np.deg2rad(45))
        # print(f", Estimated Strut Length: {measured_strut_len}")
        
        corner_distance = (np.sqrt(6) * Target.marker_separation) / 3
        self.marker_strut_corner = Marker.from_array(normal_co + (normal_vec*-corner_distance)) # Marker.from_array(np.array([measured_strut_len, 0, 0])).rotate_about_point(self.marker_x, yaw, pitch, roll)

        # cross_product = np.cross(p0 - self.marker_strut_corner.to_array(), p1 - self.marker_strut_corner.to_array())
        # self.normal = (cross_product / np.linalg.norm(cross_product))

        yaw, pitch, roll = get_orientation(p1, p2, self.marker_strut_corner.to_array())
        print(f"Target Rotation: {{ YAW: {yaw}, PITCH: {pitch}, ROLL: {roll} }}")
        
        print(f"Markers: {markers}, Marker ORIGIN: {self.marker_strut_corner}") # \nNormal of markerset: {{ ORIGIN: {normal_co}, VECTOR: {normal_vec} }}\n

        self.origin = Marker.from_array(Target.marker_strut_origin_offset).rotate_about_point(self.marker_strut_corner, yaw, pitch, roll)
        self.orientation = Orientation(yaw, pitch, roll)
        
        self.top_left = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, True, False)
        self.top_right = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, True, True)
        self.bottom_right = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, False, True)
        self.bottom_left = Target.get_corner(self.origin, Target.length, self.orientation.yaw, self.orientation.pitch, self.orientation.roll, False, False)

        cross_product = np.cross(self.bottom_right.to_array() - self.bottom_left.to_array(), self.top_left.to_array() - self.bottom_left.to_array())
        self.normal = (cross_product / np.linalg.norm(cross_product))
    
    def __str__(self):
        return f"Target({self.row},{self.column})[centre({self.origin}), orientation({self.orientation})]"
    
    def __repr__(self):
        return self.__str__()

    def json_key(self):
        return f"({self.row}{self.column})"

    def json_data(self):
        return { "X": self.origin.x, "Y": self.origin.y, "Z": self.origin.z, "YAW": self.orientation.yaw, "PITCH": self.orientation.pitch, "ROLL": self.orientation.roll }