#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import random
import math
from pathlib import Path
import argparse
import numpy as np

SEED = 0
ENCODER_LAYERS = None
HIDDEN_LAYERS = None
DECODER_LAYERS = None
SPLINE_POINTS = None
SPLINE_CLOSED = False
GRID_FEATURES = None
GATING_HIDDEN_LAYERS = None
GATING_FEATURES = None
FUNCTIONED_OUTPUT = True
USE_BIAS = True
OUTPUT_ACTIVATION = None
NORMALIZE_INPUT = None
NORMALIZE_OUTPUT = None
LOSS_FUNCTION = None
BATCH_SIZE = 32
LEARNING_RATE = 0.0001
MAX_STEPS = 1_000_000
PATIENCE = -1
EVALUATION_STEPS = 10_000
FREEZE_STEPS = 0
SHUFFLE_BUFFER_SIZE = 200_000
DROPOUT_KEEP_PROB = 0.7
REGULARIZATION = 0.1
USE_WEIGHTS = True
LINEAR_EVALUATIONS = None
CONSTANT_EVALUATIONS = None
USE_GPU = True

def main(data_dir, logdir):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    if not USE_GPU:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    import tensorflow as tf
    from gfnn import Gfnn
    if hasattr(tf, 'get_logger'):
        import logging
        tf_log = tf.get_logger()
        tf_log.parent.handlers.clear()
    # Random seed
    random.seed(SEED)
    np.random.seed(SEED)
    tf.set_random_seed(SEED)

    # Grid parameters
    spline_points = list(SPLINE_POINTS)
    if not spline_points:
        spline_points = []
    grid_dims = len(spline_points)
    spline_closed = [False] * grid_dims
    linear_eval = LINEAR_EVALUATIONS
    constant_eval = CONSTANT_EVALUATIONS
    if GRID_FEATURES:
        if GATING_HIDDEN_LAYERS or GATING_FEATURES:
            raise RuntimeError('invalid configuration.')
        if len(GRID_FEATURES) != grid_dims:
            raise RuntimeError('expected {} grid features but found {}.'.format(grid_dims, len(GRID_FEATURES)))

    # Read data feature names
    train_tfrecord_pattern = str(Path(data_dir, 'train', '*.tfrecord'))
    test_tfrecord_pattern = str(Path(data_dir, 'test', '*.tfrecord'))
    with Path(data_dir, 'x.names').open('r') as f:
        x_names = np.array(list(filter(None, map(str.strip, f.read().split('\n')))))
    x_names_lower = list(map(str.lower, x_names))
    input_size = len(x_names)
    with Path(data_dir, 'y.names').open('r') as f:
        y_names = np.array(list(filter(None, map(str.strip, f.read().split('\n')))))
    y_names_lower = list(map(str.lower, y_names))
    output_size = len(y_names)

    # Activation function
    activation_functions = {
        '': None, 'none': None, 'linear': None,
        'sigmoid': tf.math.sigmoid, 'tanh': tf.math.tanh, 'softmax': tf.math.softmax
    }
    out_act_name = (OUTPUT_ACTIVATION or '').strip().lower()
    if out_act_name not in activation_functions:
        raise RuntimeError('unknown activation function "{}".'.format(OUTPUT_ACTIVATION))
    output_activation_fn = activation_functions[out_act_name]

    # Loss function
    sigmoid_xent = lambda self, label_unscaled, out_unscaled, out_logits: (
        tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label_unscaled, logits=out_logits), axis=1))
    softmax_xent = lambda self, label_unscaled, out_unscaled, out_logits: (
        tf.nn.softmax_cross_entropy_with_logits_v2(labels=label_unscaled, logits=out_logits))
    xent_default = sigmoid_xent
    if out_act_name == 'softmax' or (out_act_name != 'sigmoid' and len(y_names) > 1):
        xent_default = softmax_xent
    loss_functions = {
        '': None, 'none': None, 'mse': None, 'squared-error': None,
        'cross-entropy': xent_default, 'xent': xent_default,
        'sigmoid-cross-entropy': sigmoid_xent, 'sigmoid-xent': sigmoid_xent, 'sigmoid': sigmoid_xent,
        'softmax-cross-entropy': softmax_xent, 'softmax-xent': softmax_xent, 'softmax': softmax_xent,
    }
    loss_name = (LOSS_FUNCTION or '').strip().lower()
    if loss_name not in loss_functions:
        raise RuntimeError('unknown loss function "{}".'.format(LOSS_FUNCTION))
    loss_fn = loss_functions[loss_name]

    # Make log directory
    logdir = Path(logdir)
    if logdir.exists() and not logdir.is_dir():
        raise RuntimeError('{} is not a valid directory.'.format(logdir))
    if not logdir.exists():
        logdir.mkdir(parents=True)

    # Normalization selection
    if NORMALIZE_INPUT is None:
        norm_in = range(input_size)
    elif isinstance(NORMALIZE_INPUT, bool):
        norm_in = range(input_size) if NORMALIZE_INPUT else []
    else:
        norm_in = [x_names_lower.index(f.lower()) if not isinstance(f, int) else f for f in NORMALIZE_INPUT]
    if NORMALIZE_OUTPUT is None:
        norm_out = range(output_size)
    elif isinstance(NORMALIZE_OUTPUT, bool):
        norm_out = range(output_size) if NORMALIZE_OUTPUT else []
    else:
        norm_out = [y_names_lower.index(f.lower()) if not isinstance(f, int) else f for f in NORMALIZE_OUTPUT]

    x_mean_path = Path(logdir, 'x_mean.txt')
    x_scale_path = Path(logdir, 'x_scale.txt')
    y_mean_path = Path(logdir, 'y_mean.txt')
    y_scale_path = Path(logdir, 'y_scale.txt')
    if all(map(Path.is_file, [x_mean_path, x_scale_path, y_mean_path, y_scale_path])):
        x_mean = np.fromfile(str(x_mean_path), sep='\n', dtype=np.float32)
        x_scale = np.fromfile(str(x_scale_path), sep='\n', dtype=np.float32)
        y_mean = np.fromfile(str(y_mean_path), sep='\n', dtype=np.float32)
        y_scale = np.fromfile(str(y_scale_path), sep='\n', dtype=np.float32)
    else:
        x_mean, x_scale = np.zeros(input_size, np.float32), np.ones(input_size, np.float32)
        y_mean, y_scale = np.zeros(output_size, np.float32), np.ones(output_size, np.float32)
        if norm_in or norm_out:
            print('Preprocessing data...')
            x_mean_c, x_scale_c, y_mean_c, y_scale_c = compute_mean_std_records(train_tfrecord_pattern, input_size, output_size)
            x_scale_c[np.abs(x_scale_c) < 1e-6] = 1
            y_scale_c[np.abs(y_scale_c) < 1e-6] = 1
            x_mean[norm_in], x_scale[norm_in] = x_mean_c[norm_in], x_scale_c[norm_in]
            y_mean[norm_out], y_scale[norm_out] = y_mean_c[norm_out], y_scale_c[norm_out]

    # Save mean and scale
    x_mean.tofile(str(x_mean_path), sep='\n')
    x_scale.tofile(str(x_scale_path), sep='\n')
    y_mean.tofile(str(y_mean_path), sep='\n')
    y_scale.tofile(str(y_scale_path), sep='\n')

    # Parse function
    parse_example_fn = lambda self, example_proto: parse_example_dummy(
        example_proto, self.input_size, grid_dims, self.output_size)
    # Preprocess function
    preprocess_fn = lambda self, x, coords: (x, coords)
    if GRID_FEATURES:
        grid_feats = [x_names_lower.index(f.lower()) if not isinstance(f, int) else f for f in GRID_FEATURES]
        preprocess_fn = gating_direct_coords(preprocess_fn, grid_feats, x_mean, x_scale)
    elif GATING_HIDDEN_LAYERS or GATING_FEATURES:
        if GATING_FEATURES:
            grid_feats = [x_names_lower.index(f.lower()) if not isinstance(f, int) else f for f in GATING_FEATURES]
        else:
            grid_feats = list(range(len(input_size)))
        preprocess_fn = gating_network_coords(preprocess_fn, spline_points, GATING_HIDDEN_LAYERS, grid_feats, x_mean, x_scale)
    else:
        if grid_dims > 0:
            grid_feats = feature_ranking(train_tfrecord_pattern, input_size, output_size)[:grid_dims]
        else:
            grid_feats = np.array([], dtype=np.int32)
        preprocess_fn = gating_direct_coords(preprocess_fn, grid_feats, x_mean, x_scale)

    # Save selected grid features
    with Path(logdir, 'grid_features.txt').open('w') as f:
        for feat in x_names[grid_feats]:
            f.write(feat)
            f.write('\n')

    preprocess_scaled_fn = None
    postprocess_fn = None

    print('Creating model in {}...'.format(logdir))
    net = Gfnn(input_size, output_size, ENCODER_LAYERS, HIDDEN_LAYERS, DECODER_LAYERS,
               spline_points, spline_closed,
               logdir, dtype=tf.float32, functioned_output=FUNCTIONED_OUTPUT, use_bias=USE_BIAS,
               regularization=REGULARIZATION, dropout=1 - DROPOUT_KEEP_PROB, use_weights=USE_WEIGHTS,
               input_mean=x_mean, input_scale=x_scale,
               output_mean=y_mean, output_scale=y_scale,
               parse_example_fn=parse_example_fn,
               preprocess_fn=preprocess_fn,
               preprocess_scaled_fn=preprocess_scaled_fn,
               postprocess_fn=postprocess_fn,
               hidden_activation_fn=tf.nn.relu,
               output_activation_fn=output_activation_fn,
               loss_fn=loss_fn,
               devices=get_devices(),
               seed=SEED)

    # print('Writing test model...')
    # tf.train.write_graph(
    #     net.make_runtime_graph(),
    #     net.logdir, 'gfnn_test.tf', as_text=False)

    # Write model summary
    with Path(logdir, 'model_summary.txt').open('w') as f:
        f.write(f'input_size: {net.input_size}\n')
        f.write(f'output_size: {net.output_size}\n')
        f.write(f'spline_points: {net.spline_points}\n')
        f.write(f'spline_close: {net.spline_close}\n')
        f.write(f'encoder_layers: {net.encoder_layers}\n')
        f.write(f'hidden_layers: {net.hidden_layers}\n')
        f.write(f'decoder_layers: {net.decoder_layers}\n')
        f.write(f'functioned_output: {net.functioned_output}\n')
        f.write(f'use_bias: {net.use_bias}\n')
        f.write(f'num_params: {net.num_params}\n')
        f.write(f'linear_eval: {linear_eval}\n')
        f.write(f'constant_eval: {constant_eval}\n')

    # Save command line
    with Path(logdir, 'command_line.txt').open('w') as f:
        print(' '.join(sys.argv), file=f)

    print('Training...')
    net.train(training_record_file_pattern=train_tfrecord_pattern,
              test_record_file_pattern=test_tfrecord_pattern,
              max_steps=MAX_STEPS,
              batch_size=BATCH_SIZE,
              shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
              learning_rate=LEARNING_RATE,
              evaluation_steps=EVALUATION_STEPS,
              patience=PATIENCE,
              freeze_steps=FREEZE_STEPS)
    print()
    tf.train.write_graph(net.make_runtime_graph(full=True, linear=False, constant=False),
                         net.logdir, 'gfnn_cubic.tf', as_text=False)
    if linear_eval:
        tf.train.write_graph(net.make_runtime_graph(full=False, linear=linear_eval, constant=False),
                             net.logdir, 'gfnn_linear.tf', as_text=False)
    if constant_eval:
        tf.train.write_graph(net.make_runtime_graph(full=False, linear=False, constant=constant_eval),
                             net.logdir, 'gfnn_constant.tf', as_text=False)
    print('Done.')

def get_devices():
    from tensorflow.python.client import device_lib
    devices = device_lib.list_local_devices()
    target_devices = [d for d in devices if d.device_type == 'GPU']
    if not target_devices: target_devices = devices
    return [d.name for d in target_devices]

def compute_mean_std_records(record_file_pattern, x_size, y_size):
    import tensorflow as tf
    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = tf.data.Dataset.list_files(record_file_pattern)
        dataset = dataset.apply(tf.data.experimental.parallel_interleave(
            tf.data.TFRecordDataset, cycle_length=os.cpu_count(), sloppy=False))
        dataset = dataset.map(lambda example_proto: parse_example(example_proto, x_size, y_size),
                              num_parallel_calls=os.cpu_count())
        dataset = dataset.batch(10000)
        dataset = dataset.prefetch(1)
        it = dataset.make_one_shot_iterator()
        x, y = it.get_next()
        (x_mean, x_std), x_upd = running_moments(x)
        (y_mean, y_std), y_upd = running_moments(y)
        sess.run(tf.global_variables_initializer())
        try:
            while True:
                sess.run([x_upd, y_upd])
        except tf.errors.OutOfRangeError: pass
        return sess.run((x_mean, x_std, y_mean, y_std))

def running_moments(data):
    import tensorflow as tf
    # Welford algorithm
    size = data.shape[1].value
    first = tf.Variable(True, dtype=tf.bool, trainable=False)
    m = tf.Variable([0.] * size, dtype=tf.float32, trainable=False)
    s = tf.Variable([0.] * size, dtype=tf.float32, trainable=False)
    k = tf.Variable(0, dtype=tf.float32, trainable=False)
    d = tf.cast(tf.shape(data)[0], tf.float32)
    k_new = k + d
    data_mean = tf.reduce_mean(data, axis=0)
    m_val = tf.cond(first, lambda: data_mean, lambda: m)
    m_new = m_val + d * (data_mean - m_val) / k_new
    s_new = s + tf.abs(tf.reduce_sum((data - m) * (data - m_new), axis=0))
    with tf.control_dependencies([tf.assign(m, m_new), tf.assign(s, s_new), tf.assign(k, k_new)]):
        update = tf.group([tf.assign(first, False)])
    mean = tf.identity(m)
    std = tf.sqrt(s / k)
    # std = s
    return (mean, std), update

def feature_ranking(record_file_pattern, x_size, y_size):
    import tensorflow as tf
    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = tf.data.Dataset.list_files(record_file_pattern)
        dataset = dataset.apply(tf.data.experimental.parallel_interleave(
            tf.data.TFRecordDataset, cycle_length=os.cpu_count(), sloppy=False))
        dataset = dataset.map(lambda example_proto: parse_example(example_proto, x_size, y_size),
                              num_parallel_calls=os.cpu_count())
        dataset = dataset.batch(10000)
        dataset = dataset.prefetch(1)
        it = dataset.make_one_shot_iterator()
        x, y = it.get_next()
        xs, ys = [], []
        try:
            while True:
                x_val, y_val = sess.run((x, y))
                xs.append(x_val)
                ys.append(y_val)
        except tf.errors.OutOfRangeError: pass
        xs = np.concatenate(xs)
        ys = np.concatenate(ys)
    c = np.abs(np.corrcoef(xs.T, ys.T)[:x_size, x_size:]).mean(1)
    return np.argsort(-c)

def spca(x, y, d=None):
    import numpy as np
    x = np.asarray(x)
    y = np.asarray(y)
    n = len(x)
    if d is None:
        d = x.shape[1]
    l = y @ y.T
    hx = (x - x.sum(0) / n)
    q = hx.T @ l @ hx
    eigvals, eigvecs = np.linalg.eigh(q)
    eigtop = np.argsort(-np.abs(eigvals))[:d]
    kernel = eigvecs[:, eigtop]
    return kernel

def pca(x, d=None):
    import numpy as np
    x = np.asarray(x)
    n = len(x)
    if d is None:
        d = x.shape[1]
    q = (x.T @ x) / (n - 1)
    eigvals, eigvecs = np.linalg.eig(q)
    eigtop = np.argsort(-np.abs(eigvals))[:d]
    kernel = eigvecs[:, eigtop]
    return kernel

def parse_example(example_proto, x_size, y_size):
    import tensorflow as tf
    features = {'x': tf.io.FixedLenFeature((x_size,), tf.float32),
                'y': tf.io.FixedLenFeature((y_size,), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['y']

def parse_example_dummy(example_proto, x_size, grid_dims, y_size):
    import tensorflow as tf
    features = {'x': tf.io.FixedLenFeature((x_size,), tf.float32),
                'y': tf.io.FixedLenFeature((y_size,), tf.float32),
                'weight': tf.io.FixedLenFeature((y_size,), tf.float32, default_value=1.)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], tf.zeros((grid_dims,), tf.float32), parsed_features['y'], parsed_features['weight']

def gating_direct_coords(preprocess_fn, gating_direct_indices, x_mean, x_scale):
    def new_preprocess_fn(self, x, coords):
        import tensorflow as tf
        x, coords = preprocess_fn(self, x, coords)
        gate = tf.gather(x, gating_direct_indices, axis=1)
        gate_mean = x_mean[gating_direct_indices]
        gate_scale = x_scale[gating_direct_indices]
        gate = (gate - gate_mean) / gate_scale
        return x, tf.nn.sigmoid(gate)
        # return x, tf.clip_by_value(gate * 0.25 + 0.5, 0, 1)
    return new_preprocess_fn

def gating_network_coords(preprocess_fn, spline_points, gating_hidden_layers, gating_indices, x_mean, x_scale):
    def new_preprocess_fn(self, x, coords):
        import tensorflow as tf
        x, coords = preprocess_fn(self, x, coords)
        if gating_indices:
            x_gate = tf.gather(x, gating_indices, axis=1)
            x_gate_mean = x_mean[gating_indices]
            x_gate_scale = x_scale[gating_indices]
        else:
            x_gate = x
            x_gate_mean = x_mean
            x_gate_scale = x_scale
        x_gate = (x_gate - x_gate_mean) / x_gate_scale
        x_gate = tf.expand_dims(x_gate, 1)
        if gating_hidden_layers:
            y_gate = self._make_module(x_gate, 'GatingEncoder', gating_hidden_layers, self._simple_vars)
        with tf.variable_scope('GatingEncoderOutput'):
            spline_idx = [i for i, s in enumerate(spline_points) if s > 1]
            y_gate = self._make_layer(y_gate, len(spline_idx), self._simple_vars)
            # y_gate = tf.floormod(y_gate, 1)
            y_gate = tf.math.sigmoid(y_gate)
        y_gate = tf.squeeze(y_gate, 1)
        out_shape = tf.stack([len(spline_points), tf.shape(y_gate)[0]])
        y_gate = tf.transpose(tf.scatter_nd(tf.expand_dims(spline_idx, 1), tf.transpose(y_gate), out_shape))
        return x, y_gate
    return new_preprocess_fn

def float_feature(value):
    import tensorflow as tf
    value = np.atleast_1d(value)
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def write_records(path, xs, ys):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with tf.io.TFRecordWriter(str(path)) as writer:
        for x, y in zip(xs, ys):
            writer.write(tf.train.Example(features=tf.train.Features(feature={
                'x': float_feature(x), 'y': float_feature(y)
            })).SerializeToString())

def parse_args(args=None):
    global ENCODER_LAYERS, SPLINE_POINTS, HIDDEN_LAYERS, DECODER_LAYERS
    global FUNCTIONED_OUTPUT, OUTPUT_ACTIVATION, USE_BIAS, NORMALIZE_INPUT, NORMALIZE_OUTPUT
    global GATING_ENCODER, GRID_FEATURES, GATING_HIDDEN_LAYERS, GATING_FEATURES
    global LOSS_FUNCTION, BATCH_SIZE, LEARNING_RATE, MAX_STEPS, PATIENCE, EVALUATION_STEPS, FREEZE_STEPS
    global SHUFFLE_BUFFER_SIZE, DROPOUT_KEEP_PROB, REGULARIZATION, USE_WEIGHTS
    global LINEAR_EVALUATIONS, CONSTANT_EVALUATIONS
    global SEED, USE_GPU

    parser = argparse.ArgumentParser(usage='%(prog)s [-h] <argument> ...',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     argument_default=argparse.SUPPRESS)
    dir_group = parser.add_argument_group('directories')
    dir_group.add_argument('--data-dir', help='directory with *.tfrecord files for training and testing', required=True)
    dir_group.add_argument('--log-dir', help='directory where training results are saved', required=True)
    model_group = parser.add_argument_group('model architecture')
    model_group.add_argument('--encoder-layers', help='encoder layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--hidden-layers', help='hidden layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--decoder-layers', help='decoder layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--spline-points', help='number of spline points', metavar='NUM_POINTS[,...]', type=num_list, default='none')
    model_group.add_argument('--grid-features', help='grid features', metavar='FEATURE[,...]', type=feature_list, default='none')
    model_group.add_argument('--gating-hidden-layers', help='gating network hidden layer sizes', metavar='SIZE[,...]', type=num_list, default='none')
    model_group.add_argument('--gating-features', help='gating network input features', metavar='FEATURE[,...]', type=feature_list, default='none')
    model_group.add_argument('--functioned-output', help='use grid-functioned output layer', type=str2bool, default=str(FUNCTIONED_OUTPUT).lower())
    model_group.add_argument('--output-activation', help='output activation function', default='none')
    model_group.add_argument('--use-bias', help='whether to use bias units', type=str2bool, default=str(USE_BIAS).lower())
    model_group.add_argument('--normalize-input', help='whether to normalize input values', type=norm_selection, default='true')
    model_group.add_argument('--normalize-output', help='whether to normalize output values', type=norm_selection, default='true')
    train_group = parser.add_argument_group('training settings')
    train_group.add_argument('--loss-function', help='loss function', default='none')
    train_group.add_argument('--batch-size', help='training batch size', type=int, default=BATCH_SIZE)
    train_group.add_argument('--learning-rate', help='learning rate', type=float, default=LEARNING_RATE)
    train_group.add_argument('--max-steps', help='maximum number of training steps', type=int, default=MAX_STEPS)
    train_group.add_argument('--patience', help='number of evaluations without improvement before stopping (negative for no limit)', type=int, default=PATIENCE)
    train_group.add_argument('--evaluation-steps', help='number of steps per evaluation', type=int, default=EVALUATION_STEPS)
    train_group.add_argument('--freeze-steps', help='number of steps per frozen graph', type=int, default=FREEZE_STEPS)
    train_group.add_argument('--shuffle-buffer-size', help='size of the data shuffling buffer', type=int, default=SHUFFLE_BUFFER_SIZE)
    train_group.add_argument('--dropout', help='dropout rate', type=float, default=round(1 - DROPOUT_KEEP_PROB, 6))
    train_group.add_argument('--regularization', help='regularization weight', type=float, default=REGULARIZATION)
    train_group.add_argument('--use-weights', help='whether to use training weights', type=str2bool, default=USE_WEIGHTS)
    export_group = parser.add_argument_group('export settings')
    export_group.add_argument('--linear-eval', help='number of linear evaluations', metavar='NUM_EVAL[,...]',
                              type=num_list, default=str(LINEAR_EVALUATIONS).lower())
    export_group.add_argument('--constant-eval', help='number of constant evaluations', metavar='NUM_EVAL[,...]',
                              type=num_list, default=str(CONSTANT_EVALUATIONS).lower())
    other_group = parser.add_argument_group('other options')
    other_group.add_argument('--seed', help='random seed', type=int, default=SEED)
    other_group.add_argument('--use-gpu', help='whether to use the GPU', type=str2bool, default=str(USE_GPU).lower())
    parsed_args = parser.parse_args(args=args)

    data_dir = parsed_args.data_dir
    log_dir = parsed_args.log_dir
    ENCODER_LAYERS = parsed_args.encoder_layers or []
    HIDDEN_LAYERS = parsed_args.hidden_layers or []
    DECODER_LAYERS = parsed_args.decoder_layers
    SPLINE_POINTS = parsed_args.spline_points or []
    GRID_FEATURES = parsed_args.grid_features or []
    GATING_HIDDEN_LAYERS = parsed_args.gating_hidden_layers or []
    GATING_FEATURES = parsed_args.gating_features or []
    FUNCTIONED_OUTPUT = parsed_args.functioned_output
    OUTPUT_ACTIVATION = parsed_args.output_activation
    USE_BIAS = parsed_args.use_bias
    NORMALIZE_INPUT = parsed_args.normalize_input
    NORMALIZE_OUTPUT = parsed_args.normalize_output
    LOSS_FUNCTION = parsed_args.loss_function
    BATCH_SIZE = parsed_args.batch_size
    LEARNING_RATE = parsed_args.learning_rate
    MAX_STEPS = parsed_args.max_steps
    PATIENCE = parsed_args.patience
    EVALUATION_STEPS = parsed_args.evaluation_steps
    FREEZE_STEPS = parsed_args.freeze_steps
    SHUFFLE_BUFFER_SIZE = parsed_args.shuffle_buffer_size
    DROPOUT_KEEP_PROB = 1 - parsed_args.dropout
    REGULARIZATION = parsed_args.regularization
    USE_WEIGHTS = parsed_args.use_weights
    LINEAR_EVALUATIONS = parsed_args.linear_eval
    CONSTANT_EVALUATIONS = parsed_args.constant_eval
    SEED = parsed_args.seed
    USE_GPU = parsed_args.use_gpu

    return data_dir, log_dir

def num_list(v):
    try:
        v = v.strip().lower()
        if not v or v == 'none': return []
        nums = []
        for s in v.split(','):
            s = s.strip()
            if not s: continue
            if '-' in s:
                start, end = map(int, s.split('-'))
                nums.extend(range(start, end + 1))
            else:
                nums.append(int(s))
        return nums
    except:
        raise argparse.ArgumentTypeError('expected a comma-separated list of numbers.')

def feature_list(v):
    v = v.strip().lower()
    if not v or v == 'none': return []
    vs = v.split(',')
    features = []
    for s in vs:
        s = s.strip()
        if not s: continue
        try:
            if '-' in s:
                start, end = map(int, s.split('-'))
                features.extend(range(start, end + 1))
            else:
                features.append(int(s))
        except ValueError:
            features.append(s)
    return features

def norm_selection(v):
    try:
        return str2bool(v)
    except argparse.ArgumentTypeError:
        return feature_list(v)

def str2bool(v):
    v = str(v).strip().lower()
    if v in {'yes', 'true', 't', 'y', '1'}:
        return True
    elif v in {'no', 'false', 'f', 'n', '0'}:
        return False
    else:
        raise argparse.ArgumentTypeError('boolean value expected.')

if __name__ == '__main__':
    main(*parse_args())
