# -*- coding: utf-8 -*-

from itertools import combinations

import numpy as np
import tensorflow as tf

from TimeNet import TimeNet, TimeNetConfig
from util import freeze_session
from rice_helpers import make_input_head_relative, \
    make_orientation_class_vectors, quantize_position, quantize_rotation


def make_runtime_graph(time_net):
    runtime_config = TimeNetConfig(
        convolution_stack=time_net.config.convolution_stack,
        dilated_stack=time_net.config.dilated_stack,
        dense_stack=time_net.config.dense_stack,
        local_recurrent_stack=time_net.config.local_recurrent_stack,
        recurrent_stack=time_net.config.recurrent_stack,
        softmax=time_net.config.softmax,
        default_class=time_net.config.default_class,
        output_smoothing=time_net.config.output_smoothing,
        progress_quantization=time_net.config.progress_quantization,
        confidence_quantization=time_net.config.confidence_quantization,
        step_size=1,
        batch_size=1,
        progress_loss=0,
        confidence_loss=0,
        weight_decay=0,
        consistence_loss=0,
        dropout=0,
    )
    runtime_net = TimeNet(time_net.data_dim, time_net.class_names,
                          runtime_config, dtype=time_net.dtype,
                          stateful=False, fixed_step=True, mask_value=None,
                          use_dropout=False,
                          input_filter=time_net.input_filter)
    try:
        runtime_net.weights = time_net.weights
        input_aliases, output_aliases = runtime_net.aliases
        keep_vars = runtime_net.get_collection(TimeNet.STATE_VARIABLES) \
            + runtime_net.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
        keep_var_names = [v.op.name for v in keep_vars]
        output_names = [alias.op.name for alias in output_aliases] \
            + [alias.op.name for alias in input_aliases]
        frozen = freeze_session(runtime_net.session,
                                keep_var_names=keep_var_names,
                                output_names=output_names)
    finally:
        runtime_net.close_session()
    return frozen


def make_class_progress(labels, classes=None, default=-1):
    if not labels:
        return []
    if classes is None:
        classes = np.unique(np.concatenate(labels, axis=0))
        classes = classes[classes >= 0]
    if np.isscalar(classes):
        classes = [classes]
    label_progs = []
    for label in labels:
        label_relevant = \
            np.expand_dims(label, 1) == np.expand_dims(classes, 0)
        label_mask = np.logical_or.reduce(label_relevant, axis=1)
        label_diff = np.diff(label)
        label_start_idx = np.concatenate(
            [[0], np.where(label_diff != 0)[0] + 1], axis=0)
        label_length = np.diff(np.concatenate(
            [label_start_idx, [len(label)]], axis=0))
        label_prog = np.concatenate(
            [np.linspace(0, 1, l) for l in label_length])
        label_prog[~label_mask] = default
        label_progs.append(label_prog)
    return label_progs


def make_class_confidence(labels, margin_ratio, background_class=-1):
    if not labels:
        return []
    margin_ratio = np.clip(margin_ratio, 0, 1)
    if np.isscalar(margin_ratio):
        margin_in = margin_out = margin_ratio
    else:
        margin_in, margin_out = margin_ratio
    if margin_in <= 0 and margin_out <= 0:
        return [np.ones(len(label), dtype=float) for label in labels]
    labels_conf = []
    for label in labels:
        label_diff = np.diff(label)
        label_start_idx = np.concatenate(
            [[0], np.where(label_diff != 0)[0] + 1], axis=0)
        label_class = label[label_start_idx]
        label_length = np.diff(np.concatenate(
            [label_start_idx, [len(label)]], axis=0))
        label_fade_in = np.round(label_length * margin_in).astype(int)
        label_fade_out = np.round(label_length * margin_out).astype(int)
        # Fade for background class
        for i, label_class_i in enumerate(label_class):
            if label_class_i == background_class:
                label_fade_in[i] = label_fade_out[i - 1] if i > 0 else 0
                label_fade_out[i] = label_fade_in[i + 1] \
                    if i < len(label_fade_in) - 1 else 0
        # Make sure fades are not longer than class segments
        label_fade_in = np.where(label_fade_in <= label_length,
                                 label_fade_in, label_length)
        label_fade_out = np.where(label_fade_out <= label_length,
                                  label_fade_out, label_length)
        label_conf = np.ones(len(label), dtype=float)
        for i, length, fade_in, fade_out in zip(label_start_idx,
                                                label_length,
                                                label_fade_in,
                                                label_fade_out):
            if fade_in > 0:
                label_conf_in = np.ones(length, dtype=label_conf.dtype)
                label_conf_in[:fade_in] = np.linspace(0, 1, fade_in)
                label_conf[i:i + length] = np.minimum(label_conf[i:i + length],
                                                      label_conf_in)
            if fade_out > 0:
                label_conf_out = np.ones(length, dtype=label_conf.dtype)
                label_conf_out[-fade_out:] = np.linspace(1, 0, fade_out)
                label_conf[i:i + length] = np.minimum(label_conf[i:i + length],
                                                      label_conf_out)
        labels_conf.append(label_conf)
    return labels_conf


def estimate_label_confidence(labels, hits):
    NUM_BINS = 20
    labels_all = np.concatenate(labels, axis=0)
    hits_all = np.concatenate(hits, axis=0)
    durs_all = np.array(list(map(len, labels)))
    label_limits = np.concatenate([[0], np.cumsum(durs_all)[:-1]], axis=0)
    labels_diff = np.concatenate([[1], np.diff(labels_all)], axis=0)
    # Enforce label frontier on limits
    labels_diff[label_limits] = 1
    idx, = np.where(labels_diff != 0)
    durations = np.diff(np.concatenate([idx, [len(labels_all)]], axis=0))
    classes = labels_all[idx]
    # Estimate confidence function
    n_class = np.max(classes) + 1
    n_class_labels = np.zeros((n_class, NUM_BINS), dtype=int)
    class_time_weights = np.zeros((n_class, NUM_BINS), dtype=float)
    t = np.linspace(0, 1, NUM_BINS + 1)[:-1]
    for i, dur, class_ in zip(idx, durations, classes):
        ti = np.linspace(0, 1, dur)
        m = np.maximum(np.searchsorted(t, ti) - 1, 0)
        np.add.at(class_time_weights[class_], m, hits_all[i:i + dur])
        np.add.at(n_class_labels[class_], m, 1)
    class_time_weights /= np.maximum(n_class_labels, 0)
    # Compute confidence estimations
    confs_all = np.zeros(len(labels_all))
    t_interp = t + np.mean(np.diff(t)) / 2
    for i, dur, class_ in zip(idx, durations, classes):
        ti = np.linspace(0, 1, dur)
        confs_all[i:i + dur] = np.interp(ti, t_interp,
                                         class_time_weights[class_])
    return np.split(confs_all, label_limits[1:])


def trim_class_margins(labels, start_margin_ratio, end_margin_ratio=None,
                       classes=None, replacement=-1):
    if not labels:
        return []
    if classes is None:
        classes = np.unique(np.concatenate(labels, axis=0))
        classes = classes[classes >= 0]
    if np.isscalar(classes):
        classes = [classes]
    if end_margin_ratio is None:
        end_margin_ratio = start_margin_ratio
    if start_margin_ratio < 0 or start_margin_ratio > 1:
        raise ValueError("The start margin ratio must be between 0 and 1.")
    if end_margin_ratio < 0 or end_margin_ratio > 1:
        raise ValueError("The end margin ratio must be between 0 and 1.")
    labels_new = []
    for label in labels:
        labels_relevant = \
            np.expand_dims(label, 1) == np.expand_dims(classes, 0)
        labels_mask = np.logical_or.reduce(labels_relevant, axis=1)
        labels_diff = np.diff(label)
        labels_start_idx = np.concatenate(
            [[0], np.where(labels_diff != 0)[0] + 1], axis=0)
        labels_length = np.diff(np.concatenate(
            [labels_start_idx, [len(label)]], axis=0))
        relative_pos = np.concatenate(
            [np.linspace(0, 1, l) for l in labels_length])
        valids = (relative_pos >= start_margin_ratio) \
            & (relative_pos <= (1. - end_margin_ratio))
        label_new = label.copy()
        label_new[~valids] = -1
        label_new = np.where(labels_mask, label_new, label)
        labels_new.append(label_new)
    return labels_new


def normalized_class_weights(labels, norm_factor=1, class_rel_weights=None):
    labels_all = np.concatenate(labels, axis=0)
    if class_rel_weights is None:
        classes = np.unique(labels_all)
        classes = np.sort(classes[classes >= 0])
        class_rel_weights = np.ones(np.max(classes) + 1)
    else:
        class_rel_weights = np.asarray(class_rel_weights)
    num_classes = len(class_rel_weights)
    labels_mask = (labels_all >= 0) & (labels_all < num_classes)
    labels_all = labels_all[labels_mask]
    classes_count = np.bincount(labels_all, minlength=num_classes).clip(1)
    class_weights = (np.sum(classes_count ** norm_factor) /
                     (num_classes * (classes_count ** norm_factor)))
    class_weights *= class_rel_weights
    return class_weights


def rice_input_filter(drag_pos=0, drag_yaw=0, drop_head=False,
                      drop_secondary_hand=False):
    def input_filter(self, input_):
        # TODO Are these right?
        HAND_SCALE = 0.1
        VEL_SCALE = 0.01
        ACC_SCALE = 0.001
        ROTATION_SCALE = 10

        res, _, _, = make_input_head_relative(
            input_, self.config.step_size, drag_pos, drag_yaw, drop_head,
            self.STATE_VARIABLES, self.STATE_UPDATES)
        if drop_secondary_hand:
            if drop_head:
                i_second_hand = 7
            else:
                i_second_hand = 14
            res = tf.concat([res[..., :i_second_hand],
                             res[..., i_second_hand + 7:]], axis=-1)

        num_elems = 1
        if not drop_head:
            num_elems += 1
        if not drop_secondary_hand:
            num_elems += 1

        # Buffer of previous inputs
        buffer_size = 2
        buffer_shape = res.shape.as_list()
        buffer_shape[1] = buffer_size
        input_buffer = tf.get_variable("PreprocessingInputBuffer",
                                       shape=buffer_shape, dtype=res.dtype,
                                       initializer=tf.zeros_initializer(),
                                       trainable=False)
        if self.fixed_step:
            update_size = min(self.config.step_size, buffer_size)
        else:
            update_size = tf.minimum(tf.shape(res)[1], buffer_size)
        input_buffer_update = tf.concat(
            [input_buffer[:, update_size:, :], res[:, -update_size:, :]],
            axis=1, name="PreprocessingInputBufferUpdated")
        tf.add_to_collection(self.STATE_VARIABLES, input_buffer)
        tf.add_to_collection(self.STATE_UPDATES, input_buffer_update)
        res_buffer = tf.concat([input_buffer, res], axis=1)

        # Additional features
        extra_features = []

        # Add lengths
        for i in range(num_elems):
            length = tf.reduce_sum(tf.square(res[:, :, 7 * i:7 * i + 3]),
                                   axis=-1, keepdims=True)
            length = tf.sqrt(length)
            # Do not scale head coordinates
            if drop_head or i > 0:
                length *= HAND_SCALE
            extra_features.append(length)

        # Add velocity
        for i in range(num_elems):
            vel = (res_buffer[:, buffer_size:, 7 * i:7 * i + 3] -
                   res_buffer[:, buffer_size - 1:-1, 7 * i:7 * i + 3])
            vel *= VEL_SCALE
            extra_features.append(vel)
            vel_norm = tf.reduce_sum(tf.square(vel), axis=-1, keepdims=True)
            vel_norm = tf.sqrt(vel_norm)
            extra_features.append(vel_norm)

        # Add accelerations
        for i in range(num_elems):
            mid = res_buffer[:, buffer_size - 1:-1, 7 * i:7 * i + 3]
            acc = (res_buffer[:, buffer_size:, 7 * i:7 * i + 3] -
                   (mid + mid) +
                   res_buffer[:, buffer_size - 2:-2, 7 * i:7 * i + 3])
            acc *= ACC_SCALE
            extra_features.append(acc)
            acc_norm = tf.reduce_sum(tf.square(acc), axis=-1, keepdims=True)
            acc_norm = tf.sqrt(acc_norm)
            extra_features.append(acc_norm)

        # Add azimuths
        for i in range(num_elems):
            az = tf.atan2(res[:, :, 7 * i + 1],
                          res[:, :, 7 * i])
            extra_features.append(tf.expand_dims(az, axis=-1))

        # Add latitudes
        for i in range(num_elems):
            lat = tf.atan2(input_[:, :, 7 * i + 2],
                           tf.sqrt(tf.square(input_[:, :, 7 * i]) +
                                   tf.square(input_[:, :, 7 * i + 1])))
            extra_features.append(tf.expand_dims(lat, axis=-1))

        # Add pairwise distances
        for i, j in combinations(range(num_elems), 2):
            v = res[:, :, 7 * i:7 * i + 3] - res[:, :, 7 * j:7 * j + 3]
            dist = tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)
            dist = tf.sqrt(dist)
            dist *= HAND_SCALE
            extra_features.append(dist)

        # Manually scale coordinates
        base_features = []
        for i in range(num_elems):
            pos = res[:, :, 7 * i:7 * i + 3]
            rot = res[:, :, 7 * i + 3:7 * i + 7]
            # Do not scale head coordinates
            if drop_head or i > 0:
                pos *= HAND_SCALE
            rot *= ROTATION_SCALE
            base_features.append(pos)
            base_features.append(rot)
        base_features.append(res[:, :, 7 * num_elems:])

        res = tf.concat(base_features + extra_features, axis=-1)

        # # Remove rotations
        # stripped = []
        # for i in range(num_elems):
        #     stripped.append(res[:, :, 7 * i:7 * i + 3])
        # res = tf.concat(stripped + [res[:, :, 7 * num_elems:]], axis=-1)

        # # Quantization
        # split_x = [-10, 10, 30]
        # split_y = [-30, -10, 10, 30]
        # split_z = [-30, -10, 10, 30]
        # base_orientation_vector = tf.constant([1, 0, 0], dtype=res.dtype)
        # class_vectors = make_orientation_class_vectors(max_components=2)
        # rot_shape = [res.get_shape()[0].value, -1, len(class_vectors)]
        # features = []
        # for i in range(num_elems):
        #     base_idx = 7 * 1
        #     position = res[..., base_idx:base_idx + 3]
        #     rotation = res[..., base_idx + 3:base_idx + 7]
        #     pos_quant = quantize_position(position,
        #                                   split_x, split_y, split_z)
        #     features.append(pos_quant)
        #     rot_quant = quantize_rotation(rotation, base_orientation_vector,
        #                                   class_vectors)
        #     rot_quant = tf.reshape(rot_quant, rot_shape)
        #     features.append(rot_quant)
        # features.append(res[..., 7 * num_elems:])
        # res = tf.concat(features, axis=-1)

        return res
    return input_filter
