import os
import sys
import argparse
from pathlib import Path

# Parse arguments before importing libraries
def parse_args(args=None):
    parser = argparse.ArgumentParser(argument_default=None)
    parser.add_argument('--data', help='data directories', nargs='*',
                        metavar='[NAME:]PATH', type=named_path)
    parser.add_argument('--models', help='model files', nargs='*',
                        metavar='[NAME:]PATH[?MODE]', type=named_path_mode('CUBIC'))
    parser.add_argument('--out', help='output directory', type=Path, required=True)
    parser.add_argument('--per-data', help='write results for each data file', type=str2bool, default='false')
    parser.add_argument('--figure', help='make error figure', type=str2bool, default='true')
    parsed_args = parser.parse_args(args=args)
    data = parsed_args.data or []
    if len(set(name for name, _ in data)) != len(data):
        raise ValueError('data file names must be unique.')
    models = parsed_args.models or []
    if len(set(name for name, _, _ in models)) != len(models):
        raise ValueError('model names must be unique.')
    out = parsed_args.out
    per_data = parsed_args.per_data
    figure = parsed_args.figure
    return data, models, out, per_data, figure

def named_path(v):
    try:
        name = ''
        path = v
        if ':' in v: name, path = v.split(':', 1)
        if not path: raise ValueError
        path = Path(path)
        if not name: name = path.stem
        if not name: raise ValueError
        return name, Path(path)
    except BaseException as e:
        raise argparse.ArgumentTypeError('value must have the format [NAME:]PATH.')

def named_path_mode(default_mode):
    def f(v):
        try:
            mode = default_mode
            name_path = v
            if '?' in v: name_path, mode = v.split('?', 1)
            if not name_path: raise ValueError
            name, path = named_path(name_path)
            return name, path, mode.strip().upper()
        except BaseException as e:
            raise argparse.ArgumentTypeError('value must have the format [NAME:]PATH[?MODE].')
    return f

def str2bool(v):
    v = str(v).strip().lower()
    if v in {'yes', 'true', 't', 'y', '1'}:
        return True
    elif v in {'no', 'false', 'f', 'n', '0'}:
        return False
    else:
        raise argparse.ArgumentTypeError('boolean value expected.')

if __name__ == '__main__':
    ARGS = parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Setup plotting
#sns.set(color_codes=True, context="talk")
sns.set(color_codes=True, context="notebook")
#sns.set(color_codes=True, context="paper")
# Figure file format
plt.rc("savefig", format="pdf")
# Use Latex to render text in plots
plt.rc("text", usetex=True)
plt.rc("ps", usedistiller="xpdf")
plt.rc("font", family="serif")

def make_empty_dataframe(columns, dtypes, index=None):
    return pd.concat([pd.Series(name=col, dtype=dt, index=index) for col, dt in zip(columns, dtypes)], axis=1)

# Model analysis

def num_nn_params(input_size, output_size, hidden_layers, use_bias=True):
    nn_params = 0
    last = input_size
    for layer in [*hidden_layers, output_size]:
        nn_params += last * layer
        if use_bias:
            nn_params += layer
        last = layer
    return nn_params

def num_gfnn_params(spline_points, input_size, output_size, hidden_layers, use_bias=True):
    nn_params = num_nn_params(input_size, output_size, hidden_layers, use_bias)
    spline_size = 1
    for p in spline_points:
        spline_size *= p
    return nn_params * spline_size

def flops_spline_old(spline_points, nn_params):
    flops = 0
    # Bigger dimensions are reduced to 4, dimensions are sorted by size
    spline_points = sorted((min(p, 4) for p in spline_points), reverse=True)
    # Weight computation cost
    for p in spline_points:
        if p > 1:
            # Weight computation fixed cost
            flops += 17
            # Dimensions with size 2 or 3 aggregate weights
            if p < 4:
                flops += 4 * p + 3 * p
    # Interpolation cost
    flops_per_param = 0
    for i, p in enumerate(spline_points):
        if p > 1:
            # Weight multiplication and reduction
            a = np.prod(spline_points[i:])
            flops_per_param += a + (a // p) * (p - 1)
    flops += flops_per_param * nn_params
    return flops

def flops_spline(spline_points, nn_params):
    flops = 0
    # Bigger dimensions are reduced to 4, dimensions are sorted by size
    spline_points = sorted((min(p, 4) for p in spline_points), reverse=True)
    # Weight computation cost
    for i, p in enumerate(spline_points):
        if p > 1:
            # Weight computation fixed cost
            flops += 17
            # Dimensions with size 2 or 3 aggregate interpolation weights
            if p < 4:
                flops += 4 * p + 3 * p
        # Flops to combine parameterisations across this dimension
        flops_per_interp = (p + (p - 1)) * nn_params
        # Number of interpolated parameterisations for this dimension
        num_interp = np.prod(spline_points[i + 1:], dtype=int)
        # Add cost
        flops += flops_per_interp * num_interp
    return flops

def flops_linear(spline_points, nn_params):
    n_dims = sum(1 for d in spline_points if d > 1)
    return 3 * n_dims + (1 << n_dims) * nn_params

def flops_constant(spline_points):
    n_dims = sum(1 for d in spline_points if d > 1)
    return 1 * n_dims

def flops_nn(input_size, output_size, hidden_layers, use_bias=True):
    flops = 0
    last = input_size
    for layer in [*hidden_layers, output_size]:
        flops += 2 * (last * layer)
        if use_bias:
            flops += layer
        last = layer
    return flops

def get_model_metadata(model_files):
    model_names = [model_name for model_name, _, _ in model_files]
    df = make_empty_dataframe(*zip(('input_size', int),
                                   ('output_size', int),
                                   ('hidden_layers', object),
                                   ('use_bias', bool),
                                   ('parameters_nn', int),
                                   ('spline_points', object),
                                   ('parameters_spline', int),
                                   ('mode', object),
                                   ('grid_evaluations', object),
                                   ('parameters_grid', object),
                                   ('flops_interp', int),
                                   ('flops_nn', int),
                                   ('flops_total', int)),
                              index=model_names)
    df.index.name = 'model'
    for i, (model_name, model_file, model_mode) in enumerate(model_files):
        with tf.Graph().as_default(), tf.Session() as sess:
            # Load graph
            gd = tf.GraphDef()
            with Path(model_file).open('rb') as f:
                gd.MergeFromString(f.read())
            *metadata, = tf.import_graph_def(gd, name='', return_elements=[
                'Meta/NumParameters:0', 'Meta/SplinePoints:0', f'Meta/GridEvaluations{model_mode.title()}:0',
                'Meta/HiddenLayers:0', 'Meta/UseBias:0', 'Meta/InputSize:0', 'Meta/OutputSize:0'
            ])
            # Get metadata
            params_spline, spline_points, grid_evals, hidden_layers, use_bias, input_size, output_size = sess.run(metadata)
            params_nn = params_spline // np.prod(spline_points)
            params_grid = params_nn * np.prod(grid_evals)
            # Flops estimation
            if model_mode == 'CUBIC':
                flops_interp = flops_spline(spline_points, params_nn)
            elif model_mode == 'LINEAR':
                flops_interp = flops_linear(spline_points, params_nn)
            elif model_mode == 'CONSTANT':
                flops_interp = flops_constant(spline_points)
            else:
                raise ValueError(f'unknown model mode "{model_mode}"')
            flops_net = flops_nn(input_size, output_size, hidden_layers, use_bias)
            flops_total = flops_interp + flops_net
            df.loc[model_name, :] = [input_size, output_size, ' '.join(map(str, hidden_layers)), use_bias, params_nn,
                                     ' '.join(map(str, spline_points)), params_spline,
                                     model_mode, ' '.join(map(str, grid_evals)), params_grid,
                                     flops_interp, flops_net, flops_total]
    return df

def get_error_statistics(errors):
    stats_columns = ['mean', 'std', 'min', 'max', 'p25', 'p50', 'p75']
    return pd.DataFrame([[np.mean(errors, axis=0), np.std(errors, axis=0),
                          np.min(errors, axis=0), np.max(errors, axis=0),
                          np.percentile(errors, 25, axis=0),
                          np.percentile(errors, 50, axis=0),
                          np.percentile(errors, 75, axis=0)]],
                          columns=stats_columns)

def read_records(data_file):
    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = (tf.data.Dataset.list_files(str(Path(data_file, '*.tfrecord')))
                   .flat_map(tf.data.TFRecordDataset)
                   .map(lambda example_proto: parse_example(example_proto, 2, 1),
                        num_parallel_calls=os.cpu_count())
                   .batch(10000)
                   .prefetch(1))
        it = dataset.make_one_shot_iterator()
        x, y = it.get_next()
        sess.run(tf.global_variables_initializer())
        vals = []
        while True:
            try:
                vals.append(sess.run([x, y]))
            except tf.errors.OutOfRangeError: break
        xs, ys = map(np.concatenate, zip(*vals))
        return xs, ys

def parse_example(example_proto, x_size, y_size):
    import tensorflow as tf
    features = {'x': tf.FixedLenFeature((x_size,), tf.float32),
                'y': tf.FixedLenFeature((y_size,), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['y']

def run_model(model_file, model_mode, x):
    with tf.Graph().as_default(), tf.Session() as sess:
        gd = tf.GraphDef()
        with open(model_file, 'rb') as f:
            gd.ParseFromString(f.read())
        inp, out = tf.import_graph_def(
            gd, name='', return_elements=['Input:0', f'Output{model_mode.title()}:0'])
        out_vals = [sess.run(out, feed_dict={inp: xi}) for xi in x]
    return np.stack(out_vals)

def main(data_files, model_files, out_dir, per_data, figure):
    out_dir.mkdir(parents=True, exist_ok=True)
    data_names = [data_name for data_name, _ in data_files]
    model_names = [model_name for model_name, _, _ in model_files]
    data_df = make_empty_dataframe(['num_examples'], [int], index=data_names)
    data_df.index.name = 'Data'
    model_df = get_model_metadata(model_files)
    data_frame_errors = []
    for data_name, data_file in data_files:
        print(f'Reading data {data_name}...')
        x, y = read_records(data_file)
        num_examples = len(x)
        data_df.loc[data_name, :] = num_examples
        error_df = make_empty_dataframe(model_names, [float] * len(model_files), index=range(num_examples))
        for i, (model_name, model_file, model_mode) in enumerate(model_files):
            print(f'Evaluating {model_name} on {data_name}...')
            pred = run_model(model_file, model_mode, x)
            model_error = np.abs(y - pred)
            error_df[model_name] = model_error
        data_frame_errors.append(error_df.melt(var_name='Model', value_name='Error'))
        data_frame_errors[-1].insert(0, 'Data', data_name)
    # Save results
    print('Saving results...')
    with pd.ExcelWriter(Path(out_dir, 'results.xlsx')) as writer:
        if model_files:
            model_df.to_excel(writer, sheet_name='Models')
        if data_files:
            data_df.to_excel(writer, sheet_name='Data')
        if data_files and model_files:
            error_df = pd.concat(data_frame_errors, ignore_index=True)
            overall_df = (error_df.groupby('Model')['Error']
                          .apply(get_error_statistics)
                          .reset_index(level=-1, drop=True)
                          .loc[model_names, :])
            overall_df.to_excel(writer, sheet_name='Overall')
            if per_data:
                data_stats = (error_df.groupby(['Data', 'Model'])['Error']
                              .apply(get_error_statistics)
                              .reset_index(level=-1, drop=True))
                for data_name, _ in data_files:
                    data_df = data_stats.loc[data_name].loc[model_names]
                    data_df.to_excel(writer, sheet_name=data_name)
            error_df.to_excel(writer, index=False, sheet_name='Error')
    if figure and data_files and model_files:
        for data_name, data_error in error_df.groupby('Data'):
            ax = sns.boxenplot('Model', 'Error', order=model_names, data=data_error)
            ax.get_figure().savefig(Path(out_dir, f'error_{data_name}'))
    print('Done.')

if __name__ == '__main__':
    main(*ARGS)
