#!/usr/bin/env python3

from pathlib import Path
from multiprocessing.pool import Pool
import itertools
import sys
import os

import numpy as np
import quaternion
import pandas as pd

from seg2seg import SegmentsClosestPoints

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

ROOT = 'root'
HEAD = 'head'
NPC_JOINTS = [
    'COG',
    'spine_00',
    'spine_01',
    'spine_02',
    'spine_03',
    'sternum',
    'neck',
    'head',
    'l_shoulder',
    'l_arm',
    'l_elbow',
    'l_wrist',
    'r_shoulder',
    'r_arm',
    'r_elbow',
    'r_wrist',
    'l_leg',
    'l_knee',
    'l_ankle',
    'l_ball',
    'r_leg',
    'r_knee',
    'r_ankle',
    'r_ball',
    'r_weapon_01',
]
NPC_SWORD_TIP = 'r_weapon_01'
NPC_SWORD_VECTOR = [90., 0., 0.]
NPC_SWORD_OFFSET = [-20., 0., 0.]
PLAYER_SWORD_TIP = 'l_weapon_01'
PLAYER_SWORD_VECTOR = [-90., 0., 0.]
PLAYER_SWORD_OFFSET = [0., 0., 0.]
COG = 'COG'
STYLES = [
    'stand',
    'block',
]
TRAJECTORY_WINDOW = 3
TRAJECTORY_STEP_SIZE = 3
COORD_DEPTH_RANGE = 40, 110
COORD_WIDTH = 120
COORD_HEIGHT = 200

STRETCH_BLOCK_MAX = 15

STANDING = 'PfnnStanding'
STYLE = 'PfnnStyle:Name'
SWORD_BLOCK = ':notify:SwordBlock'


def main(data_dir, output_dir, time_scale):
    np.random.seed(100)
    # Make output directory in advance
    output_dir = Path(output_dir)
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    # Preprocess in multiple processes
    print('Preprocessing files...')
    data_files = list(Path(data_dir).glob('*.csv'))
    with Pool() as pool:
        processed = pool.imap(process_file, zip(data_files, itertools.cycle([time_scale])))
        for data_file, examples in zip(data_files, processed):
            inputs, outputs, weights, angle, coords = examples
            data_name = data_file.stem
            print('Saving records for "{}"...'.format(data_name))
            write_records(inputs, outputs, weights, angle, coords, output_dir, data_name)
    print('Done')


def process_file(args):
    data_file, time_scale = args
    dtype = np.float32
    df = pd.read_csv(data_file, index_col=0)
    df = change_speed(df, time_scale)
    df = stretch_sword_blocks(df)
    # Get joint transforms
    joint_transforms, data_joints = get_transforms(df)
    assert ROOT in data_joints
    assert NPC_SWORD_TIP in data_joints
    assert PLAYER_SWORD_TIP in data_joints
    assert set(data_joints).issuperset(NPC_JOINTS)
    num_frames = len(joint_transforms)
    # Fix sword positions
    displace_joint(joint_transforms, data_joints.index(NPC_SWORD_TIP), NPC_SWORD_OFFSET)
    displace_joint(joint_transforms, data_joints.index(PLAYER_SWORD_TIP), PLAYER_SWORD_OFFSET)
    # Extract root transforms
    root_tx = joint_transforms[:, data_joints.index(ROOT)]
    # root_tx = extract_root_from_head(joint_transforms[:, data_joints.index(HEAD)])
    # Fix root axes
    rotation_fix_transform = root_rotation_fix(root_tx.dtype)
    root_transforms_fixed = root_tx.copy()
    root_transforms_fixed /= root_transforms_fixed[:, 3:, 3:]
    root_transforms_fixed[:, :3, :3] = root_transforms_fixed[:, :3, :3] @ rotation_fix_transform
    # Extract sword root transforms
    sword_root_tx = root_transforms_fixed
    if COG:
        # Copy Z from COG for sword root
        sword_root_tx = sword_root_tx.copy()
        sword_root_tx[:, 2, 3] = joint_transforms[:, data_joints.index(COG), 2, 3]
    # Swords transforms
    npc_sword_tip_tx = joint_transforms[:, data_joints.index(NPC_SWORD_TIP)]
    player_sword_tip_tx = joint_transforms[:, data_joints.index(PLAYER_SWORD_TIP)]
    # Get ordered joint transforms
    joint_transforms = joint_transforms[:, [data_joints.index(joint) for joint in NPC_JOINTS]]
    # Get style data
    style = get_styles(df, dtype)

    inputs, outputs, angle, coords = make_examples(root_transforms_fixed, joint_transforms, sword_root_tx,
                                                   npc_sword_tip_tx, player_sword_tip_tx, style,
                                                   TRAJECTORY_WINDOW, TRAJECTORY_STEP_SIZE,
                                                   dtype=dtype)

    # Compute example weights
    weights = None
    # weights = compute_weights_block(df, dtype)
    # distances = inputs[:, 4 * 3 * TRAJECTORY_WINDOW + 3 + 3]
    # weights = compute_distance_weights(distances)

    return inputs, outputs, weights, angle, coords


def displace_joint(joint_transforms, joint_idx, offset):
    joint_transforms[:, joint_idx, :3, 3] += rotate_vector(offset, joint_transforms[:, joint_idx])


def stretch_sword_blocks(df):
    if SWORD_BLOCK not in df:
        return df
    blocks, = np.where(df[SWORD_BLOCK] > 0)
    r = np.random.randint(1, STRETCH_BLOCK_MAX + 1, size=blocks.shape)
    idx = np.arange(len(df))
    reps = np.ones_like(idx)
    reps[blocks] = r
    idx_rep = np.repeat(idx, reps)
    return df.iloc[idx_rep]


def extract_root_from_head(head_tx):
    root_tx = head_tx.copy()
    root_tx[:, 2, 3] = 0
    return root_tx


# Rotation for the root that transforms Maya axes to UE4 axes
def root_rotation_fix(dtype=None):
    dtype = dtype or np.float32
    # Axes given in columns
    axes_ue = np.array([[ 1,  0,  0],
                        [ 0,  1,  0],
                        [ 0,  0,  1]], dtype=dtype)
    axes_maya = np.array([[ 0,  0,  1],
                          [-1,  0,  0],
                          [ 0, -1,  0]], dtype=dtype)
    return axes_ue @ np.linalg.inv(axes_maya)


def change_speed(df, time_scale):
    num_frames = len(df)
    idx = np.arange(num_frames)
    r = np.linspace(0, num_frames - 1, round(num_frames * time_scale))
    s1 = np.floor(r).astype(idx.dtype)
    s2 = (s1 + 1).clip(0, num_frames - 1)
    sr = np.round(r).astype(idx.dtype)
    t = r - s1
    df_r = []
    joints = [col.rsplit(':', 2)[0] for col in df if col.endswith(':position:x')]
    for i, joint in enumerate(joints):
        pos_cols = [joint + ':position:x', joint + ':position:y', joint + ':position:z']
        quat_cols = [joint + ':quaternion:w', joint + ':quaternion:x', joint + ':quaternion:y', joint + ':quaternion:z']
        pos = df[pos_cols].values
        quat = df[quat_cols].values
        quat = quaternion.as_quat_array(quat)
        pos_r = pos[s1] * (1 - t[:, np.newaxis]) + pos[s2] * t[:, np.newaxis]
        quat_r = np.slerp_vectorized(quat[s1], quat[s2], t)
        quat_r = quaternion.as_float_array(quat_r)
        df_r.append(pd.DataFrame(pos_r, columns=pos_cols))
        df_r.append(pd.DataFrame(quat_r, columns=quat_cols))
    notify_cols = [col for col in df if col.startswith(':notify')]
    df_r.append(df[notify_cols].iloc[sr].reset_index(drop=True))
    return pd.concat(df_r, axis=1)


# Transforms are multiplied FROM THE LEFT: Transform * Object
# Rightmost transform is applied first: Transform3 * Transform2 * Transform1
def get_transforms(df, dtype=None):
    dtype = dtype or np.float32
    joints = [col.split(':')[0] for col in df if col.endswith(':cs:position:x')]
    joint_tx = np.zeros((len(df), len(joints), 4, 4), dtype=dtype)
    for i, joint in enumerate(joints):
        pos = df[[joint + ':cs:position:x', joint + ':cs:position:y', joint + ':cs:position:z']].values
        quat = df[[joint + ':cs:quaternion:w', joint + ':cs:quaternion:x', joint + ':cs:quaternion:y', joint + ':cs:quaternion:z']].values
        quat = quaternion.as_quat_array(quat)
        rot = quaternion.as_rotation_matrix(quat)
        tx = np.concatenate([rot, pos[..., np.newaxis]], axis=2)
        tx = np.concatenate([tx, np.tile([[[0, 0, 0, 1]]], (len(tx), 1, 1))], axis=1)
        joint_tx[:, i] = tx
    return joint_tx, joints


def get_sword_trajectories(sword_tx, sword_vector, root_tx, num_steps, step_size):
    traj_tx = get_trajectories(sword_tx, root_tx, num_steps, step_size)
    traj_pos = traj_tx[..., :3, 3]
    traj_dir = rotate_vector(sword_vector, traj_tx)
    traj_dir /= np.linalg.norm(traj_dir, axis=-1, keepdims=True)
    return traj_pos, traj_dir


def get_trajectories(tx, root_tx, num_steps, step_size):
    from numpy.lib.stride_tricks import as_strided
    tx = np.asarray(tx)
    root_tx = np.asarray(root_tx)
    assert tx.ndim == 3 and tx.shape[1] == 4 and tx.shape[2] == 4
    assert root_tx.ndim == 3 and root_tx.shape[1] == 4 and root_tx.shape[2] == 4
    assert len(tx) == len(root_tx)
    num_trajectories = len(tx) - (num_steps - 1) * step_size
    # Get window for each frame
    strides = (tx.strides[0], step_size * tx.strides[0], tx.strides[1], tx.strides[2])
    tx_windows = as_strided(tx, (num_trajectories, num_steps, 4, 4), strides)
    # Make relative to root
    root_tx_inv = np.linalg.inv(root_tx)
    tx_windows = root_tx_inv[-num_trajectories:, np.newaxis] @ tx_windows
    # Normalize and reshape
    tx_windows /= tx_windows[:, :, -1:, -1:]
    # Pad with identity transforms at the beginning to keep same size
    num_pad = len(tx) - num_trajectories
    pad = np.tile(np.eye(4, dtype=tx.dtype)[np.newaxis, np.newaxis], (num_pad, num_steps, 1, 1))
    tx_windows = np.concatenate([pad, tx_windows], axis=0)
    return tx_windows


def rotate_vector(v, tx):
    tx = np.asarray(tx)
    v = np.asarray(v)
    assert tx.ndim >= 2, tx.shape[-1] == 4 and tx.shape[-2] == 4
    assert v.ndim == 1 and len(v) == 3
    tx /= tx[..., 3:, 3:]
    v_rot = tx[..., :3, :3] @ v[:, np.newaxis]
    return v_rot[..., 0]


def transform_vector(v, tx):
    tx = np.asarray(tx)
    v = np.asarray(v)
    assert tx.ndim >= 2, tx.shape[-1] == 4 and tx.shape[-2] == 4
    assert v.ndim == 1 and len(v) == 3
    v_h = np.array([[*v, 1]]).T
    v_rot = tx @ v_h
    v_rot /= v_rot[..., 3:, :]
    return v_rot[..., :3, 0]


def get_styles(df, dtype=None):
    dtype = dtype or np.float32
    style_data = np.zeros((len(df), len(STYLES)), dtype)
    standing_col = ':notify:' + STANDING
    style_col = ':notify:' + STYLE
    if 'stand' in STYLES and standing_col in df:
        style_data[:, STYLES.index('stand')] = df[standing_col].values
    if style_col not in df:
        return style_data
    style_values = df[style_col].fillna('').values
    for i, style in enumerate(style_values):
        if not style: continue
        style_tags = style.split(':')
        style_data[i, [STYLES.index(tag) for tag in style_tags]] = 1
    # Overlapping areas
    overlap = np.sum(style_data, axis=1)
    if np.all(overlap <= 1):
        # No overlapping
        return style_data
    if np.any(overlap > 2):
        raise ValueError('No more than two styles can overlap.')
    overlap, = np.nonzero(overlap > 1)
    overlap_border = np.diff(overlap) > 1
    overlap_begin = overlap[np.insert(overlap_border, 0, True)]
    overlap_end = overlap[np.append(overlap_border, True)] + 1
    assert len(overlap_begin) == len(overlap_end)
    for begin, end in zip(overlap_begin, overlap_end):
        alpha = np.linspace(0, 1, end - begin)[:, np.newaxis]
        style_data[begin:end] = style_data[max(begin - 1, 0)] * (1 - alpha) + style_data[min(end, len(df) - 1)] * alpha
    return style_data


def compute_weights_block(df, dtype=None):
    dtype = dtype or np.float32
    if SWORD_BLOCK not in df:
        return None
    weights = df[SWORD_BLOCK].values.astype(dtype)
    weights *= 10
    smooth_win = np.linspace(-3, 3, 90, dtype=dtype)
    smooth_win = np.exp(-np.square(smooth_win))
    smooth_win /= np.linalg.norm(smooth_win)
    weights = np.convolve(weights, smooth_win, mode='same')
    weights += 1
    return weights


def compute_distance_weights(distances):
    distances = distances / max(distances.max(), 1e-6)
    MAX_WEIGHT = 6
    MIN_WEIGHT = 1
    distances = np.minimum(np.maximum(distances, 1 / MAX_WEIGHT), 1 / MIN_WEIGHT)
    return 1 / distances


def make_examples(root_transforms, joint_transforms, sword_root_transforms,
                  npc_sword_tip_tx, player_sword_tip_tx, style,
                  trajectory_window, trajectory_step_size,
                  first_idx=None, last_idx=None, dtype=None):
    dtype = dtype or root_transforms.dtype
    num_joints = joint_transforms.shape[1]
    num_frames = len(root_transforms)
    # Check indices
    trajectory_size = trajectory_window * (trajectory_step_size - 1) + 1
    if first_idx is None: first_idx = 0
    if last_idx is None: last_idx = num_frames - 1
    first_idx = max(first_idx + 1, trajectory_size)
    last_idx = min(last_idx, num_frames - 2)
    num_examples = max(last_idx - first_idx + 1, 0)
    # Input size
    input_npc_sword_pos_idx = 0
    input_npc_sword_pos_end = input_npc_sword_pos_idx + 3 * trajectory_window
    input_npc_sword_dir_idx = input_npc_sword_pos_end
    input_npc_sword_dir_end = input_npc_sword_dir_idx + 3 * trajectory_window
    input_player_sword_pos_idx = input_npc_sword_dir_end
    input_player_sword_pos_end = input_player_sword_pos_idx + 3 * trajectory_window
    input_player_sword_dir_idx = input_player_sword_pos_end
    input_player_sword_dir_end = input_player_sword_dir_idx + 3 * trajectory_window
    input_npc_close_pos_idx = input_player_sword_dir_end
    input_npc_close_pos_end = input_npc_close_pos_idx + 3
    input_player_close_pos_idx = input_npc_close_pos_end
    input_player_close_pos_end = input_player_close_pos_idx + 3
    input_sword_close_dist_idx = input_player_close_pos_end
    input_sword_close_dist_end = input_sword_close_dist_idx + 1
    input_player_tip_elevation_start = input_sword_close_dist_end
    input_player_tip_elevation_end = input_player_tip_elevation_start + 1
    input_player_tip_azimuth_start = input_player_tip_elevation_end
    input_player_tip_azimuth_end = input_player_tip_azimuth_start + 1
    input_player_hilt_elevation_start = input_player_tip_azimuth_end
    input_player_hilt_elevation_end = input_player_hilt_elevation_start + 1
    input_player_hilt_azimuth_start = input_player_hilt_elevation_end
    input_player_hilt_azimuth_end = input_player_hilt_azimuth_start + 1
    input_style_idx = input_player_hilt_azimuth_end
    input_style_end = input_style_idx + len(STYLES)
    input_joint_pos_idx = input_style_end
    input_joint_pos_end = input_joint_pos_idx + 3 * num_joints
    input_joint_vel_idx = input_joint_pos_end
    input_joint_vel_end = input_joint_vel_idx + 3 * num_joints
    input_size = input_joint_vel_end
    # Output size
    output_root_vel_idx = 0
    output_root_vel_end = output_root_vel_idx + 2
    output_root_angvel_idx = output_root_vel_end
    output_root_angvel_end = output_root_angvel_idx + 1
    output_joint_pos_idx = output_root_angvel_end
    output_joint_pos_end = output_joint_pos_idx + 3 * num_joints
    output_joint_vel_idx = output_joint_pos_end
    output_joint_vel_end = output_joint_vel_idx + 3 * num_joints
    output_joint_xy_idx = output_joint_vel_end
    output_joint_xy_end = output_joint_xy_idx + 6 * num_joints
    output_size = output_joint_xy_end
    # Make example arrays
    inputs = np.empty((num_examples, input_size), dtype=dtype)
    outputs = np.empty((num_examples, output_size), dtype=dtype)
    angle = np.empty((num_examples,), dtype=dtype)
    coords = np.empty((num_examples, 3), dtype=dtype)
    if num_examples <= 0:
        return inputs, outputs, angle, coords
    # Root inverse and relative joint transforms
    root_inv = np.linalg.inv(root_transforms)
    joint_rootrel = root_inv[:, np.newaxis] @ joint_transforms
    joint_rootrel /= joint_rootrel[:, :, 3:, 3:]
    # Get sword trajectories
    npc_sword_pos, npc_sword_dir = get_sword_trajectories(
        npc_sword_tip_tx, NPC_SWORD_VECTOR, sword_root_transforms,
        trajectory_window, trajectory_step_size)
    player_sword_pos, player_sword_dir = get_sword_trajectories(
        player_sword_tip_tx, PLAYER_SWORD_VECTOR, sword_root_transforms,
        trajectory_window, trajectory_step_size)
    npc_sword_length = np.linalg.norm(NPC_SWORD_VECTOR)
    player_sword_length = np.linalg.norm(PLAYER_SWORD_VECTOR)
    for i_example in range(num_examples):
        i_frame = i_example + first_idx
        # Sword trajectories
        inputs[i_example, input_npc_sword_pos_idx:input_npc_sword_pos_end] = npc_sword_pos[i_frame].ravel()
        inputs[i_example, input_npc_sword_dir_idx:input_npc_sword_dir_end] = npc_sword_dir[i_frame].ravel()
        inputs[i_example, input_player_sword_pos_idx:input_player_sword_pos_end] = player_sword_pos[i_frame].ravel()
        inputs[i_example, input_player_sword_dir_idx:input_player_sword_dir_end] = player_sword_dir[i_frame].ravel()
        # Sword close points
        npc_tip_pos = npc_sword_pos[i_frame, -1]
        npc_hilt_pos = npc_tip_pos + npc_sword_length * npc_sword_dir[i_frame, -1]
        player_tip_pos = player_sword_pos[i_frame, -1]
        player_hilt_pos = player_tip_pos + player_sword_length * player_sword_dir[i_frame, -1]
        npc_close_point, player_close_point = SegmentsClosestPoints(
            npc_tip_pos, npc_hilt_pos, player_tip_pos, player_hilt_pos)
        close_dist = np.linalg.norm(player_close_point - npc_close_point)
        inputs[i_example, input_npc_close_pos_idx:input_npc_close_pos_end] = npc_close_point
        inputs[i_example, input_player_close_pos_idx:input_player_close_pos_end] = player_close_point
        inputs[i_example, input_sword_close_dist_idx:input_sword_close_dist_end] = close_dist
        # Player sword angles
        inputs[i_example, input_player_tip_elevation_start] = np.arctan2(player_tip_pos[..., 2], player_tip_pos[..., 0])
        inputs[i_example, input_player_tip_azimuth_start] = np.arctan2(player_tip_pos[..., 1], player_tip_pos[..., 0])
        inputs[i_example, input_player_hilt_elevation_start] = np.arctan2(player_hilt_pos[..., 2], player_hilt_pos[..., 0])
        inputs[i_example, input_player_hilt_azimuth_start] = np.arctan2(player_hilt_pos[..., 1], player_hilt_pos[..., 0])
        # Example angle
        angle[i_example] = np.arctan2(player_tip_pos[..., 1], player_tip_pos[..., 2])
        # Example grid coords
        coords[i_example, 0] = (player_tip_pos[..., 0] - COORD_DEPTH_RANGE[0]) / (COORD_DEPTH_RANGE[1] - COORD_DEPTH_RANGE[0])
        coords[i_example, 1] = player_tip_pos[..., 1] / COORD_WIDTH + 0.5
        coords[i_example, 2] = player_tip_pos[..., 2] / COORD_HEIGHT + 0.5
        # Style
        inputs[i_example, input_style_idx:input_style_end] = style[i_frame].ravel()
        # Joints
        joint_prev, joint_cur, joint_next = joint_rootrel[i_frame - 1:i_frame + 2]
        inputs[i_example, input_joint_pos_idx:input_joint_pos_end] = (joint_cur[:, :3, 3]).ravel()
        inputs[i_example, input_joint_vel_idx:input_joint_vel_end] = (joint_cur[:, :3, 3] - joint_prev[:, :3, 3]).ravel()
        outputs[i_example, output_joint_pos_idx:output_joint_pos_end] = joint_next[:, :3, 3].ravel()
        outputs[i_example, output_joint_vel_idx:output_joint_vel_end] = (joint_next[:, :3, 3] - joint_cur[:, :3, 3]).ravel()
        # Compute X and Y directions for each joint
        joint_xy = joint_next[:, :3, :3] @ np.array([[1, 0], [0, 1], [0, 0]])
        joint_xy = np.transpose(joint_xy, (0, 2, 1))
        outputs[i_example, output_joint_xy_idx:output_joint_xy_end] = joint_xy.ravel()
        # Output updates
        root_update_tx = root_inv[i_frame] @ root_transforms[i_frame + 1]
        root_update_tx /= root_update_tx[3:, 3:]
        root_vel = root_update_tx[:2, 3]
        outputs[i_example, output_root_vel_idx:output_root_vel_end] = root_vel
        root_angvel_x, root_angvel_y = (root_update_tx[:3, :3] @ np.array([1., 0., 0.]))[:2]
        outputs[i_example, output_root_angvel_idx:output_root_angvel_end] = np.arctan2(root_angvel_y, root_angvel_x)
    # Make sure angle is in [0, 2 * pi)
    angle = (angle + 2 * np.pi) % (2 * np.pi)
    return inputs, outputs, angle, coords


def float_feature(value):
    import tensorflow as tf
    value = np.atleast_1d(value)
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def write_records(inputs, outputs, weights, phase, coords, output_dir, file_name):
    import tensorflow as tf
    output_dir = Path(output_dir)
    if output_dir.exists() and not output_dir.is_dir():
        raise ValueError('Invalid directory {}.'.format(output_dir))
    output_dir = output_dir.resolve()
    # Write records
    output_path = Path(output_dir, '{}.tfrecord'.format(file_name))
    with tf.python_io.TFRecordWriter(str(output_path)) as writer:
        for i_example, (input_item, output_item) in enumerate(zip(inputs, outputs)):
            feature = {
                'x': float_feature(input_item),
                'y': float_feature(output_item),
            }
            if weights is not None:
                feature['weight'] = float_feature(weights[i_example])
            if phase is not None:
                feature['phase'] = float_feature(phase[i_example])
            if coords is not None:
                feature['coords'] = float_feature(coords[i_example])
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            writer.write(example.SerializeToString())


def plot_styles(styles):
    if styles.shape[1] != len(STYLES): raise ValueError
    plt.figure()
    for i, style in enumerate(STYLES):
        plt.plot(styles[:, i], label=style)
    plt.legend()
    plt.tight_layout()


def plot_axes(tx, pos=None):
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    if pos is None: pos = np.zeros(3, dtype=tx.dtype)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    tx = tx[:3, :3]
    axes = tx @ np.eye(3, dtype=tx.dtype)
    for i, axis in enumerate(axes.T):
        ax.plot3D([pos[0], pos[0] + axis[0]],
                  [pos[1], pos[1] + axis[1]],
                  [pos[2], pos[2] + axis[2]],
                  c='rgb'[i])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_xlim(-1, 1); ax.set_xticks([-1, 0, 1])
    ax.set_ylim(-1, 1); ax.set_yticks([-1, 0, 1])
    ax.set_zlim(-1, 1); ax.set_zticks([-1, 0, 1])
    ax.invert_yaxis()
    ax.set_aspect('equal')
    # fig.tight_layout()


if __name__ == '__main__':
    if len(sys.argv) < 3:
        print('Usage: {} <data dir> <output dir> [time scale]'.format(sys.argv[0]), file=sys.stderr)
        exit(1)
    main(sys.argv[1], sys.argv[2], float(sys.argv[3]) if len(sys.argv) > 3 else 1.0)
