#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import random
import math
from pathlib import Path
import argparse

import numpy as np


SEED = 100
SPLINE_POINTS = None
SPLINE_CLOSED = False
HIDDEN_LAYERS = None
USE_BIAS = True
MAX_STEPS = 1_000_000
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 200_000
EVALUATION_STEPS = 10_000
PATIENCE = -1
DROPOUT_KEEP_PROB = 0.7
REGULARIZATION = 0.1
LEARNING_RATE = 0.0001
LINEAR_EVALUATIONS = 0
CONSTANT_EVALUATIONS = 0

NUM_JOINTS = None
TRAJECTORY_WINDOW = None
NUM_STYLES = None


def main(train_data_dir, test_data_dir, logdir):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
    from gfnn import Gfnn

    # Random seed
    random.seed(SEED)
    np.random.seed(SEED)
    tf.set_random_seed(SEED)

    # Choose mode
    spline_points = list(SPLINE_POINTS)
    spline_closed = False
    linear_eval = LINEAR_EVALUATIONS
    constant_eval = CONSTANT_EVALUATIONS
    parse_example_fn = None
    preprocess_fn = None
    if any(v <= 0 for v in linear_eval):
        linear_eval = 0
    if any(v == 0 for v in constant_eval):
        constant_eval = 0
    if len(spline_points) == 0:
        spline_points = [1, 1, 1]
        if linear_eval:
            linear_eval = 1
        if constant_eval:
            constant_eval = 1
    if len(spline_points) == 1:
        parse_example_fn = lambda self, example_proto: parse_example_phase(example_proto, self.input_size, self.output_size)
        preprocess_fn = lambda self, input_, coords: (input_, (coords % (2 * math.pi)) / (2 * math.pi))
        spline_closed = True
    if len(spline_points) == 2:
        spline_points.insert(0, 1)
        if linear_eval:
            linear_eval = [1] + list(np.broadcast_to(linear_eval, 2))
        if constant_eval:
            constant_eval = [1] + list(np.broadcast_to(constant_eval, 2))
    if len(spline_points) > 3:
        raise ValueError('invalid number of spline points')

    # Input size
    input_npc_sword_pos_idx = 0
    input_npc_sword_pos_end = input_npc_sword_pos_idx + 3 * TRAJECTORY_WINDOW
    input_npc_sword_dir_idx = input_npc_sword_pos_end
    input_npc_sword_dir_end = input_npc_sword_dir_idx + 3 * TRAJECTORY_WINDOW
    input_player_sword_pos_idx = input_npc_sword_dir_end
    input_player_sword_pos_end = input_player_sword_pos_idx + 3 * TRAJECTORY_WINDOW
    input_player_sword_dir_idx = input_player_sword_pos_end
    input_player_sword_dir_end = input_player_sword_dir_idx + 3 * TRAJECTORY_WINDOW
    input_npc_close_pos_idx = input_player_sword_dir_end
    input_npc_close_pos_end = input_npc_close_pos_idx + 3
    input_player_close_pos_idx = input_npc_close_pos_end
    input_player_close_pos_end = input_player_close_pos_idx + 3
    input_sword_close_dist_idx = input_player_close_pos_end
    input_sword_close_dist_end = input_sword_close_dist_idx + 1
    input_player_tip_elevation_start = input_sword_close_dist_end
    input_player_tip_elevation_end = input_player_tip_elevation_start + 1
    input_player_tip_azimuth_start = input_player_tip_elevation_end
    input_player_tip_azimuth_end = input_player_tip_azimuth_start + 1
    input_player_hilt_elevation_start = input_player_tip_azimuth_end
    input_player_hilt_elevation_end = input_player_hilt_elevation_start + 1
    input_player_hilt_azimuth_start = input_player_hilt_elevation_end
    input_player_hilt_azimuth_end = input_player_hilt_azimuth_start + 1
    input_style_idx = input_player_hilt_azimuth_end
    input_style_end = input_style_idx + NUM_STYLES
    input_joint_pos_idx = input_style_end
    input_joint_pos_end = input_joint_pos_idx + 3 * NUM_JOINTS
    input_joint_vel_idx = input_joint_pos_end
    input_joint_vel_end = input_joint_vel_idx + 3 * NUM_JOINTS
    input_size = input_joint_vel_end
    # Output size
    output_root_vel_idx = 0
    output_root_vel_end = output_root_vel_idx + 2
    output_root_angvel_idx = output_root_vel_end
    output_root_angvel_end = output_root_angvel_idx + 1
    output_joint_pos_idx = output_root_angvel_end
    output_joint_pos_end = output_joint_pos_idx + 3 * NUM_JOINTS
    output_joint_vel_idx = output_joint_pos_end
    output_joint_vel_end = output_joint_vel_idx + 3 * NUM_JOINTS
    output_joint_xy_idx = output_joint_vel_end
    output_joint_xy_end = output_joint_xy_idx + 6 * NUM_JOINTS
    output_size = output_joint_xy_end

    logdir = Path(logdir)
    if logdir.exists() and not logdir.is_dir():
        raise ValueError('{} is not a valid directory.'.format(logdir))
    if not logdir.exists():
        logdir.mkdir(parents=True)

    print('Preprocessing data...')
    train_tfrecord_pattern = str(Path(train_data_dir, '*.tfrecord'))
    test_tfrecord_pattern = str(Path(test_data_dir, '*.tfrecord'))
    x_mean, x_scale, y_mean, y_scale = compute_mean_std_records(train_tfrecord_pattern, input_size, output_size)
    # Readjust scaling
    # X
    x_scale[input_npc_sword_pos_idx:input_npc_sword_pos_end] = x_scale[input_npc_sword_pos_idx:input_npc_sword_pos_end].mean()
    x_scale[input_npc_sword_dir_idx:input_npc_sword_dir_end] = x_scale[input_npc_sword_dir_idx:input_npc_sword_dir_end].mean()
    x_scale[input_player_sword_pos_idx:input_player_sword_pos_end] = x_scale[input_player_sword_pos_idx:input_player_sword_pos_end].mean()
    x_scale[input_player_sword_dir_idx:input_player_sword_dir_end] = x_scale[input_player_sword_dir_idx:input_player_sword_dir_end].mean()
    x_scale[input_npc_close_pos_idx:input_npc_close_pos_end] = x_scale[input_npc_close_pos_idx:input_npc_close_pos_end].mean()
    x_scale[input_player_close_pos_idx:input_player_close_pos_end] = x_scale[input_player_close_pos_idx:input_player_close_pos_end].mean()
    x_scale[input_sword_close_dist_idx:input_sword_close_dist_end] = x_scale[input_sword_close_dist_idx:input_sword_close_dist_end].mean()
    x_scale[input_player_tip_elevation_start:input_player_hilt_azimuth_end] = x_scale[input_player_tip_elevation_start:input_player_hilt_azimuth_end].mean()
    if NUM_STYLES > 0:
        x_scale[input_style_idx:input_style_end] = x_scale[input_style_idx:input_style_end].mean()
    x_scale[input_joint_pos_idx:input_joint_pos_end] = x_scale[input_joint_pos_idx:input_joint_pos_end].mean()
    x_scale[input_joint_vel_idx:input_joint_vel_end] = x_scale[input_joint_vel_idx:input_joint_vel_end].mean()
    x_scale[np.abs(x_scale) < 1e-6] = 1
    # Y
    y_scale[output_root_vel_idx:output_root_vel_end] = y_scale[output_root_vel_idx:output_root_vel_end].mean()
    y_scale[output_root_angvel_idx:output_root_angvel_end] = y_scale[output_root_angvel_idx:output_root_angvel_end].mean()
    y_scale[output_joint_pos_idx:output_joint_pos_end] = y_scale[output_joint_pos_idx:output_joint_pos_end].mean()
    y_scale[output_joint_vel_idx:output_joint_vel_end] = y_scale[output_joint_vel_idx:output_joint_vel_end].mean()
    y_scale[output_joint_xy_idx:output_joint_xy_end] = y_scale[output_joint_xy_idx:output_joint_xy_end].mean()
    y_scale[np.abs(y_scale) < 1e-6] = 1
    # Save to file
    x_scale.tofile(str(Path(logdir, 'x_scale.txt')), sep='\n')
    x_mean.tofile(str(Path(logdir, 'x_mean.txt')), sep='\n')
    y_scale.tofile(str(Path(logdir, 'y_scale.txt')), sep='\n')
    y_mean.tofile(str(Path(logdir, 'y_mean.txt')), sep='\n')

    # two_pi = tf.constant(2 * np.pi, dtype=tf.float32)
    # Preprocessing to discard sword tip positions
    def preprocess_scaled_discard_swordtip(self, input_):
        return tf.concat([
            input_[..., :input_npc_sword_pos_idx],
            input_[..., input_npc_sword_dir_idx:input_player_sword_pos_idx],
            input_[..., input_player_sword_dir_idx:],
        ], axis=-1)

    # Preprocessing to discard joint velocity and sword tip positions
    def preprocess_scaled_discard_swordtip_velocity(self, input_):
        return tf.concat([
            input_[..., :input_npc_sword_pos_idx],
            input_[..., input_npc_sword_dir_idx:input_player_sword_pos_idx],
            input_[..., input_player_sword_dir_idx:input_joint_vel_idx],
            input_[..., input_joint_vel_end:],
        ], axis=-1)

    # Preprocessing to discard joint velocity and sword tip positions
    def preprocess_scaled_discard_joints(self, input_):
        return tf.concat([
            input_[..., :output_joint_pos_idx]
        ], axis=-1)

    # Postprocessing to normalize X and Y axes
    def postprocess_xy(self, output):
        output_xy = output[:, output_joint_xy_idx:output_joint_xy_end]
        output_xy = tf.reshape(output_xy, (-1, NUM_JOINTS, 2, 3))
        output_xy /= tf.sqrt(tf.reduce_sum(tf.square(output_xy), axis=-1, keepdims=True))
        output_xy = tf.reshape(output_xy, (-1, NUM_JOINTS * 6))
        return tf.concat([output[:, :output_joint_xy_idx], output_xy, output[:, output_joint_xy_end:]], axis=-1)

    print('Creating model in {}...'.format(logdir))
    net = Gfnn(input_size, output_size, HIDDEN_LAYERS,
               spline_points, spline_closed,
               logdir, dtype=tf.float32, use_bias=USE_BIAS, regularization=REGULARIZATION,
               dropout=1 - DROPOUT_KEEP_PROB,
               input_mean=x_mean, input_scale=x_scale,
               output_mean=y_mean, output_scale=y_scale,
               parse_example_fn=parse_example_fn,
               preprocess_fn=preprocess_fn,
               preprocess_scaled_fn=None,
               postprocess_fn=postprocess_xy,
               devices=get_devices(),
               seed=SEED)

    # print('Writing test model...')
    # tf.train.write_graph(
    #     net.make_runtime_graph(),
    #     net.logdir, 'simplenet_test.tf', as_text=False)

    # Write model summary
    with Path(logdir, 'model_summary.txt').open('w') as f:
        f.write(f'input_size: {net.input_size}\n')
        f.write(f'output_size: {net.output_size}\n')
        f.write(f'spline_points: {net.spline_points}\n')
        f.write(f'spline_close: {net.spline_close}\n')
        f.write(f'hidden_layers: {net.hidden_layers}\n')
        f.write(f'use_bias: {net.use_bias}\n')
        f.write(f'num_params: {net.num_params}\n')
        f.write(f'linear_eval: {linear_eval}\n')
        f.write(f'constant_eval: {constant_eval}\n')

    print('Training...')
    net.train(training_record_file_pattern=train_tfrecord_pattern,
              test_record_file_pattern=test_tfrecord_pattern,
              max_steps=MAX_STEPS,
              batch_size=BATCH_SIZE,
              shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
              learning_rate=LEARNING_RATE,
              evaluation_steps=EVALUATION_STEPS,
              patience=PATIENCE)
    tf.train.write_graph(net.make_runtime_graph(full=True, linear=False, constant=False),
                         net.logdir, 'gfnn_cubic.tf', as_text=False)
    if linear_eval:
        tf.train.write_graph(net.make_runtime_graph(full=False, linear=linear_eval, constant=False),
                             net.logdir, 'gfnn_linear.tf', as_text=False)
    if constant_eval:
        tf.train.write_graph(net.make_runtime_graph(full=False, linear=False, constant=constant_eval),
                             net.logdir, 'gfnn_constant.tf', as_text=False)
    print('Done.')

def get_devices():
    from tensorflow.python.client import device_lib
    devices = device_lib.list_local_devices()
    target_devices = [d for d in devices if d.device_type == 'GPU']
    if not target_devices: target_devices = devices
    return [d.name for d in target_devices]

def compute_mean_std_records(record_file_pattern, x_size, y_size):
    import tensorflow as tf
    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = tf.data.Dataset.list_files(record_file_pattern)
        dataset = dataset.apply(tf.data.experimental.parallel_interleave(
            tf.data.TFRecordDataset, cycle_length=os.cpu_count(), sloppy=False))
        dataset = dataset.map(lambda example_proto: parse_example(example_proto, x_size, y_size),
                              num_parallel_calls=os.cpu_count())
        dataset = dataset.batch(10000)
        dataset = dataset.prefetch(1)
        it = dataset.make_one_shot_iterator()
        x, y = it.get_next()
        (x_mean, x_std), x_upd = running_moments(x)
        (y_mean, y_std), y_upd = running_moments(y)
        sess.run(tf.global_variables_initializer())
        while True:
            try:
                sess.run([x_upd, y_upd])
            except tf.errors.OutOfRangeError: break
        return sess.run((x_mean, x_std, y_mean, y_std))

def running_moments(data):
    import tensorflow as tf
    # Welford algorithm
    size = data.shape[1].value
    first = tf.Variable(True, dtype=tf.bool, trainable=False)
    m = tf.Variable([0.] * size, dtype=tf.float32, trainable=False)
    s = tf.Variable([0.] * size, dtype=tf.float32, trainable=False)
    k = tf.Variable(0, dtype=tf.float32, trainable=False)
    d = tf.cast(tf.shape(data)[0], tf.float32)
    k_new = k + d
    data_mean = tf.reduce_mean(data, axis=0)
    m_val = tf.cond(first, lambda: data_mean, lambda: m)
    m_new = m_val + d * (data_mean - m_val) / k_new
    s_new = s + tf.abs(tf.reduce_sum((data - m) * (data - m_new), axis=0))
    with tf.control_dependencies([tf.assign(m, m_new), tf.assign(s, s_new), tf.assign(k, k_new)]):
        update = tf.group([tf.assign(first, False)])
    mean = tf.identity(m)
    std = tf.sqrt(s / k)
    # std = s
    return (mean, std), update

def parse_example(example_proto, x_size, y_size):
    import tensorflow as tf
    features = {'x': tf.FixedLenFeature((x_size,), tf.float32),
                'y': tf.FixedLenFeature((y_size,), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['y']

def parse_example_phase(example_proto, input_size, output_size):
    import tensorflow as tf
    features = {'x': tf.FixedLenFeature((input_size,), tf.float32),
                'phase': tf.FixedLenFeature((), tf.float32),
                'y': tf.FixedLenFeature((output_size,), tf.float32),
                'weight': tf.FixedLenFeature((), tf.float32, default_value=1.)}
    parsed_features = tf.parse_single_example(example_proto, features)
    phase = tf.expand_dims(parsed_features['phase'], 0)
    return parsed_features['x'], phase, parsed_features['y'], parsed_features['weight']

def parse_args(args=None):
    global SPLINE_POINTS, SPLINE_CLOSED, HIDDEN_LAYERS, USE_BIAS
    global NUM_JOINTS, TRAJECTORY_WINDOW, NUM_STYLES
    global MAX_STEPS, BATCH_SIZE, SHUFFLE_BUFFER_SIZE, EVALUATION_STEPS, PATIENCE
    global DROPOUT_KEEP_PROB, REGULARIZATION, LEARNING_RATE
    global LINEAR_EVALUATIONS, CONSTANT_EVALUATIONS
    global SEED

    parser = argparse.ArgumentParser(usage='%(prog)s [-h] <argument> ...',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     argument_default=argparse.SUPPRESS)
    dir_group = parser.add_argument_group('directories')
    dir_group.add_argument('--train-data-dir', help='directory with *.tfrecord files for training', required=True)
    dir_group.add_argument('--test-data-dir', help='directory with *.tfrecord files for test', required=True)
    dir_group.add_argument('--log-dir', help='directory where training results are saved', required=True)
    data_group = parser.add_argument_group('data features')
    data_group.add_argument('--num-joints', help='number of skeletal joints', type=int, required=True)
    data_group.add_argument('--trajectory-window', help='size of trajectory window', type=int, required=True)
    data_group.add_argument('--num-styles', help='number of style labels', type=int, required=True)
    model_group = parser.add_argument_group('model architecture')
    model_group.add_argument('--spline-points', help='number of spline points', metavar='NUM_POINTS[,...]', type=num_list, default='none')
    model_group.add_argument('--hidden-layers', help='hidden layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--use-bias', help='whether to use bias units', type=str2bool, default=str(USE_BIAS).lower())
    train_group = parser.add_argument_group('training settings')
    train_group.add_argument('--batch-size', help='training batch size.', type=int, default=BATCH_SIZE)
    train_group.add_argument('--max-steps', help='maximum number of training steps', type=int, default=MAX_STEPS)
    train_group.add_argument('--patience', help='number of evaluations without improvement before stopping (negative for no limit)', type=int, default=PATIENCE)
    train_group.add_argument('--evaluation-steps', help='number of steps per evaluation', type=int, default=EVALUATION_STEPS)
    train_group.add_argument('--shuffle-buffer-size', help='size of the data shuffling buffer', type=int, default=SHUFFLE_BUFFER_SIZE)
    train_group.add_argument('--learning-rate', help='learning rate', type=float, default=LEARNING_RATE)
    train_group.add_argument('--regularization', help='regularization weight', type=float, default=REGULARIZATION)
    train_group.add_argument('--dropout', help='dropout rate', type=float, default=round(1 - DROPOUT_KEEP_PROB, 6))
    export_group = parser.add_argument_group('export settings')
    export_group.add_argument('--linear-eval', help='number of linear evaluations', metavar='NUM_EVAL[,...]',
                              type=num_list, default=str(LINEAR_EVALUATIONS))
    export_group.add_argument('--constant-eval', help='number of constant evaluations', metavar='NUM_EVAL[,...]',
                              type=num_list, default=str(CONSTANT_EVALUATIONS))
    other_group = parser.add_argument_group('other options')
    other_group.add_argument('--seed', help='random seed', type=int, default=SEED)
    parsed_args = parser.parse_args(args=args)

    train_data_dir = parsed_args.train_data_dir
    test_data_dir = parsed_args.test_data_dir
    log_dir = parsed_args.log_dir
    SPLINE_POINTS = parsed_args.spline_points or []
    HIDDEN_LAYERS = parsed_args.hidden_layers or []
    USE_BIAS = parsed_args.use_bias
    NUM_JOINTS = parsed_args.num_joints
    TRAJECTORY_WINDOW = parsed_args.trajectory_window
    NUM_STYLES = parsed_args.num_styles
    MAX_STEPS = parsed_args.max_steps
    BATCH_SIZE = parsed_args.batch_size
    SHUFFLE_BUFFER_SIZE = parsed_args.shuffle_buffer_size
    EVALUATION_STEPS = parsed_args.evaluation_steps
    PATIENCE = parsed_args.patience
    DROPOUT_KEEP_PROB = 1 - parsed_args.dropout
    REGULARIZATION = parsed_args.regularization
    LEARNING_RATE = parsed_args.learning_rate
    LINEAR_EVALUATIONS = parsed_args.linear_eval
    CONSTANT_EVALUATIONS = parsed_args.constant_eval
    SEED = parsed_args.seed

    return train_data_dir, test_data_dir, log_dir

def num_list(v):
    try:
        v = v.strip().lower()
        if not v or v == 'none': return []
        return list(map(int, v.split(',')))
    except:
        raise argparse.ArgumentTypeError('expected a comma-separated list of numbers.')

def str2bool(v):
    v = str(v).strip().lower()
    if v in {'yes', 'true', 't', 'y', '1'}:
        return True
    elif v in {'no', 'false', 'f', 'n', '0'}:
        return False
    else:
        raise argparse.ArgumentTypeError('boolean value expected.')


if __name__ == '__main__':
    main(*parse_args())
