#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from datetime import datetime
import gc
from itertools import repeat
from multiprocessing import Pool
import random
from pathlib import Path
import sys

import numpy as np
import quaternion

import rice_features as F
from dataio import read_data_directory, save_data_archive
from captures import make_quaternion_from_axis_angle, rotate_coordinates, \
    reindex_data, prerotate_quaternions, postrotate_quaternions, \
    animate_capture


# %% PATHS & SETUP

SEED = 10000

# Captures directory
CAPTURES_DIR = 'captures'
# Training data directory
TRAIN_DIR = 'dataset/train'
# Validation data directory
VALIDATION_DIR = 'dataset/validation'
# Test data directory
TEST_DIR = 'dataset/test'

# %% DATASET CONFIGURATION

validation_split = .15
test_split = .15

num_rotations = 1

displacement_radius = 150.
num_displacements = 1

num_variations = 6
height_decrease = 10
height_increase = 10
scale_down = 1.15
scale_up = 1.15
speed_down = 1.15
speed_up = 1.15
trajectory_max_amplitude = 10.
trajectory_min_frequency = 0.
trajectory_max_frequency = 1.
orientation_max_amplitude = np.pi / 12.  # +/- 15 degrees
orientation_min_frequency = 0.
orientation_max_frequency = 1.

# %% CAPTURE ANIMATION


def rice_animate_capture(capture, **kwargs):
    mean_pos = capture[F.HEAD_POS].mean().values
    mean_pos[-1] = 0
    head = (capture[F.HEAD_POS] - mean_pos, capture[F.HEAD_ROT])
    controlling_hand = (capture[F.CONTROLLING_HAND_POS] - mean_pos,
                        capture[F.CONTROLLING_HAND_ROT])
    secondary_hand = (capture[F.SECONDARY_HAND_POS] - mean_pos,
                      capture[F.SECONDARY_HAND_ROT])
    return animate_capture(head, controlling_hand, secondary_hand,
                           capture.index, **kwargs)

# %% ROTATE


def _rotate_capture(capture, angle, jitter):
    rotated = capture.copy()
    ref = np.average(capture[F.HEAD_POS], axis=0)
    q = make_quaternion_from_axis_angle([0., 0., 1.], angle + jitter)
    for pos in (F.HEAD_POS, F.CONTROLLING_HAND_POS, F.SECONDARY_HAND_POS):
        pos_ref = capture[pos].values - ref
        rotated[pos] = rotate_coordinates(pos_ref.T, q).T
        rotated[pos] += ref
    for rot in (F.HEAD_ROT, F.CONTROLLING_HAND_ROT, F.SECONDARY_HAND_ROT):
        rotated[rot] = postrotate_quaternions(capture[rot].values, q)
    return rotated


def rotate_captures(captures, angles, jitter=0.):
    captures_rotated = []
    with Pool() as pool:
        for angle in angles:
            jitters = np.random.uniform(-1., 1., len(captures)) * jitter
            captures_rotated.extend(
                pool.starmap(_rotate_capture,
                             zip(captures, repeat(angle), jitters)))
    return captures_rotated

# %% DISPLACE


def _displace_capture(capture, x, y):
    displaced = capture.copy()
    shift = np.zeros(3)
    ref = np.average(capture[F.HEAD_POS], axis=0)
    shift[:2] = np.array([x, y]) - ref[:2]
    for pos in (F.HEAD_POS, F.CONTROLLING_HAND_POS, F.SECONDARY_HAND_POS):
        displaced[pos] = capture[pos] + shift
    return displaced


def displace_captures(captures, radius):
    xs = np.random.uniform(-1, 1, len(captures)) * radius
    ys = np.random.uniform(-1, 1, len(captures)) * radius
    with Pool() as pool:
        return list(pool.starmap(_displace_capture, zip(captures, xs, ys)))

# %% VARIATIONS


def height_var(decrease, increase=None):
    if increase is None:
        increase = decrease
    if decrease < 0 or increase < 0:
        raise ValueError('Height variations cannot be negative.')

    def fun(capture):
        variation = capture.copy()
        h = np.random.uniform(-decrease, increase)
        zs = [F.HEAD_Z, F.CONTROLLING_HAND_Z, F.SECONDARY_HAND_Z]
        variation[zs] = variation[zs] + h
        return variation
    return fun


def scale_var(downscale, upscale=None):
    if upscale is None:
        upscale = downscale
    if downscale < 1 or upscale < 1:
        raise ValueError('Scaling values must be greater or equal to one.')

    def fun(capture):
        variation = capture.copy()
        ref = np.average(capture[F.HEAD_POS], axis=0)
        s = np.random.uniform(1. - downscale, upscale - 1.)
        s = s + 1. if s >= 0. else 1. / (1. - s)
        for pos in [F.HEAD_POS, F.CONTROLLING_HAND_POS, F.SECONDARY_HAND_POS]:
            variation[pos] = (capture[pos] - ref) * s + ref
        return variation
    return fun


def trajectory_var(max_amplitude, min_frequency, max_frequency=None):
    if max_frequency is None:
        max_frequency = min_frequency
    if max_amplitude < 0:
        raise ValueError('Amplitude cannot be negative.')
    if min_frequency < 0 or max_frequency < min_frequency:
        raise ValueError('Invalid frequency range.')

    def fun(capture):
        variation = capture.copy()
        t = capture.index.values
        pos = F.HEAD_POS + F.CONTROLLING_HAND_POS + F.SECONDARY_HAND_POS
        for coord in pos:
            f = np.random.uniform(min_frequency, max_frequency)
            a = np.random.uniform(0, max_amplitude)
            p = np.random.uniform(0, 2 * np.pi)
            variation[coord] = capture[coord] + a * np.sin(f * t + p)
        return variation
    return fun


def orientation_var(max_amplitude, min_frequency, max_frequency=None):
    if max_frequency is None:
        max_frequency = min_frequency
    if max_amplitude < 0:
        raise ValueError('Amplitude cannot be negative.')
    if min_frequency < 0 or max_frequency < min_frequency:
        raise ValueError('Invalid frequency range.')

    def fun(capture):
        variation = capture.copy()
        t = capture.index.values
        for rot in [F.HEAD_ROT, F.CONTROLLING_HAND_ROT, F.SECONDARY_HAND_ROT]:
            # Rotation axes
            axes = np.empty((len(t), len(rot)))
            for i, _ in enumerate(rot):
                f = np.random.uniform(min_frequency, max_frequency)
                p = np.random.uniform(0, 2 * np.pi)
                axes[:, i] = np.sin(f * t + p)
            axes = np.sum(np.square(axes), axis=1, keepdims=True)
            # Rotation angles
            a = np.random.uniform(0, max_amplitude)
            f = np.random.uniform(min_frequency, max_frequency)
            p = np.random.uniform(0, 2 * np.pi)
            angles = a * np.sin(f * t + p)
            rot_vectors = axes * angles[:, np.newaxis]
            qs = quaternion.from_rotation_vector(rot_vectors)
            variation[rot] = prerotate_quaternions(capture[rot].values.copy(),
                                                   qs)
        return variation
    return fun


def speed_var(slow_down, speed_up=None):
    if speed_up is None:
        speed_up = slow_down
    if slow_down < 1 or speed_up < 1:
        raise ValueError('Speed values must be greater or equal to one.')

    def fun(capture):
        s = np.random.uniform(1. - slow_down, speed_up - 1.)
        s = s + 1. if s >= 0. else 1. / (1. - s)
        idx = capture.index.values
        num_frames = int(np.round(len(idx) / s))
        idx_scaled = np.linspace(np.min(idx), np.max(idx), num_frames)
        idx_map = np.searchsorted(idx, idx_scaled) - 1
        idx_map = np.clip(idx_map, 0, len(idx) - 1)
        variation = capture.iloc[idx_map, :].copy()
        idx_new = np.linspace(np.min(idx), np.max(idx) / s, num_frames,
                              dtype=idx.dtype)
        variation.index = idx_new
        variation.index.name = capture.index.name
        for pos, rot in [(F.HEAD_POS, F.HEAD_ROT),
                         (F.CONTROLLING_HAND_POS, F.CONTROLLING_HAND_ROT),
                         (F.SECONDARY_HAND_POS, F.SECONDARY_HAND_ROT)]:
            variation[pos] = reindex_data(capture[pos], idx_scaled).values
            variation[rot] = reindex_data(capture[rot], idx_scaled,
                                          slerp=True).values
        variation[F.ENVIRONMENT] = reindex_data(
            capture[F.ENVIRONMENT], idx_scaled).values
        return variation
    return fun


def _make_capture_variations(capture, variation_ops, num_variations,
                             seed=None):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
    variations = []
    for _ in range(num_variations):
        variation = capture.copy()
        random.shuffle(variation_ops)
        for variation_op, variation_params in variation_ops:
            variation = variation_op(*variation_params)(variation)
        variations.append(variation)
    return variations


def make_variations(captures, variation_ops, num_variations):
    variation_ops = list(variation_ops)
    seeds = np.random.randint(0, 1000000, len(captures))
    with Pool() as pool:
        return sum(pool.starmap(_make_capture_variations,
                                zip(captures,
                                    repeat(variation_ops),
                                    repeat(num_variations),
                                    seeds)), [])

# %% FUSION


def locate_gestures(captures, classes=None):
    if classes is None:
        classes = np.unique(np.concatenate(
            [np.unique(c[F.GESTURE_CLASS].values) for c in captures], axis=0))
    classes = np.asarray(classes)[:, np.newaxis]
    gestures_dense = []
    located_gestures = {c[0]: [] for c in classes}
    for i_capture, capture in enumerate(captures):
        capture_classes = capture[F.GESTURE_CLASS].values
        hits = capture_classes[np.newaxis, :] == classes
        # Add margins
        hits = np.c_[np.zeros(len(classes)), hits, np.zeros(len(classes))]
        hits_diff = np.diff(hits, axis=1)
        starts, = np.where(np.logical_or.reduce(hits_diff > 0, axis=0))
        ends, = np.where(np.logical_or.reduce(hits_diff < 0, axis=0))
        assert len(starts) == len(ends)
        capture_dense = np.empty((len(starts), 2), dtype=np.int32)
        for i_gesture, (start, end) in enumerate(zip(starts, ends)):
            capture_dense[i_gesture] = (start, end)
            gesture_class = capture_classes[start]
            located_gestures[gesture_class].append((i_capture, i_gesture))
        gestures_dense.append(capture_dense)
    return gestures_dense, located_gestures


# %% MAIN


def _log(msg):
    print('{} - {}'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg))


def main(dataset_dir):
    random.seed(SEED)
    np.random.seed(SEED)

    _log('Loading data...')
    base_captures = read_data_directory(
        Path(dataset_dir, CAPTURES_DIR), index_col=F.INDEX)
    num_captures = len(base_captures)
    _log('Data loaded (examples: {}).'.format(num_captures))

    # Train / test split
    num_validation_captures = int(np.round(num_captures * validation_split))
    num_test_captures = int(np.round(num_captures * test_split))
    num_train_captures = num_captures - (num_validation_captures +
                                         num_test_captures)
    base_captures_shuffle = base_captures.copy()
    random.shuffle(base_captures_shuffle)
    train_captures = base_captures_shuffle[:num_train_captures]
    validation_captures = \
        base_captures_shuffle[num_train_captures:
                              num_train_captures + num_validation_captures]
    test_captures = base_captures_shuffle[-num_test_captures:]
    del base_captures, base_captures_shuffle
    _log('Data split (train: {}, validation: {}, test: {}).'.format(
        len(train_captures), len(validation_captures), len(test_captures)))

    _log('Applying rotations...')
    angles = np.arange(num_rotations) * 2. * np.pi / num_rotations
    jitter = np.pi / num_rotations
    train_captures_rotated = rotate_captures(train_captures, angles, jitter)
    validation_captures_rotated = rotate_captures(validation_captures, angles,
                                                  jitter)
    test_captures_rotated = rotate_captures(test_captures, angles, jitter)
    del train_captures, validation_captures, test_captures
    _log('Rotations applied (train: {}, validation: {}, test: {}).'.format(
        len(train_captures_rotated),
        len(train_captures_rotated),
        len(test_captures_rotated)))

    _log('Applying displacements...')
    train_captures_displaced = []
    validation_captures_displaced = []
    test_captures_displaced = []
    for _ in range(num_displacements):
        train_captures_displaced.extend(
            displace_captures(train_captures_rotated,
                              displacement_radius))
        validation_captures_displaced.extend(
            displace_captures(validation_captures_rotated,
                              displacement_radius))
        test_captures_displaced.extend(
            displace_captures(test_captures_rotated,
                              displacement_radius))
    del train_captures_rotated, validation_captures_rotated, \
        test_captures_rotated
    _log('Displacements applied (train: {}, validation: {}, test: {}).'.format(
        len(train_captures_displaced),
        len(validation_captures_displaced),
        len(test_captures_displaced)))

    _log('Applying variations...')
    variation_ops = [
        (height_var, (height_decrease,
                      height_increase)),
        (scale_var, (scale_down,
                     scale_up)),
        (speed_var, (speed_down,
                     speed_up)),
        (trajectory_var, (trajectory_max_amplitude,
                          trajectory_min_frequency,
                          trajectory_max_frequency)),
        (orientation_var, (orientation_max_amplitude,
                           orientation_min_frequency,
                           orientation_max_frequency)),
    ]
    train_captures_variations = train_captures_displaced \
        + make_variations(train_captures_displaced,
                          variation_ops, num_variations)
    # It probably does not make sense to make variations
    # for validation and test data
    validation_captures_variations = validation_captures_displaced
    test_captures_variations = test_captures_displaced
    del train_captures_displaced, validation_captures_displaced, \
        test_captures_displaced
    _log('Variations applied (train: {}, validation: {}, test: {}).'.format(
        len(train_captures_variations),
        len(validation_captures_variations),
        len(test_captures_variations)))

    _log('Saving data...')
    for data, directory in [(train_captures_variations, TRAIN_DIR),
                            (validation_captures_variations, VALIDATION_DIR),
                            (test_captures_variations, TEST_DIR)]:
        gc.collect()
        directory_path = Path(dataset_dir, directory)
        if not directory_path.exists():
            directory_path.mkdir(parents=True)
        save_data_archive(data, directory_path, 'examples')
        _log('{} examples saved to {}.'.format(len(data), directory_path))
    _log('Data saved.')


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print('Usage: {} <dataset dir>'.format(sys.argv[0]), file=sys.stderr)
        sys.exit(1)
    main(sys.argv[1])
