import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import itertools
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

def read_data(data_dir):
    x_train, y_train = read_records(Path(data_dir, 'train'))
    x_test, y_test = read_records(Path(data_dir, 'test'))
    return (x_train, y_train), (x_test, y_test)

def read_records(data_dir):
    with tf.Graph().as_default(), tf.Session() as sess:
        dataset = (tf.data.Dataset.list_files(str(Path(data_dir, '*.tfrecord')))
                   .flat_map(tf.data.TFRecordDataset)
                   .map(lambda example_proto: parse_example(example_proto, 2, 1),
                        num_parallel_calls=os.cpu_count())
                   .batch(10000)
                   .prefetch(1))
        it = dataset.make_one_shot_iterator()
        x, y = it.get_next()
        sess.run(tf.global_variables_initializer())
        vals = []
        while True:
            try:
                vals.append(sess.run([x, y]))
            except tf.errors.OutOfRangeError: break
        xs, ys = map(np.concatenate, zip(*vals))
        return xs, ys

def parse_example(example_proto, x_size, y_size):
    import tensorflow as tf
    features = {'x': tf.FixedLenFeature((x_size,), tf.float32),
                'y': tf.FixedLenFeature((y_size,), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['y']

def rosenbrock(x, y):
    a, b = 1, 100
    return np.square(a - x) + 100 * np.square(y - np.square(x))

def ackley(x, y):
    return (-20 * np.exp(-0.2 * np.sqrt(0.5 * (np.square(x) + np.square(y))))
            - np.exp(0.5 * (np.cos(2 * np.pi * x) + np.cos(2 * np.pi * y)))
            + np.e + 20)

def plot_problem(x, y, func, data_dir=None, ax=None):
    z = func(x, y)
    fig = None
    if ax is None:
        sns.set_context('notebook', font_scale=1.4)
        fig = plt.figure(figsize=(6, 4))
        ax = fig.add_subplot(111, projection='3d')
    bgcolor = ax.get_facecolor()
    ax.w_xaxis.set_pane_color(bgcolor)
    ax.w_yaxis.set_pane_color(bgcolor)
    ax.w_zaxis.set_pane_color(bgcolor)
    ax.set_facecolor('none')
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_zlabel('$z$')
    ax.xaxis.labelpad = 12.0
    ax.yaxis.labelpad = 12.0
    pal = sns.color_palette()
    if data_dir:
        (x_train, y_train), (x_test, y_test) = read_data(data_dir)
        p_train = np.concatenate([x_train, y_train], axis=1)
        p_test = np.concatenate([x_test, y_test], axis=1)
        #print('p_train', len(p_train))
        #print('p_test', len(p_test))
        p_bounds = list(itertools.product(*map(lambda c: (c.min(), c.max()), (x, y, z))))
        p = np.concatenate([p_train, p_test, p_bounds], axis=0)
        c = np.zeros((len(p), 4), np.float32)
        c[:len(p_train), :3] = pal[2]
        c[len(p_train):-8, :3] = pal[1]
        c[:-8, 3] = 1.0  # 0.8
        ax.scatter3D(*p.T, s=5, c=c)
    ax.plot_surface(x, y, z, rstride=1, cstride=1, linewidth=0, alpha=0.75, color=pal[0])
    ax.view_init(elev=35, azim=-60)
    if fig is not None:
        fig.show()
        fig.subplots_adjust(left=-0.05, right=0.99, bottom=0.03, top=1.04)
    return ax

def main():
    # Setup plotting
    sns.set(context='notebook',
            style='darkgrid',
            palette='deep',
            font='serif',
            font_scale=1.4,
            color_codes=True)
    # Figure file format
    plt.rc('savefig', format='pdf')
    # Use Latex to render text in plots
    plt.rc('text', usetex=True)
    plt.rc('ps', usedistiller='xpdf')
    plt.rc('font', family='serif')
    plt.close('all')

    RES = 50j

    x, y = np.ogrid[-2:2:RES, -1:3:RES]
    ax = plot_problem(x, y, rosenbrock)
    ax.set_xticks(np.linspace(-2, 2, 5))
    ax.set_yticks(np.linspace(-1, 3, 5))
    ax.set_zticks(np.linspace(0, 2500, 6))
    ax.figure.savefig('fig/rosenbrock')
    x, y = np.ogrid[-5:5:RES, -5:5:RES]
    ax = plot_problem(x, y, ackley)
    ax.set_xticks(np.linspace(-4, 4, 5))
    ax.set_yticks(np.linspace(-4, 4, 5))
    ax.set_zticks(np.linspace(0, 15, 4))
    ax.figure.savefig('fig/ackley')
    x, y = np.ogrid[-1:1:RES, -1:1:RES]
    ax = plot_problem(x, y, ackley)
    ax.set_xticks(np.linspace(-1, 1, 5))
    ax.set_yticks(np.linspace(-1, 1, 5))
    ax.set_zticks(np.linspace(0, 5, 6))
    ax.figure.savefig('fig/ackleysmall')

    x, y = np.ogrid[-2:2:RES, -1:3:RES]
    ax = plot_problem(x, y, rosenbrock, 'data/rosenbrock')
    ax.set_zlim(0, 2500)
    ax.set_xticks(np.linspace(-2, 2, 5))
    ax.set_yticks(np.linspace(-1, 3, 5))
    ax.set_zticks(np.linspace(0, 2500, 6))
    ax.figure.savefig('fig/rosenbrock_data')
    x, y = np.ogrid[-5:5:RES, -5:5:RES]
    ax = plot_problem(x, y, ackley, 'data/ackley')
    ax.set_zlim(0, 15)
    ax.set_xticks(np.linspace(-4, 4, 5))
    ax.set_yticks(np.linspace(-4, 4, 5))
    ax.set_zticks(np.linspace(0, 15, 4))
    ax.figure.savefig('fig/ackley_data')
    x, y = np.ogrid[-1:1:RES, -1:1:RES]
    ax = plot_problem(x, y, ackley, 'data/ackleysmall')
    ax.set_zlim(0, 5)
    ax.set_xticks(np.linspace(-1, 1, 5))
    ax.set_yticks(np.linspace(-1, 1, 5))
    ax.set_zticks(np.linspace(0, 5, 6))
    ax.figure.savefig('fig/ackleysmall_data')

if __name__ == '__main__':
    main()
