#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import random
import math
from pathlib import Path
import argparse
import numpy as np

SEED = 100
ENCODER_LAYERS = None
HIDDEN_LAYERS = None
DECODER_LAYERS = None
NUM_GATED = 1
GATING_HIDDEN_LAYERS = None
GATING_INDICES = None
FUNCTIONED_OUTPUT = True
USE_BIAS = True
REMOVE_STYLES = False
REMOVE_PHASE = False
BATCH_SIZE = 32
LEARNING_RATE = 0.0001
MAX_STEPS = 1_000_000
PATIENCE = -1
EVALUATION_STEPS = 10_000
FREEZE_STEPS = 0
SHUFFLE_BUFFER_SIZE = 200_000
DROPOUT_KEEP_PROB = 1.0
REGULARIZATION = 0.0
NOISE = 0.0
USE_WEIGHTS = True

NUM_FEET = None
NUM_JOINTS = None
PAST_TRAJECTORY_WINDOW = None
FUTURE_TRAJECTORY_WINDOW = None
NUM_STYLES = None
WITH_TERRAIN = None
WITH_CONTACT = None
IMPORTANT_JOINTS = None
IMPORTANCE_BOOST = 2
USE_GPU = True

def main(train_data_dir, test_data_dir, logdir):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    if not USE_GPU:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    import tensorflow as tf
    from mann import Mann
    if hasattr(tf, 'get_logger'):
        import logging
        tf_log = tf.get_logger()
        tf_log.parent.handlers.clear()
    # Random seed
    random.seed(SEED)
    np.random.seed(SEED)
    tf.set_random_seed(SEED)
    trajectory_window = PAST_TRAJECTORY_WINDOW + FUTURE_TRAJECTORY_WINDOW

    # X preprocessing
    # Trajectory
    x_trajectory_start = 0
    # Past positions
    x_past_pos_start = x_trajectory_start
    x_past_pos_end = x_past_pos_start + 2 * PAST_TRAJECTORY_WINDOW
    # Future positions
    x_future_pos_start = x_past_pos_end
    x_future_pos_end = x_future_pos_start + 2 * FUTURE_TRAJECTORY_WINDOW
    # Past directions
    x_past_dir_start = x_future_pos_end
    x_past_dir_end = x_past_dir_start + 2 * PAST_TRAJECTORY_WINDOW
    # Future directions
    x_future_dir_start = x_past_dir_end
    x_future_dir_end = x_future_dir_start + 2 * FUTURE_TRAJECTORY_WINDOW
    # Style
    x_style_start = x_future_dir_end
    x_style_end = x_style_start + NUM_STYLES * trajectory_window
    # End of trajectory
    x_trajectory_end = x_style_end
    # Joints
    x_joint_start = x_trajectory_end
    # Position
    x_joint_pos_start = x_joint_start
    x_joint_pos_end = x_joint_pos_start + 3 * NUM_JOINTS
    # Velocity
    x_joint_vel_start = x_joint_pos_end
    x_joint_vel_end = x_joint_vel_start + 3 * NUM_JOINTS
    # End of joints
    x_joint_end = x_joint_vel_end

    # Terrain: left, center, right (at the end of the vector for some reason)
    x_terr_start = x_joint_end
    x_past_terr_start = x_terr_start
    x_past_terr_end = x_past_terr_start + (3 * PAST_TRAJECTORY_WINDOW if WITH_TERRAIN else 0)
    x_future_terr_start = x_past_terr_end
    x_future_terr_end = x_future_terr_start + (3 * FUTURE_TRAJECTORY_WINDOW if WITH_TERRAIN else 0)
    # End of terrain
    x_terr_end = x_future_terr_end
    # End of input
    x_end = x_terr_end
    input_size = x_end

    # Y preprocessing
    y_root_vel_start = 0
    y_root_vel_end = y_root_vel_start + 2
    y_root_angvel_start = y_root_vel_end
    y_root_angvel_end = y_root_angvel_start + 1
    y_phase_start = y_root_angvel_end
    y_phase_end = y_phase_start + NUM_FEET
    # Contact
    y_contact_start = y_phase_end
    y_contact_end = y_contact_start + (NUM_FEET if WITH_CONTACT else 0)
    # Trajectory
    y_trajectory_start = y_contact_end
    # Trajectory future positions
    y_future_pos_start = y_trajectory_start
    y_future_pos_end = y_future_pos_start + 2 * FUTURE_TRAJECTORY_WINDOW
    # Trajectory future directions
    y_future_dir_start = y_future_pos_end
    y_future_dir_end = y_future_dir_start + 2 * FUTURE_TRAJECTORY_WINDOW
    # End of trajectory
    y_trajectory_end = y_future_dir_end
    # Joints
    y_joint_start = y_trajectory_end
    # Position
    y_joint_pos_start = y_joint_start
    y_joint_pos_end = y_joint_pos_start + 3 * NUM_JOINTS
    # Velocity
    y_joint_vel_start = y_joint_pos_end
    y_joint_vel_end = y_joint_vel_start + 3 * NUM_JOINTS
    # Rotation
    # y_joint_quat_start = y_joint_vel_end
    # y_joint_quat_end = y_joint_quat_start + 4 * NUM_JOINTS
    y_joint_xy_start = y_joint_vel_end
    y_joint_xy_end = y_joint_xy_start + 6 * NUM_JOINTS
    # End of joints
    y_joint_end = y_joint_xy_end
    # End of output
    y_end = y_joint_end
    output_size = y_end

    logdir = Path(logdir)
    if logdir.exists() and not logdir.is_dir():
        raise ValueError('{} is not a valid directory.'.format(logdir))
    if not logdir.exists():
        logdir.mkdir(parents=True)

    print('Preprocessing data...')
    train_tfrecord_pattern = str(Path(train_data_dir, '*.tfrecord'))
    test_tfrecord_pattern = str(Path(test_data_dir, '*.tfrecord'))
    x_mean, x_scale, y_mean, y_scale = compute_mean_std_records(train_tfrecord_pattern, input_size, output_size)
    # Readjust scaling
    # X
    x_scale[x_past_pos_start:x_past_pos_end] = x_scale[x_past_pos_start:x_past_pos_end].mean()
    x_scale[x_future_pos_start:x_future_pos_end] = x_scale[x_future_pos_start:x_future_pos_end].mean()
    x_scale[x_past_dir_start:x_past_dir_end] = x_scale[x_past_dir_start:x_past_dir_end].mean()
    x_scale[x_future_dir_start:x_future_dir_end] = x_scale[x_future_dir_start:x_future_dir_end].mean()
    if NUM_STYLES > 0:
        x_scale[x_style_start:x_style_end] = x_scale[x_style_start:x_style_end].mean()
    x_scale[x_joint_pos_start:x_joint_pos_end] = x_scale[x_joint_pos_start:x_joint_pos_end].mean()
    x_scale[x_joint_vel_start:x_joint_vel_end] = x_scale[x_joint_vel_start:x_joint_vel_end].mean()
    if WITH_TERRAIN:
        x_scale[x_past_terr_start:x_future_terr_end] = x_scale[x_past_terr_start:x_future_terr_end].mean()
    x_scale[np.abs(x_scale) < 1e-6] = 1
    # Y
    y_scale[y_root_vel_start:y_root_vel_end] = y_scale[y_root_vel_start:y_root_vel_end].mean()
    y_scale[y_root_angvel_start:y_root_angvel_end] = y_scale[y_root_angvel_start:y_root_angvel_end].mean()
    y_scale[y_phase_start:y_phase_end] = y_scale[y_phase_start:y_phase_end].mean()
    if WITH_CONTACT:
        y_scale[y_contact_start:y_contact_end] = y_scale[y_contact_start:y_contact_end].mean()
    y_scale[y_future_pos_start:y_future_pos_end] = y_scale[y_future_pos_start:y_future_pos_end].mean()
    y_scale[y_future_dir_start:y_future_dir_end] = y_scale[y_future_dir_start:y_future_dir_end].mean()
    y_scale[y_joint_pos_start:y_joint_pos_end] = y_scale[y_joint_pos_start:y_joint_pos_end].mean()
    y_scale[y_joint_vel_start:y_joint_vel_end] = y_scale[y_joint_vel_start:y_joint_vel_end].mean()
    y_scale[y_joint_xy_start:y_joint_xy_end] = y_scale[y_joint_xy_start:y_joint_xy_end].mean()
    y_scale[np.abs(y_scale) < 1e-6] = 1
    # Save to file
    x_scale.tofile(str(Path(logdir, 'x_scale.txt')), sep='\n')
    x_mean.tofile(str(Path(logdir, 'x_mean.txt')), sep='\n')
    y_scale.tofile(str(Path(logdir, 'y_scale.txt')), sep='\n')
    y_mean.tofile(str(Path(logdir, 'y_mean.txt')), sep='\n')

    # Parse function
    parse_example_fn = lambda self, example_proto: parse_example_weight(
        example_proto, input_size, output_size)
    # Remove phase
    output_size_net = output_size
    y_mean_net = y_mean
    y_scale_net = y_scale
    if REMOVE_PHASE:
        parse_example_fn = example_remove_phase(parse_example_fn, y_phase_start, y_phase_end)
        y_mean_net = np.r_[y_mean[:y_phase_start], y_mean[y_phase_end:]]
        y_scale_net = np.r_[y_scale[:y_phase_start], y_scale[y_phase_end:]]
        output_size_net = output_size - NUM_FEET

    # Preprocess function
    preprocess_fn = None

    preprocess_scaled_fn = lambda self, x, g: (x, g)
    def preprocess_scaled_fn(self, input_, input_gated):
        if GATING_INDICES:
            input_gated = tf.gather(input_, GATING_INDICES, axis=1)
        return input_, input_gated
    if REMOVE_STYLES:
        preprocess_scaled_fn = remove_styles(preprocess_scaled_fn, x_style_start, x_style_end)

    # Output and postprocessing
    postprocess_fn = lambda self, output: output
    if REMOVE_PHASE:
        postprocess_fn = add_zero_phases(postprocess_fn, y_phase_start, y_phase_end)
    # Normalize X and Y axes
    postprocess_fn = normalize_xy(postprocess_fn, y_joint_xy_start, y_joint_xy_end)

    # Loss function
    def loss_fn(self, label_unscaled, output_unscaled, output_logits):
        d = tf.squared_difference(label_unscaled, output_unscaled)
        if REMOVE_PHASE:
            d = add_zero_phases(lambda self, x: x, y_phase_start, y_phase_end)(self, d)
        if IMPORTANT_JOINTS and IMPORTANCE_BOOST != 1:
            # Scale joint values
            d_prev = d[:, :y_joint_pos_start]
            joint_pos = tf.reshape(d[:, y_joint_pos_start:y_joint_pos_end], (-1, NUM_JOINTS, 3))
            d_pos = tf.reshape(joint_importance_scale(joint_pos), (-1, NUM_JOINTS * 3))
            joint_vel = tf.reshape(d[:, y_joint_vel_start:y_joint_vel_end], (-1, NUM_JOINTS, 3))
            d_vel = tf.reshape(joint_importance_scale(joint_vel), (-1, NUM_JOINTS * 3))
            joint_xy = tf.reshape(d[:, y_joint_xy_start:y_joint_xy_end], (-1, NUM_JOINTS, 6))
            d_xy = tf.reshape(joint_importance_scale(joint_xy), (-1, NUM_JOINTS * 6))
            d_post = d[:, y_joint_xy_end:]
            d = tf.concat([d_prev, d_pos, d_vel, d_xy, d_post], axis=1)
        return tf.reduce_sum(d, axis=1)

    print('Creating model in {}...'.format(logdir))
    net = Mann(input_size, output_size_net, ENCODER_LAYERS, HIDDEN_LAYERS, DECODER_LAYERS,
               NUM_GATED, GATING_HIDDEN_LAYERS,
               logdir, dtype=tf.float32, functioned_output=FUNCTIONED_OUTPUT, use_bias=USE_BIAS,
               regularization=REGULARIZATION, dropout=1 - DROPOUT_KEEP_PROB, noise=NOISE, use_weights=USE_WEIGHTS,
               input_mean=x_mean, input_scale=x_scale,
               output_mean=y_mean_net, output_scale=y_scale_net,
               parse_example_fn=parse_example_fn,
               preprocess_fn=preprocess_fn,
               preprocess_scaled_fn=preprocess_scaled_fn,
               postprocess_fn=postprocess_fn,
               hidden_activation_fn=tf.nn.relu,
               output_activation_fn=None,
               loss_fn=loss_fn,
               devices=get_devices(),
               seed=SEED)

    # Write model summary
    with Path(logdir, 'model_summary.txt').open('w') as f:
        f.write(f'input_size: {net.input_size}\n')
        f.write(f'output_size: {net.output_size}\n')
        f.write(f'actual_input_size: {net.actual_input_size}\n')
        f.write(f'actual_input_gate_size: {net.actual_input_gate_size}\n')
        f.write(f'actual_output_size: {net.actual_output_size}\n')
        f.write(f'num_gated: {net.num_gated}\n')
        f.write(f'gating_hidden_layers: {net.gating_hidden_layers}\n')
        f.write(f'encoder_layers: {net.encoder_layers}\n')
        f.write(f'hidden_layers: {net.hidden_layers}\n')
        f.write(f'decoder_layers: {net.decoder_layers}\n')
        f.write(f'functioned_output: {net.functioned_output}\n')
        f.write(f'use_bias: {net.use_bias}\n')
        f.write(f'num_params: {net.num_params}\n')

    # Save command line
    with Path(logdir, 'command_line.txt').open('w') as f:
        print(' '.join(sys.argv), file=f)

    print('Training...')
    net.train(training_record_file_pattern=train_tfrecord_pattern,
              test_record_file_pattern=test_tfrecord_pattern,
              max_steps=MAX_STEPS,
              batch_size=BATCH_SIZE,
              shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
              learning_rate=LEARNING_RATE,
              evaluation_steps=EVALUATION_STEPS,
              patience=PATIENCE,
              freeze_steps=FREEZE_STEPS)
    print()
    tf.train.write_graph(net.make_runtime_graph(), net.logdir, 'mann_rt.tf', as_text=False)
    print('Done.')

def get_devices():
    from tensorflow.python.client import device_lib
    devices = device_lib.list_local_devices()
    target_devices = [d for d in devices if d.device_type == 'GPU']
    if not target_devices: target_devices = devices
    return [d.name for d in target_devices]

def compute_mean_std_records(record_file_pattern, x_size, y_size):
    import tensorflow as tf
    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = tf.data.Dataset.list_files(record_file_pattern)
        dataset = dataset.apply(tf.data.experimental.parallel_interleave(
            tf.data.TFRecordDataset, cycle_length=os.cpu_count(), sloppy=False))
        dataset = dataset.map(lambda example_proto: parse_example(example_proto, x_size, y_size),
                              num_parallel_calls=os.cpu_count())
        dataset = dataset.batch(10000)
        dataset = dataset.prefetch(1)
        it = dataset.make_one_shot_iterator()
        x, y = it.get_next()
        (x_mean, x_std), x_upd = running_moments(x)
        (y_mean, y_std), y_upd = running_moments(y)
        sess.run(tf.global_variables_initializer())
        while True:
            try:
                sess.run([x_upd, y_upd])
            except tf.errors.OutOfRangeError: break
        return sess.run((x_mean, x_std, y_mean, y_std))

def running_moments(data):
    import tensorflow as tf
    # Welford algorithm
    size = data.shape[1].value
    first = tf.Variable(True, dtype=tf.bool, trainable=False)
    m = tf.Variable([0.] * size, dtype=tf.float32, trainable=False)
    s = tf.Variable([0.] * size, dtype=tf.float32, trainable=False)
    k = tf.Variable(0, dtype=tf.float32, trainable=False)
    d = tf.cast(tf.shape(data)[0], tf.float32)
    k_new = k + d
    data_mean = tf.reduce_mean(data, axis=0)
    m_val = tf.cond(first, lambda: data_mean, lambda: m)
    m_new = m_val + d * (data_mean - m_val) / k_new
    s_new = s + tf.abs(tf.reduce_sum((data - m) * (data - m_new), axis=0))
    with tf.control_dependencies([tf.assign(m, m_new), tf.assign(s, s_new), tf.assign(k, k_new)]):
        update = tf.group([tf.assign(first, False)])
    mean = tf.identity(m)
    std = tf.sqrt(s / k)
    # std = s
    return (mean, std), update

def parse_example(example_proto, x_size, y_size):
    import tensorflow as tf
    features = {'x': tf.FixedLenFeature((x_size,), tf.float32),
                'y': tf.FixedLenFeature((y_size,), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['y']

def parse_example_weight(example_proto, input_size, output_size):
    import tensorflow as tf
    features = {'x': tf.FixedLenFeature((input_size,), tf.float32),
                'y': tf.FixedLenFeature((output_size,), tf.float32),
                'weight': tf.FixedLenFeature((), tf.float32, default_value=1.)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['y'], parsed_features['weight']

def example_remove_phase(parse_fn, y_phase_start, y_phase_end):
    def new_parse_fn(self, example_proto):
        import tensorflow as tf
        x, y, w = parse_fn(self, example_proto)
        y = tf.concat([y[:y_phase_start], y[y_phase_end:]], axis=-1)
        return x, y, w
    return new_parse_fn

def remove_styles(preprocess_scaled_fn, x_style_start, x_style_end):
    def new_preprocess_scaled_fn(self, x, g):
        import tensorflow as tf
        x, g = preprocess_scaled_fn(self, x, g)
        x = tf.concat([x[:, :x_style_start], x[:, x_style_end:]], axis=1)
        return x, g
    return new_preprocess_scaled_fn

def joint_importance_scale(joint_batch):
    import tensorflow as tf
    if not IMPORTANT_JOINTS:
        return joint_batch
    # Scale-preserving weights
    low_weight = NUM_JOINTS / (NUM_JOINTS + (IMPORTANCE_BOOST - 1) * len(IMPORTANT_JOINTS))
    high_weight = IMPORTANCE_BOOST * low_weight
    mask = tf.scatter_nd(tf.expand_dims(IMPORTANT_JOINTS, 1),
                         tf.ones([len(IMPORTANT_JOINTS)], dtype=joint_batch.dtype), [NUM_JOINTS])
    weights = tf.expand_dims(mask * high_weight + (1 - mask) * low_weight, 1)
    return joint_batch * weights

def add_zero_phases(postprocess_fn, y_phase_start, y_phase_end):
    def new_postprocess_fn(self, output):
        import tensorflow as tf
        output = postprocess_fn(self, output)
        zero_phases = tf.zeros([tf.shape(output)[0], y_phase_end - y_phase_start], output.dtype)
        return tf.concat([output[:, :y_phase_start], zero_phases, output[:, y_phase_start:]], axis=-1)
    return new_postprocess_fn

def normalize_xy(postprocess_fn, y_joint_xy_start, y_joint_xy_end):
    def new_postprocess_fn(self, output):
        import tensorflow as tf
        output = postprocess_fn(self, output)
        output_xy = output[:, y_joint_xy_start:y_joint_xy_end]
        output_xy = tf.reshape(output_xy, (-1, NUM_JOINTS, 2, 3))
        output_xy /= tf.sqrt(tf.reduce_sum(tf.square(output_xy), axis=-1, keepdims=True))
        output_xy = tf.reshape(output_xy, (-1, NUM_JOINTS * 6))
        return tf.concat([output[:, :y_joint_xy_start], output_xy, output[:, y_joint_xy_end:]], axis=-1)
    return new_postprocess_fn

def parse_args(args=None):
    global NUM_FEET, NUM_JOINTS, PAST_TRAJECTORY_WINDOW, FUTURE_TRAJECTORY_WINDOW, NUM_STYLES, WITH_TERRAIN, WITH_CONTACT
    global ENCODER_LAYERS, HIDDEN_LAYERS, DECODER_LAYERS, NUM_GATED, GATING_HIDDEN_LAYERS, GATING_INDICES
    global FUNCTIONED_OUTPUT, USE_BIAS, REMOVE_STYLES, REMOVE_PHASE
    global BATCH_SIZE, LEARNING_RATE, MAX_STEPS, PATIENCE, EVALUATION_STEPS, FREEZE_STEPS
    global SHUFFLE_BUFFER_SIZE, DROPOUT_KEEP_PROB, REGULARIZATION, USE_WEIGHTS, IMPORTANT_JOINTS, IMPORTANCE_BOOST, NOISE
    global LINEAR_EVALUATIONS, CONSTANT_EVALUATIONS
    global SEED, USE_GPU

    parser = argparse.ArgumentParser(usage='%(prog)s [-h] <argument> ...',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     argument_default=argparse.SUPPRESS)
    dir_group = parser.add_argument_group('directories')
    dir_group.add_argument('--train-data-dir', help='directory with *.tfrecord files for training', required=True)
    dir_group.add_argument('--test-data-dir', help='directory with *.tfrecord files for test', required=True)
    dir_group.add_argument('--log-dir', help='directory where training results are saved', required=True)
    data_group = parser.add_argument_group('data features')
    data_group.add_argument('--num-feet', help='number of feet in the character', type=int, required=True)
    data_group.add_argument('--num-joints', help='number of skeletal joints', type=int, required=True)
    data_group.add_argument('--past-trajectory-window', help='size of past trajectory window', type=int, required=True)
    data_group.add_argument('--future-trajectory-window', help='size of future trajectory window', type=int, required=True)
    data_group.add_argument('--num-styles', help='number of style labels', type=int, required=True)
    data_group.add_argument('--with-terrain', help='whether to use terrain information', type=str2bool, required=True)
    data_group.add_argument('--with-contact', help='whether to use contact information', type=str2bool, required=True)
    model_group = parser.add_argument_group('model architecture')
    model_group.add_argument('--encoder-layers', help='encoder layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--hidden-layers', help='hidden layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--decoder-layers', help='decoder layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--num-gated', help='number of gated parameterisations', type=int, default=NUM_GATED)
    model_group.add_argument('--gating-hidden-layers', help='gating network hidden layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--gating-indices', help='gating network input indices', metavar='INDEX[,...]', type=num_list, default='none')
    model_group.add_argument('--functioned-output', help='use grid-functioned output layer', type=str2bool, default=str(FUNCTIONED_OUTPUT).lower())
    model_group.add_argument('--use-bias', help='whether to use bias units', type=str2bool, default=str(USE_BIAS).lower())
    model_group.add_argument('--remove-styles', help='whether to remove style labels', type=str2bool, default=str(REMOVE_STYLES).lower())
    model_group.add_argument('--remove-phase', help='whether to remove phase values from output', type=str2bool, default=str(REMOVE_PHASE).lower())
    train_group = parser.add_argument_group('training settings')
    train_group.add_argument('--batch-size', help='training batch size.', type=int, default=BATCH_SIZE)
    train_group.add_argument('--learning-rate', help='learning rate', type=float, default=LEARNING_RATE)
    train_group.add_argument('--max-steps', help='maximum number of training steps', type=int, default=MAX_STEPS)
    train_group.add_argument('--patience', help='number of evaluations without improvement before stopping (negative for no limit)', type=int, default=PATIENCE)
    train_group.add_argument('--evaluation-steps', help='number of steps per evaluation', type=int, default=EVALUATION_STEPS)
    train_group.add_argument('--freeze-steps', help='number of steps per frozen graph', type=int, default=FREEZE_STEPS)
    train_group.add_argument('--shuffle-buffer-size', help='size of the data shuffling buffer', type=int, default=SHUFFLE_BUFFER_SIZE)
    train_group.add_argument('--dropout', help='dropout rate', type=float, default=round(1 - DROPOUT_KEEP_PROB, 6))
    train_group.add_argument('--regularization', help='regularization weight', type=float, default=REGULARIZATION)
    train_group.add_argument('--noise', help='scale of Gaussian noise added to input and output for training', type=float, default=NOISE)
    train_group.add_argument('--use-weights', help='whether to use training weights', type=str2bool, default=USE_WEIGHTS)
    train_group.add_argument('--important-joints', help='indices of important joints', metavar='INDEX[,...]', type=num_list, default='none')
    train_group.add_argument('--importance-boost', help='loss boost for important joints', type=int, default=IMPORTANCE_BOOST)
    other_group = parser.add_argument_group('other options')
    other_group.add_argument('--seed', help='random seed', type=int, default=SEED)
    other_group.add_argument('--use-gpu', help='whether to use the GPU', type=str2bool, default=str(USE_GPU).lower())
    parsed_args = parser.parse_args(args=args)

    train_data_dir = parsed_args.train_data_dir
    test_data_dir = parsed_args.test_data_dir
    log_dir = parsed_args.log_dir
    NUM_FEET = parsed_args.num_feet
    NUM_JOINTS = parsed_args.num_joints
    PAST_TRAJECTORY_WINDOW = parsed_args.past_trajectory_window
    FUTURE_TRAJECTORY_WINDOW = parsed_args.future_trajectory_window
    NUM_STYLES = parsed_args.num_styles
    WITH_TERRAIN = parsed_args.with_terrain
    WITH_CONTACT = parsed_args.with_contact
    ENCODER_LAYERS = parsed_args.encoder_layers or []
    HIDDEN_LAYERS = parsed_args.hidden_layers or []
    DECODER_LAYERS = parsed_args.decoder_layers
    NUM_GATED = parsed_args.num_gated or []
    GATING_HIDDEN_LAYERS = parsed_args.gating_hidden_layers or []
    GATING_INDICES = parsed_args.gating_indices or []
    FUNCTIONED_OUTPUT = parsed_args.functioned_output
    USE_BIAS = parsed_args.use_bias
    REMOVE_STYLES = parsed_args.remove_styles
    REMOVE_PHASE = parsed_args.remove_phase
    BATCH_SIZE = parsed_args.batch_size
    LEARNING_RATE = parsed_args.learning_rate
    MAX_STEPS = parsed_args.max_steps
    PATIENCE = parsed_args.patience
    EVALUATION_STEPS = parsed_args.evaluation_steps
    FREEZE_STEPS = parsed_args.freeze_steps
    SHUFFLE_BUFFER_SIZE = parsed_args.shuffle_buffer_size
    DROPOUT_KEEP_PROB = 1 - parsed_args.dropout
    REGULARIZATION = parsed_args.regularization
    NOISE = parsed_args.noise
    USE_WEIGHTS = parsed_args.use_weights
    IMPORTANT_JOINTS = parsed_args.important_joints
    IMPORTANCE_BOOST = parsed_args.importance_boost
    SEED = parsed_args.seed
    USE_GPU = parsed_args.use_gpu

    return train_data_dir, test_data_dir, log_dir

def num_list(v):
    try:
        v = v.strip().lower()
        if not v or v == 'none': return []
        nums = []
        for s in v.split(','):
            if '-' in s:
                start, end = map(int, s.split('-'))
                nums.extend(range(start, end + 1))
            else:
                nums.append(int(s))
        return nums
    except:
        raise argparse.ArgumentTypeError('expected a comma-separated list of numbers.')

def str2bool(v):
    v = str(v).strip().lower()
    if v in {'yes', 'true', 't', 'y', '1'}:
        return True
    elif v in {'no', 'false', 'f', 'n', '0'}:
        return False
    else:
        raise argparse.ArgumentTypeError('boolean value expected.')


if __name__ == '__main__':
    main(*parse_args())
