# -*- coding: utf-8 -*-

from itertools import combinations, product

import numpy as np
import tensorflow as tf


def make_class_progress(labels, classes=None, default=-1):
    if not labels:
        return []
    if classes is None:
        classes = np.unique(np.concatenate(labels, axis=0))
        classes = classes[classes >= 0]
    if np.isscalar(classes):
        classes = [classes]
    label_progs = []
    for label in labels:
        label_relevant = \
            np.expand_dims(label, 1) == np.expand_dims(classes, 0)
        label_mask = np.logical_or.reduce(label_relevant, axis=1)
        label_diff = np.diff(label)
        label_start_idx = np.concatenate(
            [[0], np.where(label_diff != 0)[0] + 1], axis=0)
        label_length = np.diff(np.concatenate(
            [label_start_idx, [len(label)]], axis=0))
        label_prog = np.concatenate(
            [np.linspace(0, 1, l) for l in label_length])
        label_prog[~label_mask] = default
        label_progs.append(label_prog)
    return label_progs


def trim_class_margins(labels, start_margin_ratio, end_margin_ratio=None,
                       classes=None, replacement=-1):
    if not labels:
        return []
    if classes is None:
        classes = np.unique(np.concatenate(labels, axis=0))
        classes = classes[classes >= 0]
    if np.isscalar(classes):
        classes = [classes]
    labels_new = []
    for label in labels:
        if end_margin_ratio is None:
            end_margin_ratio = start_margin_ratio
        if start_margin_ratio < 0 or start_margin_ratio > 1:
            raise ValueError("The start margin ratio must be between 0 and 1.")
        if end_margin_ratio < 0 or end_margin_ratio > 1:
            raise ValueError("The end margin ratio must be between 0 and 1.")
        labels_relevant = \
            np.expand_dims(label, 1) == np.expand_dims(classes, 0)
        labels_mask = np.logical_or.reduce(labels_relevant, axis=1)
        labels_diff = np.diff(label)
        labels_start_idx = np.concatenate(
            [[0], np.where(labels_diff != 0)[0] + 1], axis=0)
        labels_length = np.diff(np.concatenate(
            [labels_start_idx, [len(label)]], axis=0))
        relative_pos = np.concatenate(
            [np.linspace(0, 1, l) for l in labels_length])
        valids = (relative_pos >= start_margin_ratio) \
            & (relative_pos <= (1. - end_margin_ratio))
        label_new = label.copy()
        label_new[~valids] = -1
        label_new = np.where(labels_mask, label_new, label)
        labels_new.append(label_new)
    return labels_new


def _is_tf(*args):
    return any(isinstance(arg, tf.Tensor) or
               (isinstance(arg, (tuple, list)) and _is_tf(*arg))
               for arg in args)


def _quat_product(q1, q2, name="QuaternionProduct"):
    with_tf = _is_tf(q1, q2)
    if not _is_tf(q1):
        q1 = np.asarray(q1)
    if not _is_tf(q2):
        q2 = np.asarray(q2)
    if with_tf:
        assert q1.shape[-1].value == 4 and q2.shape[-1].value == 4, \
            "Quaternions must have 4 elements"
    with tf.variable_scope(name):
        w = (q1[..., 0] * q2[..., 0]) \
            - (q1[..., 1] * q2[..., 1]) \
            - (q1[..., 2] * q2[..., 2]) \
            - (q1[..., 3] * q2[..., 3])
        x = (q1[..., 0] * q2[..., 1]) \
            + (q1[..., 1] * q2[..., 0]) \
            + (q1[..., 2] * q2[..., 3]) \
            - (q1[..., 3] * q2[..., 2])
        y = (q1[..., 0] * q2[..., 2]) \
            - (q1[..., 1] * q2[..., 3]) \
            + (q1[..., 2] * q2[..., 0]) \
            + (q1[..., 3] * q2[..., 1])
        z = (q1[..., 0] * q2[..., 3]) \
            + (q1[..., 1] * q2[..., 2]) \
            - (q1[..., 2] * q2[..., 1]) \
            + (q1[..., 3] * q2[..., 0])
        if with_tf:
            return tf.stack([w, x, y, z], axis=-1)
        else:
            return np.stack([w, x, y, z], axis=-1)


def _quat_conj(q, name="QuaternionConjugate"):
    with_tf = _is_tf(q)
    if not with_tf:
        q = np.asarray(q)
    with tf.variable_scope(name):
        if with_tf:
            return tf.concat([q[..., :1], -q[..., 1:]], axis=-1)
        else:
            return np.concatenate([q[..., :1], -q[..., 1:]], axis=-1)


def _vector_quat_rotate(v, q, name="Rotate"):
    with_tf = _is_tf(v, q)
    with tf.variable_scope(name):
        if with_tf:
            vw = tf.zeros(tf.concat([tf.shape(v)[:-1], [1]], axis=0),
                          dtype=v.dtype)
            vq = tf.concat([vw, v], axis=-1)
        else:
            vq = np.pad(v, [(0, 0)] * (np.ndim(v) - 1) + [(1, 0)], "constant")
        return _quat_product(_quat_product(q, vq), _quat_conj(q))[..., 1:]


def _normalize(x, name="Normalize"):
    with_tf = _is_tf(x)
    with tf.variable_scope(name):
        if with_tf:
            return x / tf.norm(x, axis=-1, keepdims=True)
        else:
            return x / np.linalg.norm(x, axis=-1, keepdims=True)


# Not necessary since tf.atan2 was introduced
def _atan2(s, c, eps=1e-23):
    with_tf = _is_tf(s, c)
    # Avoid zeros
    if with_tf:
        s_sign = tf.sign(s)
        s_sign = tf.where(s_sign != 0, s_sign, tf.ones_like(s_sign))
        c_sign = tf.sign(c)
        c_sign = tf.where(c_sign != 0, c_sign, tf.ones_like(c_sign))
        s += eps * s_sign
        c += eps * c_sign
        theta = tf.atan(s / c)
        theta = tf.where(c >= 0, theta, theta + tf.sign(s) * np.pi)
        return theta
    else:
        return np.arctan2(s, c)


def _yaw_from_quat(q):
    with_tf = _is_tf(q)
    qw = q[..., 0]
    qx = q[..., 1]
    qy = q[..., 2]
    qz = q[..., 3]
    y = 2. * (qw * qz + qx * qy)
    x = 1. - 2. * (qy * qy + qz * qz)
    if with_tf:
        return tf.atan2(y, x)
    else:
        return np.arctan2(y, x)


def _pyr_from_quat(q):
    qw = q[..., 0]
    qx = q[..., 1]
    qy = q[..., 2]
    qz = q[..., 3]
    sing_test = qz * qx - qw * qy
    yaw_y = 2.0 * (qw * qz + qx * qy)
    yaw_x = 1.0 - 2.0 * (qy * qy + qz * qz)
    SINGULARITY_THRESHOLD = 0.499999995
    pitch = np.arcsin(2.0 * sing_test)
    yaw = np.arctan2(yaw_y, yaw_x)
    roll = np.arctan2(-2.0 * (qw * qx + qy * qz),
                      1.0 - 2.0 * (qx * qx + qy * qy))
    m = sing_test < -SINGULARITY_THRESHOLD
    if np.any(m):
        pitch[m] = -90.0
        yaw[m] = np.arctan2(yaw_y[m], yaw_x[m])
        roll[m] = -yaw[m] - (2.0 * np.arctan2(qx[m], qw[m]))
    m = sing_test > SINGULARITY_THRESHOLD
    if np.any(m):
        pitch[m] = 90.0
        yaw[m] = np.arctan2(yaw_y[m], yaw_x[m])
        roll[m] = yaw[m] - (2.0 * np.arctan2(qx[m], qw[m]))
    res = np.stack([pitch, yaw, roll], axis=-1) * 180.0 / np.pi
    res %= 360.0
    res[res > 180.0] = 360.0 - res[res > 180.0]
    return res


def _quat_from_yaw(yaw):
    with_tf = _is_tf(yaw)
    yaw_half = yaw / 2
    if with_tf:
        return tf.stack([tf.cos(yaw_half), tf.zeros_like(yaw),
                         tf.zeros_like(yaw), tf.sin(yaw_half)], axis=-1)
    else:
        return np.stack([np.cos(yaw_half), np.zeros_like(yaw),
                         np.zeros_like(yaw), np.sin(yaw_half)], axis=-1)


def _lerp(start, end, alpha, name="Lerp"):
    with tf.variable_scope(name):
        return start * (1 - alpha) + end * alpha


def _quat_slerp(q1, q2, alpha, normalize=False, name="QuatSlerp"):
    with_tf = _is_tf(q1, q2, alpha)
    if with_tf:
        assert q1.shape[-1].value == 4 and q2.shape[-1].value == 4, \
            "Quaternions must have 4 elements"
    with tf.variable_scope(name):
        if normalize:
            q1 = _normalize(q1)
            q2 = _normalize(q2)
        if with_tf:
            dot = tf.reduce_sum(q1 * q2, axis=-1, name="Dot")
            dot_positive = dot >= 0
            q2 = tf.where(dot_positive, q2, -q2)
            dot = tf.where(dot_positive, dot, -dot)
            theta = tf.acos(dot, name="Theta")
            result = (tf.sin((1 - alpha) * theta) * q1 +
                      tf.sin(alpha * theta) * q2) / tf.sin(theta)
        else:
            dot = np.sum(q1 * q2, axis=-1)
            m = dot < 0
            q2 = q2.copy()
            q2[m] = -q2[m]
            dot[m] = -dot[m]
            theta = np.arccos(dot)
            result = (np.sin((1 - alpha) * theta) * q1 +
                      np.sin(alpha * theta) * q2) / np.sin(theta)
        return result


def _quat_lerp(q1, q2, alpha, normalize=False, name="QuatLerp"):
    with_tf = _is_tf(q1, q2, alpha)
    if with_tf:
        assert q1.shape[-1].value == 4 and q2.shape[-1].value == 4, \
            "Quaternions must have 4 elements"
    with tf.variable_scope(name):
        if normalize:
            q1 = _normalize(q1)
            q2 = _normalize(q2)
        if with_tf:
            dot = tf.reduce_sum(q1 * q2, axis=-1, name="Dot")
            dot_positive = dot >= 0
            q2 = tf.where(dot_positive, q2, -q2)
        else:
            dot = np.sum(q1 * q2, axis=-1)
            m = dot < 0
            q2 = q2.copy()
            q2[m] = -q2[m]
        return _normalize(_lerp(q1, q2, alpha))


def _quat_canonical(q, normalize=False):
    if normalize:
        q = _normalize(q)
    m = q[..., 0] >= 0
    if _is_tf(q):
        m = tf.tile(m[..., tf.newaxis], [1] * len(m.shape) + [4])
        q = tf.where(m, q, -q)
    else:
        q = q.copy()
        q[~m] *= -1
    return q


def _quat_to_xvector(q, normalize=False):
    # (w, x, y, z) * (0, 1, 0, 0) * (w, -x, -y, -z) =
    # = (0, w^2 + x^2 - y^2 - z^2, 2(wz + xy), 2(-wy + xz))
    if normalize:
        q = _normalize(q)
    qw = q[..., 0]
    qx = q[..., 1]
    qy = q[..., 2]
    qz = q[..., 3]
    vx = qw * qw + qx * qx - qy * qy - qz * qz
    vy = 2 * (qw * qz + qx * qy)
    vz = 2 * (-qw * qy + qx * qz)
    v = [vx, vy, vz]
    if _is_tf(q):
        return tf.stack(v, axis=-1)
    else:
        return np.stack(v, axis=-1)


def _quat_from_xvector(v, normalize=False):
    if normalize:
        v = _normalize(v)
    return


def _make_drag_vector(drag, t, dtype=None):
    res = (np.array(drag) ** (np.arange(t) + 1))
    if dtype:
        res = res.astype(dtype=dtype)
    return res


def _make_drag_matrix(drag, t, dtype=None):
    ii, jj = np.meshgrid(np.arange(t) + 1, np.arange(t) + 1, indexing="ij")
    exp = jj - ii
    m = exp >= 0
    res = (1 - drag) * (np.array(drag) ** (exp * m)) * m
    if dtype:
        res = res.astype(dtype=dtype)
    return res


def make_input_head_relative(input_, step_size=0, drag_pos=0, drag_yaw=0,
                             drop_head=False, state_variables_collection=None,
                             state_updates_collection=None):
    with_state_collections = (state_variables_collection and
                              state_updates_collection)
    with_tf = _is_tf(input_) or with_state_collections
    drag_pos = np.clip(drag_pos or 0, 0, 1)
    drag_yaw = np.clip(drag_yaw or 0, 0, 1)
    y = input_
    if _is_tf(y):
        tf_dtype = y.dtype
        np_dtype = y.dtype.as_numpy_dtype
    else:
        y = np.asarray(y)
        tf_dtype = tf.as_dtype(y.dtype)
        np_dtype = y.dtype
    with tf.variable_scope("HeadDrag"):
        if with_tf:
            batch_size = y.get_shape()[0].value
            batch_step_size = tf.shape(y)[1]
        else:
            batch_size, batch_step_size = np.shape(y)[:2]

        pos_batch = y[..., 0:3]
        yaw_batch = _yaw_from_quat(y[..., 3:7])

        if drag_pos > 0:
            if with_tf:
                if with_state_collections:
                    dragged_pos = tf.get_variable(
                        "DraggedPosition", [batch_size, 3], dtype=tf_dtype,
                        initializer=tf.zeros_initializer(), trainable=False)
                else:
                    dragged_pos = tf.zeros([batch_size, 3], dtype=tf_dtype)
            else:
                dragged_pos = np.zeros([batch_size, 3], dtype=np_dtype)
            # Tensor product (N) x (T x 3) -> (N x T x 3)
            drag_pos_vector = _make_drag_vector(drag_pos, step_size,
                                                dtype=np_dtype)
            if with_tf:
                drag_pos_vector = tf.slice(drag_pos_vector, [0],
                                           [batch_step_size])
                pos_base = tf.matmul(
                    tf.reshape(dragged_pos, [-1, 1]),
                    tf.expand_dims(drag_pos_vector, 0))
                pos_base = tf.transpose(
                    tf.reshape(pos_base, [-1, batch_size, 3]), [1, 0, 2])
            else:
                drag_pos_vector = drag_pos_vector[:batch_step_size]
                pos_base = np.matmul(
                    np.reshape(dragged_pos, [-1, 1]),
                    np.expand_dims(drag_pos_vector, 0))
                pos_base = np.transpose(
                    np.reshape(pos_base, [-1, batch_size, 3]), [1, 0, 2])
            # Tensor product (N x T x 3) x (T x T) -> (N x T x 3)
            # There is probably a better way to do this...
            drag_pos_mat = _make_drag_matrix(drag_pos, step_size,
                                             dtype=np_dtype)
            if with_tf:
                drag_pos_mat = tf.slice(drag_pos_mat, [0, 0],
                                        [batch_step_size, batch_step_size])
                pos_res_batch = tf.reshape(tf.transpose(pos_batch, [2, 0, 1]),
                                           [-1, batch_step_size])
                drag_pos_res_matmul = tf.matmul(pos_res_batch, drag_pos_mat)
                drag_pos_matmul = tf.transpose(
                    tf.reshape(drag_pos_res_matmul,
                               [3, batch_size, batch_step_size]),
                    [1, 2, 0])
            else:
                drag_pos_mat = drag_pos_mat[:batch_step_size, :batch_step_size]
                pos_res_batch = np.reshape(np.transpose(pos_batch, [2, 0, 1]),
                                           [-1, batch_step_size])
                drag_pos_res_matmul = np.matmul(pos_res_batch, drag_pos_mat)
                drag_pos_matmul = np.transpose(
                    np.reshape(drag_pos_res_matmul,
                               [3, batch_size, batch_step_size]),
                    [1, 2, 0])
            dragged_pos_batch = pos_base + drag_pos_matmul
            dragged_pos_update = dragged_pos_batch[:, -1, :]
            if with_state_collections:
                tf.add_to_collection(state_variables_collection, dragged_pos)
                tf.add_to_collection(state_updates_collection,
                                     dragged_pos_update)
        else:
            dragged_pos_batch = pos_batch
        if with_tf:
            dragged_pos_batch = tf.identity(dragged_pos_batch,
                                            name="DraggedPositionBatchFinal")

        if drag_yaw > 0:
            if with_tf:
                if with_state_collections:
                    dragged_yaw = tf.get_variable(
                        "DraggedYaw", [batch_size], dtype=tf_dtype,
                        initializer=tf.zeros_initializer(), trainable=False)
                else:
                    dragged_yaw = tf.zeros([batch_size], dtype=tf_dtype)
            else:
                dragged_yaw = np.zeros([batch_size], dtype=np_dtype)
            drag_yaw_vector = _make_drag_vector(drag_yaw, step_size,
                                                dtype=np_dtype)
            if with_tf:
                drag_yaw_vector = tf.slice(drag_yaw_vector, [0],
                                           [batch_step_size])
                yaw_base = tf.matmul(tf.expand_dims(dragged_yaw, 1),
                                     tf.expand_dims(drag_yaw_vector, 0))
            else:
                drag_yaw_vector = drag_yaw_vector[:batch_step_size]
                yaw_base = np.matmul(np.expand_dims(dragged_yaw, 1),
                                     np.expand_dims(drag_yaw_vector, 0))
            drag_yaw_mat = _make_drag_matrix(drag_yaw, step_size,
                                             dtype=np_dtype)
            if with_tf:
                drag_yaw_mat = tf.slice(drag_yaw_mat, [0, 0],
                                        [batch_step_size, batch_step_size])
                dragged_yaw_batch = yaw_base \
                    + tf.matmul(yaw_batch, drag_yaw_mat)
            else:
                drag_yaw_mat = drag_yaw_mat[:batch_step_size, :batch_step_size]
                dragged_yaw_batch = yaw_base \
                    + np.matmul(yaw_batch, drag_yaw_mat)
            dragged_yaw_update = dragged_yaw_batch[:, -1]
            if with_state_collections:
                tf.add_to_collection(state_variables_collection, dragged_yaw)
                tf.add_to_collection(state_updates_collection,
                                     dragged_yaw_update)
        else:
            dragged_yaw_batch = yaw_batch
        if with_tf:
            dragged_yaw_batch = tf.identity(dragged_yaw_batch,
                                            name="DraggedYawBatchFinal")

    with tf.variable_scope("HeadRelative"):
        # Make relative
        elems = []
        i_elems = [0, 1, 2]  # Head, hand 1, hand 2
        if drop_head:
            i_elems = i_elems[1:]
        for i in i_elems:
            idx = i * 7
            pos = y[..., idx:idx + 3]
            rot = y[..., idx + 3:idx + 7]
            pos_rel, rot_rel = do_relative(pos, rot, dragged_pos_batch,
                                           dragged_yaw_batch)
            elems += [pos_rel, rot_rel]
        if with_tf:
            y = tf.concat(elems + [y[..., 21:]], axis=-1,
                          name="ProcessedInput")
        else:
            y = np.concatenate(elems + [y[..., 21:]], axis=-1)
    return y, dragged_pos_batch, dragged_yaw_batch


def do_relative(position, rotation, base_position, base_yaw):
    base_rot_inv = _quat_from_yaw(-base_yaw)
    pos_rel = _vector_quat_rotate(position - base_position, base_rot_inv)
    rot_rel = _quat_product(base_rot_inv, rotation)
    rot_rel = _quat_canonical(rot_rel)
    return pos_rel, rot_rel


def undo_relative(position, rotation, base_position, base_yaw):
    base_rot = _quat_from_yaw(base_yaw)
    pos_orig = _vector_quat_rotate(position, base_rot) + base_position
    rot_orig = _quat_product(base_rot, rotation)
    rot_orig = _quat_canonical(rot_orig)
    return pos_orig, rot_orig


def quantize_position(position, split_x, split_y, split_z):
    with_tf = _is_tf(position, split_x, split_y, split_z)
    features = []
    if with_tf:
        coord_shape = tf.shape(position)[:-1]
    else:
        coord_shape = np.shape(position)[:-1]
    for i, split in enumerate([split_x, split_y, split_z]):
        if not split:
            continue
        coord = position[..., i]
        if with_tf:
            coord_mask = tf.fill(coord_shape, True)
        else:
            coord_mask = np.full(coord_shape, True, dtype=bool)
        for s in split:
            m = coord < s
            features.append(coord_mask & m)
            coord_mask = coord_mask & (~m)
        features.append(coord_mask)
    if with_tf:
        features = tf.stack(features, axis=-1)
        features = tf.cast(features, position.dtype)
    else:
        features = np.stack(features, axis=-1)
        features = features.astype(position.dtype)
    return features


def make_orientation_class_vectors(max_components=1, dtype=None):
    if max_components < 1:
        return np.zeros((0, 3), dtype=dtype)
    vectors = []
    signs = [-1., 1.]
    for i in range(3):
        for s in signs:
            v = np.zeros(3)
            v[i] = s
            vectors.append(v)
    if max_components >= 2:
        for i1, i2 in combinations(range(3), 2):
            for s1, s2 in product(signs, signs):
                v = np.zeros(3)
                v[i1] = s1
                v[i2] = s2
                vectors.append(v)
    if max_components >= 3:
        for s1, s2, s3 in product(signs, signs, signs):
            v = np.array([s1, s2, s3])
            vectors.append(v)
    vectors = np.stack(vectors, axis=0)
    vectors /= np.linalg.norm(vectors, axis=1, keepdims=True)
    if dtype is not None:
        vectors = vectors.astype(dtype)
    return vectors


def quantize_rotation(rotation, base_orientation_vector,
                      orientation_class_vectors):
    with_tf = _is_tf(rotation, base_orientation_vector,
                     orientation_class_vectors)
    if with_tf:
        num_classes = tf.shape(orientation_class_vectors)[-2]
    else:
        num_classes = orientation_class_vectors.shape[-2]
    orientation_vectors = _vector_quat_rotate(base_orientation_vector,
                                              rotation)
    newaxis = tf.newaxis if with_tf else np.newaxis
    orientation_class_vectors = orientation_class_vectors[..., newaxis, :, :]
    orientation_vectors = orientation_vectors[..., newaxis, :]
    if with_tf:
        dot = tf.reduce_sum(orientation_class_vectors * orientation_vectors,
                            axis=-1)
        idx_class = tf.argmax(dot, axis=-1)
        rotation_class = tf.one_hot(idx_class, num_classes,
                                    dtype=rotation.dtype)
    else:
        dot = np.sum(orientation_class_vectors * orientation_vectors, axis=-1)
        idx_class = np.argmax(dot, axis=-1)
        rotation_class = (idx_class[..., np.newaxis] ==
                          np.arange(num_classes)[np.newaxis, :])
        rotation_class = rotation_class.astype(rotation.dtype)
    return rotation_class
