# -*- coding: utf-8 -*-

import numpy as np
import quaternion
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from matplotlib import animation


# Common capture data manipulation functions

def make_rotation_euler(pitch, yaw, roll):
    # Why do I need to do this changes??
    pitch, yaw, roll = -np.pi + pitch, -np.pi + yaw, np.pi - roll
    pitch_s, pitch_c = np.sin(pitch), np.cos(pitch)
    yaw_s, yaw_c = np.sin(yaw), np.cos(yaw)
    roll_s, roll_c = np.sin(roll), np.cos(roll)
    rx = np.array([[1, 0, 0, 0],
                   [0, roll_c, roll_s, 0],
                   [0, roll_s, roll_c, 0],
                   [0, 0, 0, 1]])
    ry = np.array([[pitch_c, 0, pitch_s, 0],
                   [0, 1, 0, 0],
                   [-pitch_s, 0, pitch_c, 0],
                   [0, 0, 0, 1]])
    rz = np.array([[yaw_c, yaw_s, 0, 0],
                   [yaw_s, yaw_c, 0, 0],
                   [0, 0, 1, 0],
                   [0, 0, 0, 1]])
    return rz.dot(ry).dot(rx)


def quaternion_from_pitch_yaw_roll(pitch, yaw, roll):
    return quaternion.from_rotation_matrix(
        make_rotation_euler(pitch, yaw, roll))


def make_rotation_matrix(q):
    qs = np.atleast_1d(q)
    ms = quaternion.as_rotation_matrix(qs)
    t = ms.shape[0]
    ms = np.concatenate((np.concatenate((ms, np.zeros((t, 1, 3))), 1),
                         np.zeros((t, 4, 1))), axis=2)
    ms[:, 3, 3] = 1
    if np.isscalar(q):
        ms = ms[0, :, :]
    return ms


def rotate_coordinates(coordinates, q):
    ms = make_rotation_matrix(q)
    homogeneous = coordinates
    if homogeneous.ndim < 3:
        homogeneous = np.array([coordinates])
    t, r, n = homogeneous.shape
    if r != 3:
        raise ValueError('coordinates must be a 3xN or Mx3XN array.')
    homogeneous = np.concatenate((homogeneous, np.ones((t, 1, n))), 1)
    transformed = np.einsum('...ij,...jk->...ik', ms, homogeneous)
    if np.isscalar(q) and coordinates.ndim < 3:
        transformed = transformed[0, :3, :]
    else:
        transformed = transformed[:, :3, :]
    return transformed


def prerotate_quaternions(quaternions, q):
    unpack = False
    if quaternions.ndim < 2:
        unpack = True
        quaternions = np.array([quaternions])
    n, dim = quaternions.shape
    if dim != 4:
        raise ValueError('quaternions must be a 4 or Nx4 array.')
    rotated = quaternion.as_float_array(
        q * quaternion.as_quat_array(quaternions))
    if unpack:
        rotated = rotated[0]
    return rotated


def postrotate_quaternions(quaternions, q):
    unpack = False
    if quaternions.ndim < 2:
        unpack = True
        quaternions = np.array([quaternions])
    n, dim = quaternions.shape
    if dim != 4:
        raise ValueError('quaternions must be a 4 or Nx4 array.')
    rotated = quaternion.as_float_array(
        quaternion.as_quat_array(quaternions) * q)
    if unpack:
        rotated = rotated[0]
    return rotated


def make_quaternion_from_axis_angle(axis, radians):
    axis_arr = np.atleast_2d(axis)
    radians_arr = np.atleast_1d(radians)
    radians_arr = radians_arr[:, np.newaxis] * np.ones((axis_arr.shape[0], 1))
    if axis_arr.shape[1] != 3:
        raise ValueError('axis must be a 3 vector or a Nx3 matrix.')
    c = np.cos(radians_arr / 2)
    s = np.sin(radians_arr / 2)
    qs = quaternion.as_quat_array(np.concatenate((c, s * axis_arr), 1))
    if np.isscalar(radians) and np.ndim(axis) < 2:
        qs = qs[0]
    return qs


def interpolate_slerp(data):
    if data.shape[1] != 4:
        raise ValueError('Need exactly 4 values for SLERP.')
    vals = data.values
    empty = np.any(np.isnan(vals), axis=1)
    quaternions = quaternion.as_quat_array(vals)
    valid_loc = np.argwhere(~empty).squeeze(axis=-1)
    valid_index = data.index[valid_loc].values
    valid_quaternions = quaternions[valid_loc]
    empty_loc = np.argwhere(empty).squeeze(axis=-1)
    empty_loc = empty_loc[(empty_loc > valid_loc.min()) &
                          (empty_loc < valid_loc.max())]
    empty_index = data.index[empty_loc].values
    interp_loc_end = np.searchsorted(valid_loc, empty_loc)
    interp_loc_start = interp_loc_end - 1
    interp_q_start = valid_quaternions[interp_loc_start]
    interp_q_end = valid_quaternions[interp_loc_end]
    interp_t_start = valid_index[interp_loc_start]
    interp_t_end = valid_index[interp_loc_end]
    interpolated = quaternion.slerp(interp_q_start, interp_q_end,
                                    interp_t_start, interp_t_end, empty_index)
    data = data.copy()
    data.iloc[empty_loc] = quaternion.as_float_array(interpolated)
    return data


# Not used, not correctly implemented
# TODO: check dot product of interpolated quaternions
# http://stackoverflow.com/a/17243783/1782792
# def interpolate_quaternion_lerp(data):
#     raise Exception('Quaternion LERP is not implemented.')
#     if data.shape[1] != 4:
#         raise ValueError('Need exactly 4 values for quaternion LERP.')
#     data = data.interpolate('linear')
#     data /= np.linalg.norm((data), 1)
#     return data


def reindex_data(data, new_index, slerp=False):
    tmp_index = np.unique(np.concatenate((data.index, new_index)))
    tmp_data = data.reindex(tmp_index)
    if slerp:
        data_interp = interpolate_slerp(tmp_data)
    else:
        data_interp = tmp_data.interpolate(method='index')
    return data_interp.reindex(new_index)


def change_speed(data, factor):
    if factor == 0:
        raise ValueError('Factor cannot be zero.')
    data = data.copy()
    data.index /= factor
    if factor < 0:
        data = data.reindex(index=data.index[::-1])
        data.index -= data.index[0]
    return data


def animate_capture(head=None, hand1=None, hand2=None, time=None, labels=None,
                    fps=None, elevation=None, azimuth=None, fig=None,
                    size_inches=None):
    try:
        __IPYTHON__
        from IPython.display import clear_output
    except (NameError, ImportError):
        def clear_output(wait=False):
            pass

    if fps is None:
        fps = 25
    if elevation is None:
        elevation = 30
    if azimuth is None:
        azimuth = 15
    elems = [head, hand1, hand2]
    num_elem = len([elem for elem in elems if elem is not None])
    if num_elem <= 0:
        raise ValueError('At least one of head, hand 1 or hand 2 '
                         'must be provided.')
    num_data = None
    for elem in elems:
        if num_data is None:
            if elem is not None:
                num_data = len(elem[0])
        elif elem is not None and len(elem[0]) != num_data:
            raise ValueError('Provided values must have the same size.')
    if time is not None:
        if len(time) != num_data:
            raise ValueError('Invalid time value.')
    else:
        time = np.arange(num_data) * 1. / 60  # 60 fps by default
    if labels is not None:
        if len(labels) != num_data:
            raise ValueError('Invalid labels value.')
    else:
        labels = np.empty(num_data, dtype=np.str)

    # To ms
    interval_ms = 1000. / fps
    t_min = time.min()
    t_max = time.max()
    n_frames = int(np.ceil((t_max - t_min) * fps))
    trail_time = 1.  # Could make it a parameter I guess...
    n_trail = int(np.round(trail_time * fps))

    # Orientation figure coordinates
    vision_cone = np.array([[0, 20, 20, 20],
                            [0, 2, -2, 0],
                            [0, 0, 0, 10]])
    vision_cone_tris = [[0, 2, 1],
                        [0, 3, 1],
                        [0, 3, 2],
                        [1, 2, 3]]

    # Compute cones coordinates
    elems_cones = []
    for i_elem, elem in enumerate(elems):
        if elem is None:
            elems_cones.append(None)
            continue
        pos, rot = elem
        # Reindex to right frequency
        new_index = np.linspace(t_min, t_max, n_frames)
        idx = np.searchsorted(time, new_index)
        new_index = np.asarray(time)[idx]
        pos = np.asarray(pos)[idx]
        rot = np.asarray(rot)[idx]
        elems[i_elem] = (pos, rot)
        new_labels = np.asarray(labels)[idx]
        new_labels = new_labels.astype(np.str)

        # Orientation figures
        quat = quaternion.as_quat_array(rot)
        cone_coords = rotate_coordinates(vision_cone, quat)
        cone_coords += pos[:, :, np.newaxis]
        elems_cones.append(cone_coords)

    # Attaching 3D axis to the figure
    fig = fig or plt.figure()
    if size_inches:
        fig.set_size_inches(*size_inches)
    ax = p3.Axes3D(fig)

    plots = {
        'head': None,
        'hand1': None,
        'hand2': None,
        'frame': None,
        'counter': ax.text2D(-0.085, .085, '', ha='left', va='top'),
        'label': ax.text2D(.085, .085, '', ha='right', va='top')
    }
    if head is not None:
        plots['head'] = [None, '#C23B22',
                         ax.plot([], [], [], color='#FF6961', lw=2)[0]]
    if hand1 is not None:
        plots['hand1'] = [None, '#779ECB',
                          ax.plot([], [], [], color='#AEC6CF', lw=2)[0]]
    if hand2 is not None:
        plots['hand2'] = [None, '#779ECB',
                          ax.plot([], [], [], color='#AEC6CF', lw=2)[0]]
    frame_idx = []  # Order of elements in frame (0: head, 1: hand1, 2: hand2)
    if num_elem > 1:
        plots['frame'] = ax.plot([], [], [], color='#AAAAAA', lw=1)[0]
        if num_elem >= 3:
            frame_idx = [1, 0, 2]
        else:
            frame_idx = [i for i, elem in enumerate(elems) if elem is not None]

    # Setting the axes properties
    x_min, x_max = -150, 150
    y_min, y_max = -150, 150
    z_min, z_max = 0, 200
    ax.view_init(elevation, azimuth)
    ax.set_xlim3d([x_min, x_max])
    ax.set_xlabel('X')
    ax.set_ylim3d([y_min, y_max])
    ax.set_ylabel('Y')
    ax.invert_yaxis()
    ax.set_zlim3d([z_min, z_max])
    ax.set_zlabel('Z')
    ax.set_title('Capture preview')

    def animate(i):
        i_start = max(i - n_trail, 0)
        plotted = []
        # Plot element cones
        for name, data, cone in zip(('head', 'hand1', 'hand2'),
                                    elems, elems_cones):
            if data is None or cone is None:
                continue
            pos, _ = data
            cone_plot, cone_color, trail_plot = plots[name]
            # NOTE: there is no .set_data() for 3 dim data...
            if cone_plot is not None:
                cone_plot.remove()
            cone_plot = ax.plot_trisurf(*cone[i, :, :],
                                        triangles=vision_cone_tris,
                                        color=cone_color)
            plots[name][0] = cone_plot
            trail_plot.set_data(pos[i_start:i + 1, 0:2].T)
            trail_plot.set_3d_properties(pos[i_start:i + 1, 2])
            plotted.extend([cone_plot, trail_plot])
        # Plot frame
        if plots['frame'] and frame_idx:
            pos, _ = elems[frame_idx[0]]
            xy = [pos[i, :2]] * 2
            z = [z_min, pos[i, 2]]
            for i_elem in frame_idx[1:]:
                pos, _ = elems[i_elem]
                xy.extend([pos[i, :2]] * 3)
                z.extend([pos[i, 2], z_min, pos[i, 2]])
            plots['frame'].set_data(np.array(xy).T)
            plots['frame'].set_3d_properties(z)
            plotted.append(plots['frame'])
        # Frame count
        plots['counter'].set_text('Time:\n{:.2f}s'.format(new_index[i]))
        plotted.append(plots['counter'])
        # Frame label
        plots['label'].set_text('{}'.format(new_labels[i]))
        plotted.append(plots['label'])
        return plotted

    # blit=True means only re-draw the parts that have changed
    anim = animation.FuncAnimation(fig, animate, frames=n_frames,
                                   interval=interval_ms, blit=False,
                                   repeat=False)
    clear_output()
    return anim
