#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
from multiprocessing.pool import Pool
import sys
import os

import numpy as np
import quaternion
import pandas as pd

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# TEST_SPLIT = 0
TEST_SPLIT = 0.15
MIRROR = True
ROOT_REF = 'Hips'
JOINTS = [
    'Hips',
    'Spine',
    'Spine1',
    'Neck',
    'Head',
    'Head_end',
    'LeftShoulder',
    'LeftArm',
    'LeftForeArm',
    'LeftHand',
    'LeftHand_end',
    'RightShoulder',
    'RightArm',
    'RightForeArm',
    'RightHand',
    'RightHand_end',
    'LeftUpLeg',
    'LeftLeg',
    'LeftFoot',
    'LeftFoot_end',
    'RightUpLeg',
    'RightLeg',
    'RightFoot',
    'RightFoot_end',
    'Tail',
    'Tail1',
    'Tail1_end',
]
FRONT_FOOT_IDX = [10, 15]
BACK_FOOT_IDX = [19, 23]
FOOT_IDX = FRONT_FOOT_IDX + BACK_FOOT_IDX
# IMPORTANT_JOINTS = [8, 9, 10, 13, 14, 15, 17, 18, 19, 21, 22, 23]
STYLES = ['Stand', 'Walk', 'Sit', 'Down', 'Up', 'Jump']
STANDING_STYLES = ['Stand', 'Sit', 'Down', 'Up']
# INVALID_STYLES = []
# INVALID_STYLES = ['Up', 'Jump']
INVALID_STYLES = ['Sit', 'Down', 'Up', 'Jump']
GAITS = ['Walk', 'Trot', 'Gallop']
INVALID_GAITS = []
# INVALID_GAITS = ['Trot']
USE_NO_GAIT = True
# USE_NO_GAIT = False
USE_GAIT_WEIGHT = True
GAIT_INV_WEIGHTS = {'Stand': 20640, 'Walk': 11782, 'Trot': 333, 'Gallop': 660}
GAIT_WEIGHT_EXP = 0.7
STANDING_PHASE_MIN_CYCLE = 6  # 1/4 s @ 24 fps
PAST_TRAJECTORY_WINDOW = 6
FUTURE_TRAJECTORY_WINDOW = 6
TRAJECTORY_STEP_SIZE = 4  # 1/6 s @ 24 fps
STANDING_SPEED_THRESHOLD = 0.1

STEP_NOTIFY = 'DogStep'
NUM_PHASES = 4
STANDING = 'PfnnStanding'
STYLE_NOTIFY = 'DogStyle:Name'
GAIT_PREFIX = ':DogLocomotion'

def main(data_dir, output_dir):
    np.random.seed(0)
    # Make output directory in advance
    output_dir = Path(output_dir)
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    if TEST_SPLIT > 0:
        train_dir = Path(output_dir, 'train')
        if not train_dir.exists():
            train_dir.mkdir(parents=True)
        test_dir = Path(output_dir, 'test')
        if not test_dir.exists():
            test_dir.mkdir(parents=True)
    # Preprocess in multiple processes
    print('Preprocessing files...')
    data_files = list(Path(data_dir).glob('**/*.CSV'))
    with Pool() as pool:
        for data_file, examples in zip(data_files, pool.imap(process_file, data_files)):
            inputs, phase, outputs, weights = examples
            data_name = data_file.stem
            print(f'Saving records for "{data_name}"...')
            if TEST_SPLIT > 0:
                n_test = int(TEST_SPLIT * len(inputs))
                p = np.random.permutation(len(inputs))
                idx_train = p[:-n_test]
                idx_test = p[-n_test:]
                write_records(inputs[idx_train], phase[idx_train], outputs[idx_train], weights[idx_train], train_dir, data_name)
                write_records(inputs[idx_test], phase[idx_test], outputs[idx_test], weights[idx_test], test_dir, data_name)
            else:
                write_records(inputs, phase, outputs, weights, output_dir, data_name)
    print('Done')

def process_file(data_file):
    dtype = np.float32
    df = pd.read_csv(data_file, index_col=0)
    # Get styles
    style_col = f':notify:{NUM_PHASES + 1}:{STYLE_NOTIFY}'
    styles = get_styles(df.get(style_col, pd.Series('Stand', index=df.index, name=style_col)))
    assert (styles > 0).any(axis=1).all(), data_file.stem
    # Get gaits
    gaits = get_gaits(df)
    # Compute phases
    step_notifies = []
    for i in range(NUM_PHASES):
        col_name = f':notify:{i + 1}:{STEP_NOTIFY}'
        col = next((c for c in df.columns if c.startswith(col_name)), None)
        if col:
            step_notifies.append(df[col])
        else:
            step_notifies.append(pd.Series(0, index=df.index, name=col_name))
    standing = ~(styles[list(set(STYLES) - set(STANDING_STYLES))] > 0).any(1)
    phases = [compute_phase(s, standing) for s in step_notifies]
    # Get joint transforms
    joint_transforms, data_joints = get_joint_transforms(df)
    assert ROOT_REF in data_joints
    assert set(data_joints).issuperset(JOINTS)
    # Extract root transforms
    root_tx = extract_root_tx(df, ROOT_REF)
    # Get ordered joint transforms
    joint_idx = [data_joints.index(joint) for joint in JOINTS]
    joint_transforms = joint_transforms[:, joint_idx]
    # Use standing and gaits as styles
    no_gait = 1 - gaits.values.sum(1, keepdims=True).clip(0, 1)
    gait_styles = np.concatenate([no_gait, gaits.values], axis=1)
    # Filter by styles
    mask = ~(styles[INVALID_STYLES] != 0).any(1).values
    # Filter gaits
    if not USE_NO_GAIT:
        mask &= ~(gait_styles[:, 0] != 0)
        gait_styles[:, 0] = 0
    for i, g in enumerate(GAITS):
        if g in INVALID_GAITS:
            mask &= ~(gait_styles[:, i + 1] != 0)
            gait_styles[:, i + 1] = 0
    # Remove invalid weights
    gait_styles = gait_styles[:, [0] + [i + 1 for i, g in enumerate(GAITS) if g not in INVALID_GAITS]]
    # Make examples
    inputs, phase, outputs, ws = make_examples(
        root_tx, joint_transforms, gait_styles, np.stack(phases, axis=1),
        PAST_TRAJECTORY_WINDOW, FUTURE_TRAJECTORY_WINDOW, TRAJECTORY_STEP_SIZE,
        mask=mask, dtype=dtype)
    # Return examples
    return inputs, phase, outputs, ws

def get_gaits(df):
    n = len(df)
    gaits = []
    for g in GAITS:
        g_name = f'{GAIT_PREFIX}{g}'
        col = next((c for c in df if c.endswith(g_name)), None)
        g_data = df[col].astype(np.float32) if col is not None else np.zeros(n, np.float32)
        gaits.append(pd.Series(g_data, name=g).reset_index(drop=True))
    gaits = pd.concat(gaits, axis=1)
    gaits.columns = GAITS
    gaits[gaits.sum(1) > 1] = np.nan
    gaits.interpolate(inplace=True)
    return gaits

def get_styles(styles):
    df = pd.DataFrame(index=styles.index, columns=STYLES, dtype=np.float32)
    for s in STYLES:
        df[s] = styles.str.contains(s)
    df[df.sum(1) > 1] = np.nan
    df.interpolate(inplace=True)
    return df

def compute_phase(notifies, standing):
    g = (standing.astype(np.float32).diff() != 0).cumsum()
    st = pd.concat([g, standing], axis=1)
    blocks = st.groupby(g).apply(
        lambda gr: pd.Series([gr.iloc[:, 1].all(), gr.first_valid_index(), gr.last_valid_index()]))
    notifies = notifies.copy()
    for _, (stand, start, end) in blocks.iterrows():
        if stand:
            notifies[start:end:STANDING_PHASE_MIN_CYCLE] = 1
    notifies.iloc[0] = 1
    notifies.iloc[-1] = 1
    c = notifies.cumsum()
    return (c * notifies).replace(0, np.nan).interpolate() - c

# Transforms are multiplied FROM THE LEFT: Transform * Object
# Rightmost transform is applied first: Transform3 * Transform2 * Transform1
def get_joint_transforms(df, dtype=None):
    dtype = dtype or np.float32
    joints = [col.split(':')[0] for col in df if col.endswith(':cs:position:x')]
    joint_tx = np.zeros((len(df), len(joints), 4, 4), dtype=dtype)
    for i, joint in enumerate(joints):
        pos = df[[joint + ':cs:position:x', joint + ':cs:position:y', joint + ':cs:position:z']].values
        quat = df[[joint + ':cs:quaternion:w', joint + ':cs:quaternion:x', joint + ':cs:quaternion:y', joint + ':cs:quaternion:z']].values
        quat = quaternion.as_quat_array(quat)
        rot = quaternion.as_rotation_matrix(quat)
        tx = np.concatenate([rot, pos[..., np.newaxis]], axis=2)
        tx = np.concatenate([tx, np.tile([[[0, 0, 0, 1]]], (len(tx), 1, 1))], axis=1)
        joint_tx[:, i] = tx
    return joint_tx, joints

def extract_root_tx(df, root_ref):
    pos = df[[root_ref + ':cs:position:x', root_ref + ':cs:position:y', root_ref + ':cs:position:z']].values.copy()
    quat = df[[root_ref + ':cs:quaternion:w', root_ref + ':cs:quaternion:x', root_ref + ':cs:quaternion:y', root_ref + ':cs:quaternion:z']].values
    quat = quaternion.as_quat_array(quat)
    pos[:, 2] = 0
    rot = quaternion.as_rotation_matrix(quat)
    rx = rot @ [1., 0., 0.]
    yaw = np.arctan2(rx[:, 1], rx[:, 0])
    quat2 = quaternion.from_float_array(np.stack([np.cos(yaw / 2), np.zeros_like(yaw), np.zeros_like(yaw), np.sin(yaw / 2)], axis=1))
    rot2 = quaternion.as_rotation_matrix(quat2)
    tx = np.concatenate([rot2, pos[..., np.newaxis]], axis=2)
    tx = np.concatenate([tx, np.tile([[[0, 0, 0, 1]]], (len(tx), 1, 1))], axis=1)
    return tx

def make_examples(root_transforms, joint_transforms, styles, phase,
                  past_trajectory_window, future_trajectory_window, trajectory_step_size,
                  with_terrain=False, with_contact=False,
                  first_idx=None, last_idx=None, mask=None, mirror=False, dtype=None):
    dtype = dtype or root_transforms.dtype
    num_joints = joint_transforms.shape[1]
    num_frames = len(root_transforms)
    # Check indices
    past_trajectory_size = past_trajectory_window * trajectory_step_size
    future_trajectory_size = future_trajectory_window * trajectory_step_size
    total_trajectory_window = past_trajectory_window + future_trajectory_window
    total_trajectory_size = past_trajectory_size + future_trajectory_size
    if first_idx is None: first_idx = 0
    if last_idx is None: last_idx = num_frames - 1
    if mask is None: mask = np.ones(num_frames, np.bool)
    first_idx = max(first_idx + 1, past_trajectory_size + 1)
    last_idx = min(last_idx, num_frames - future_trajectory_size - 2)
    num_examples = max(last_idx - first_idx + 1, 0)
    mask2 = mask.copy()
    for i in range(len(mask)):
        mask2[i] = mask[max(i - past_trajectory_size, 0):i + future_trajectory_size + 1].all()
    mask = mask2[first_idx:last_idx + 1]
    # Input size
    input_traj_pos_idx = 0
    input_traj_pos_end = input_traj_pos_idx+ 2 * total_trajectory_window
    input_traj_dir_idx = input_traj_pos_end
    input_traj_dir_end = input_traj_pos_end + 2 * total_trajectory_window
    input_style_idx = input_traj_dir_end
    input_style_end = input_style_idx + styles.shape[1] * total_trajectory_window
    input_joint_pos_idx = input_style_end
    input_joint_pos_end = input_joint_pos_idx + 3 * num_joints
    input_joint_vel_idx = input_joint_pos_end
    input_joint_vel_end = input_joint_vel_idx + 3 * num_joints
    input_terrain_idx = input_joint_vel_end
    input_terrain_end = input_terrain_idx + (3 * total_trajectory_window if with_terrain else 0)
    input_size = input_terrain_end
    # Output size
    output_root_vel_idx = 0
    output_root_vel_end = output_root_vel_idx + 2
    output_root_angvel_idx = output_root_vel_end
    output_root_angvel_end = output_root_angvel_idx + 1
    output_phase_idx = output_root_angvel_end
    output_phase_end = output_phase_idx + phase.shape[1]
    output_contact_idx = output_phase_end
    output_contact_end = output_contact_idx + (4 if with_contact else 0)
    output_traj_pos_idx = output_contact_end
    output_traj_pos_end = output_traj_pos_idx + 2 * future_trajectory_window
    output_traj_dir_idx = output_traj_pos_end
    output_traj_dir_end = output_traj_dir_idx + 2 * future_trajectory_window
    output_joint_pos_idx = output_traj_dir_end
    output_joint_pos_end = output_joint_pos_idx + 3 * num_joints
    output_joint_vel_idx = output_joint_pos_end
    output_joint_vel_end = output_joint_vel_idx + 3 * num_joints
    # output_joint_quat_idx = output_joint_vel_end
    # output_joint_quat_end = output_joint_quat_idx + 4 * num_joints
    output_joint_xy_idx = output_joint_vel_end
    output_joint_xy_end = output_joint_xy_idx + 6 * num_joints
    output_size = output_joint_xy_end
    # Fix root axes
    root_transforms_fixed = root_transforms.copy()
    #root_transforms_fixed /= root_transforms_fixed[:, 3:, 3:]
    #root_transforms_fixed[:, :3, :3] = root_transforms_fixed[:, :3, :3] @ root_rotation_fix(root_transforms_fixed.dtype)
    # Make example arrays
    inputs = np.zeros(input_size, dtype=dtype)
    outputs = np.zeros(output_size, dtype=dtype)
    phase_diff = np.diff(phase, axis=0)
    phase_diff[phase_diff < 0] += 1
    root_inv = np.linalg.inv(root_transforms_fixed)
    forward = np.array([[1], [0], [0]], dtype=root_inv.dtype)
    inputs = np.empty((num_examples, input_size), dtype=dtype)
    outputs = np.empty((num_examples, output_size), dtype=dtype)
    for i_example in range(num_examples):
        i_frame = i_example + first_idx
        # Current trajectory
        trajectory_start_idx = i_frame - past_trajectory_size
        trajectory_end_idx = i_frame + future_trajectory_size
        trajectory_tx = root_transforms_fixed[trajectory_start_idx:trajectory_end_idx]
        trajectory_tx = root_inv[i_frame] @ trajectory_tx
        trajectory_tx /= trajectory_tx[:, 3:, 3:]
        trajectory_pos = trajectory_tx[:, :2, 3]
        trajectory_rot = (trajectory_tx[:, :3, :3] @ forward)[:, :2, 0]
        trajectory_rot /= np.linalg.norm(trajectory_rot, axis=-1, keepdims=True)
        inputs[i_example, input_traj_pos_idx:input_traj_pos_end] = trajectory_pos[::trajectory_step_size].ravel()
        inputs[i_example, input_traj_dir_idx:input_traj_dir_end] = trajectory_rot[::trajectory_step_size].ravel()
        # Future trajectory
        future_trajectory_tx = root_transforms_fixed[i_frame + 1:trajectory_end_idx + 1]
        future_trajectory_tx = root_inv[i_frame + 1] @ future_trajectory_tx
        future_trajectory_tx /= future_trajectory_tx[:, 3:, 3:]
        future_trajectory_pos = future_trajectory_tx[:, :2, 3]
        future_trajectory_rot = (future_trajectory_tx[:, :3, :3] @ forward)[:, :2, 0]
        future_trajectory_rot /= np.linalg.norm(future_trajectory_rot, axis=-1, keepdims=True)
        outputs[i_example, output_traj_pos_idx:output_traj_pos_end] = future_trajectory_pos[::trajectory_step_size].ravel()
        outputs[i_example, output_traj_dir_idx:output_traj_dir_end] = future_trajectory_rot[::trajectory_step_size].ravel()
        # Styles
        inputs[i_example, input_style_idx:input_style_end] = styles[trajectory_start_idx:trajectory_end_idx:trajectory_step_size].T.ravel()
        # Joints
        joint_frame = root_inv[i_frame - 1:i_frame + 2, np.newaxis] @ joint_transforms[i_frame - 1:i_frame + 2]
        joint_frame /= joint_frame[:, :, 3:, 3:]
        joint_prev, joint_cur, joint_next = joint_frame
        inputs[i_example, input_joint_pos_idx:input_joint_pos_end] = (joint_cur[:, :3, 3]).ravel()
        inputs[i_example, input_joint_vel_idx:input_joint_vel_end] = (joint_cur[:, :3, 3] - joint_prev[:, :3, 3]).ravel()
        outputs[i_example, output_joint_pos_idx:output_joint_pos_end] = joint_next[:, :3, 3].ravel()
        outputs[i_example, output_joint_vel_idx:output_joint_vel_end] = (joint_next[:, :3, 3] - joint_cur[:, :3, 3]).ravel()
        # joint_quat = quaternion.as_float_array(quaternion.from_rotation_matrix(joint_next[:, :3, :3]))
        # outputs[i_example, output_joint_quat_idx:output_joint_quat_end] = joint_quat.ravel()
        # Compute X and Y directions for each joint
        joint_xy = joint_next[:, :3, :3] @ np.array([[1, 0], [0, 1], [0, 0]])
        joint_xy = np.transpose(joint_xy, (0, 2, 1))
        outputs[i_example, output_joint_xy_idx:output_joint_xy_end] = joint_xy.ravel()
        # Output updates
        root_vel = trajectory_pos[past_trajectory_size + 1] - trajectory_pos[past_trajectory_size]
        outputs[i_example, output_root_vel_idx:output_root_vel_end] = root_vel
        root_angvel_x, root_angvel_y = trajectory_rot[past_trajectory_size + 1]
        outputs[i_example, output_root_angvel_idx:output_root_angvel_end] = np.arctan2(root_angvel_y, root_angvel_x)
        outputs[i_example, output_phase_idx:output_phase_end] = phase_diff[i_frame]
        # Terrain
        if with_terrain:
            # TODO
            inputs[input_terrain_idx:input_terrain_end] = 0
        # Contact
        if with_contact:
            # TODO
            outputs[output_contact_idx:output_contact_end] = 0
    weights = compute_weights(inputs, past_trajectory_window, future_trajectory_window, num_joints, styles.shape[1])
    return inputs[mask], phase[first_idx:last_idx + 1][mask], outputs[mask], weights[mask]

def get_mirror_joint_idx():
    mirror_idx = np.arange(len(JOINTS))
    for i, j in enumerate(JOINTS):
        if j.startswith('Left'):
            mirror_idx[i] = JOINTS.index(f'Right{j[4:]}')
        elif j.startswith('Right'):
            mirror_idx[i] = JOINTS.index(f'Left{j[5:]}')
    return mirror_idx

# Rotation for the root that transforms Maya axes to UE4 axes
def root_rotation_fix(dtype=None):
    dtype = dtype or np.float32
    # Axes given in columns
    axes_ue = np.array([[ 1,  0,  0],
                        [ 0,  1,  0],
                        [ 0,  0,  1]], dtype=dtype)
    axes_maya = np.array([[ 0,  0,  1],
                          [-1,  0,  0],
                          [ 0, -1,  0]], dtype=dtype)
    return axes_ue @ np.linalg.inv(axes_maya)

def compute_weights(x, past_trajectory_window, future_trajectory_window, num_joints, num_styles):
    w = None

    # Gait weight
    #sty = ???
    #w = (sty * gait_weights(sty.dtype)).sum(1)

    # Velocity / angle weight
    fut_pos_start = 2 * past_trajectory_window
    fut_pos_end = fut_pos_start + 2 * future_trajectory_window
    fut_pos = np.reshape(x[:, fut_pos_start:fut_pos_end], (-1, future_trajectory_window, 2))
    vel = np.sum(np.linalg.norm(fut_pos[:, 1:] - fut_pos[:, :-1], axis=-1), axis=-1)
    ang = np.arctan2(fut_pos[:, -1, 1], fut_pos[:, -1, 0])
    #w = (vel + 4) / 32  # Velocity weight
    #w = np.abs(ang) + 0.25  # Angle weight 1
    #w = 1.25 - np.cos(1.5 * ang)  # Angle weight 2 (better)
    w = 1 + (vel / 100) * (1 + np.abs(ang))  # Angvel weight

    # Back foot weight
    #joint_vel_start = ((4 + num_styles) * (past_trajectory_window + future_trajectory_window) + 3 * num_joints)
    #joint_vel_end = joint_vel_start + 3 * num_joints
    #joint_vel = np.linalg.norm(x[:, joint_vel_start:joint_vel_end].reshape(-1, num_joints, 3), axis=-1)
    #back_foot_vel = joint_vel[:, BACK_FOOT_IDX].sum(1)
    #w = (back_foot_vel + 1) / 5

    return w

def gait_weights(dtype=None):
    # w_i = {N_i ^ a} / {\sum_j N_j ^ a} with a=-0.7
    n = np.array([GAIT_INV_WEIGHTS['Stand'], *(GAIT_INV_WEIGHTS[g] for g in GAITS if g not in INVALID_GAITS)])
    gw = n ** (-GAIT_WEIGHT_EXP)
    if not USE_NO_GAIT:
        gw[0] = 0
    gw /= max(gw.sum(), 1e-8)
    if dtype is not None:
        gw = gw.astype(dtype)
    return gw

def write_records(inputs, phases, outputs, weights, output_dir, file_name):
    assert len(inputs) == len(phases) == len(outputs)
    import tensorflow as tf
    output_dir = Path(output_dir)
    if output_dir.exists() and not output_dir.is_dir():
        raise ValueError(f'Invalid directory {output_dir}.')
    output_dir = output_dir.resolve()
    # Write records
    output_path = Path(output_dir, f'{file_name}.tfrecord')
    with tf.io.TFRecordWriter(str(output_path)) as writer:
        for i in range(len(inputs)):
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        'x': float_feature(inputs[i]),
                        'phase': float_feature(phases[i]),
                        'y': float_feature(outputs[i]),
                        'weight': float_feature([weights[i]]),
                    }))
            writer.write(example.SerializeToString())

def float_feature(value):
    import tensorflow as tf
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

if __name__ == '__main__':
    if len(sys.argv) != 3:
        print(f'Usage: {sys.argv[0]} <data dir> <output dir>', file=sys.stderr)
        exit(1)
    main(sys.argv[1], sys.argv[2])
