import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

def rosenbrock(x, y):
    a, b = 1, 100
    return np.square(a - x) + 100 * np.square(y - np.square(x))

def ackley(x, y):
    return (-20 * np.exp(-0.2 * np.sqrt(0.5 * (np.square(x) + np.square(y))))
            - np.exp(0.5 * (np.cos(2 * np.pi * x) + np.cos(2 * np.pi * y)))
            + np.e + 20)

def evaluate_model(x, y, model_file):
    with tf.Graph().as_default(), tf.Session() as sess:
        with open(model_file, 'rb') as f:
            gd = tf.GraphDef()
            gd.ParseFromString(f.read())
        inp, out = tf.import_graph_def(gd, name='', return_elements=['Input:0', 'Output:0'])
        del gd
        xx, yy = np.broadcast_arrays(x, y)
        zz = np.empty(xx.size, xx.dtype)
        for i, xy in enumerate(np.stack([xx.ravel(), yy.ravel()], axis=1)):
            zz[i] = sess.run(out, feed_dict={inp: xy})
        return zz.reshape(xx.shape)

def plot_reconstruction(x, y, model):
    sns.set_context('notebook', font_scale=1.4)
    fig = plt.figure(figsize=(6, 4))
    pal = sns.color_palette()
    ax = fig.add_subplot(111, projection='3d')
    z = evaluate_model(x, y, model)
    plot_surface(x, y, z, pal[1], ax)
    fig.show()
    fig.subplots_adjust(left=-0.05, right=0.99, bottom=0.03, top=1.04)
    return ax

def plot_surface(x, y, z, color, ax):
    xx, yy, zz = np.broadcast_arrays(x, y, z)
    bgcolor = ax.get_facecolor()
    ax.w_xaxis.set_pane_color(bgcolor)
    ax.w_yaxis.set_pane_color(bgcolor)
    ax.w_zaxis.set_pane_color(bgcolor)
    ax.set_facecolor('none')
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_zlabel('$z$')
    ax.xaxis.labelpad = 12.0
    ax.yaxis.labelpad = 12.0
    ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, linewidth=0, alpha=0.75, color=color)
    ax.view_init(elev=35, azim=-60)

def main():
    # Setup plotting
    sns.set(context='notebook',
            style='darkgrid',
            palette='deep',
            font='serif',
            font_scale=1.4,
            color_codes=True)
    # Figure file format
    plt.rc('savefig', format='pdf')
    # Use Latex to render text in plots
    plt.rc('text', usetex=True)
    plt.rc('ps', usedistiller='xpdf')
    plt.rc('font', family='serif')
    plt.close('all')

    RES = 50j

    # Rosenbrock
    x, y = np.ogrid[-2:2:RES, -1:3:RES]
    for model in ('NN', '2x2', '3x3', '4x4', '5x5', '6x6', '7x7'):
        ax = plot_reconstruction(x, y, f'eval_mem/rosenbrock/train/{model}/gfnn_best.tf')
        ax.set_zlim(0, 2500)
        ax.set_xticks(np.linspace(-2, 2, 5))
        ax.set_yticks(np.linspace(-1, 3, 5))
        ax.set_zticks(np.linspace(0, 2500, 6))
        ax.figure.savefig(f'fig/rosenbrock_reconst_{model}')
        plt.close()
    # Ackley
    x, y = np.ogrid[-5:5:RES, -5:5:RES]
    for model in ('NN', '2x2', '3x3', '4x4', '5x5', '6x6', '7x7'):
        ax = plot_reconstruction(x, y, f'eval_mem/ackley/train/{model}/gfnn_best.tf')
        ax.set_zlim(0, 15)
        ax.set_xticks(np.linspace(-4, 4, 5))
        ax.set_yticks(np.linspace(-4, 4, 5))
        ax.set_zticks(np.linspace(0, 15, 4))
        ax.figure.savefig(f'fig/ackley_reconst_{model}')
        plt.close()
    # Ackley (small region)
    x, y = np.ogrid[-1:1:RES, -1:1:RES]
    for model in ('NN', '2x2', '3x3', '4x4', '5x5', '6x6', '7x7'):
        ax = plot_reconstruction(x, y, f'eval_mem/ackleysmall/train/{model}/gfnn_best.tf')
        ax.set_zlim(0, 5)
        ax.set_xticks(np.linspace(-1, 1, 5))
        ax.set_yticks(np.linspace(-1, 1, 5))
        ax.set_zticks(np.linspace(0, 5, 6))
        ax.figure.savefig(f'fig/ackleysmall_reconst_{model}')
        plt.close()

if __name__ == '__main__':
    main()
