import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

# Setup plotting
sns.set(context='notebook',
        style='darkgrid',
        palette='deep',
        font='serif',
        font_scale=1.4,
        color_codes=True)
# Figure file format
plt.rc('savefig', format='pdf')
# Use Latex to render text in plots
plt.rc('text', usetex=True)
plt.rc('ps', usedistiller='xpdf')
plt.rc('font', family='serif')
plt.close('all')

def rosenbrock(x, y):
    a, b = 1, 100
    return np.square(a - x) + 100 * np.square(y - np.square(x))

def ackley(x, y):
    return (-20 * np.exp(-0.2 * np.sqrt(0.5 * (np.square(x) + np.square(y))))
            - np.exp(0.5 * (np.cos(2 * np.pi * x) + np.cos(2 * np.pi * y)))
            + np.e + 20)

def evaluate_model(x, y, model_file):
    with tf.Graph().as_default(), tf.Session() as sess:
        with open(model_file, 'rb') as f:
            gd = tf.GraphDef()
            gd.ParseFromString(f.read())
        inp, out = tf.import_graph_def(gd, name='', return_elements=['Input:0', 'Output:0'])
        del gd
        xx, yy = np.broadcast_arrays(x, y)
        zz = np.empty(xx.size, xx.dtype)
        for i, xy in enumerate(np.stack([xx.ravel(), yy.ravel()], axis=1)):
            zz[i] = sess.run(out, feed_dict={inp: xy})
        return zz.reshape(xx.shape)

def plot_error(x, y, z, z2, vmin=None, vmax=None):
    sns.set_context('notebook', font_scale=1.4)
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111, projection='3d')
    z_err = np.abs(z - z2)
    plot_surface_error(x, y, z, z_err, vmin, vmax, ax)
    fig.show()
    fig.subplots_adjust(left=-0.05, right=0.99, bottom=0.03, top=1.04)
    return ax

def plot_errormap(x, y, z, z2, xticks, yticks, vmin=None, vmax=None):
    sns.set_context('notebook', font_scale=1.6)
    fig = plt.figure(figsize=(4.9, 4))
    ax = fig.add_subplot(111)
    ax.grid(False)
    z_err = np.abs(z - z2)
    extent = (xticks.min(), xticks.max(), yticks.min(), yticks.max())
    im = ax.imshow(z_err, vmin=0, vmax=vmax, extent=extent,
                   aspect='equal', interpolation='bilinear')
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    fig.show()
    fig.subplots_adjust(left=0.175, right=0.825, bottom=0.11, top=1.0)
    return ax

def plot_colorbar(mappable):
    sns.set_context('notebook', font_scale=1.6)
    fig = plt.figure(figsize=(4, 4))
    fig.colorbar(mappable, cax=plt.gca(), orientation='horizontal')
    fig.subplots_adjust(left=0.06, right=0.94, bottom=0.45, top=0.55)
    return plt.gca()

def plot_surface_error(x, y, z, err, vmin, vmax, ax):
    xx, yy, zz = np.broadcast_arrays(x, y, z)
    bgcolor = ax.get_facecolor()
    ax.w_xaxis.set_pane_color(bgcolor)
    ax.w_yaxis.set_pane_color(bgcolor)
    ax.w_zaxis.set_pane_color(bgcolor)
    ax.set_facecolor('none')
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_zlabel('$z$')
    ax.xaxis.labelpad = 12.0
    ax.yaxis.labelpad = 12.0
    cmap = mpl.cm.plasma
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
    ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, linewidth=0, alpha=0.75, facecolors=cmap(norm(err)))
    ax.view_init(elev=35, azim=-60)

def plot_error_boxen(error_file, ylims):
    sns.set_context('notebook', font_scale=1.8)
    error_df = pd.read_excel(error_file, sheet_name='Error')
    plt.figure(figsize=(8, 4))
    ax = sns.boxenplot('Model', 'Error', data=error_df)
    ax.set_ylim(*ylims)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticklabels(['MLP', '$2\\times 2$', '$3\\times 3$', '$4\\times 4$',
                        '$5\\times 5$', '$6\\times 6$', '$7\\times 7$'])
    #for t in ax.get_xticklabels():
    #    t.set_rotation(45)
    #    t.set_ha('right')
    #ax.figure.tight_layout(pad=0.05)
    ax.figure.subplots_adjust(left=0.07, right=0.93, bottom=0.09, top=0.995)
    return ax

# plot_error_boxen('eval_mem/rosenbrock/results/all_cubic/results.xlsx', (-14, 206))
# plot_error_boxen('eval_mem/ackley/results/all_cubic/results.xlsx', (-0.2, 3.2))
# plot_error_boxen('eval_mem/ackleysmall/results/all_cubic/results.xlsx', (-0.05, 1.1))

def main():
    # Reconstruction resolution
    RES = 50j

    # Rosenbrock
    x, y = np.ogrid[-2:2:RES, -1:3:RES]
    z = rosenbrock(x, y)
    vmax = 500
    xticks = np.linspace(-2, 2, 5)
    yticks = np.linspace(-1, 3, 5)
    zticks = np.linspace(0, 2500, 6)
    zlim = (0, 2500)
    for model in ('NN', '2x2', '3x3', '4x4', '5x5', '6x6', '7x7'):
        z_model = evaluate_model(x, y, f'eval_mem/rosenbrock/train/{model}/gfnn_best.tf')
        ax = plot_error(x, y, z_model, z, 0, vmax)
        ax.set_zlim(*zlim)
        ax.set_xticks(xticks)
        ax.set_yticks(yticks)
        ax.set_zticks(zticks)
        ax.figure.savefig(f'fig/rosenbrock_error_{model}')
        plt.close()
        ax = plot_errormap(x, y, z_model, z, xticks, yticks, 0, vmax)
        ax.figure.savefig(f'fig/rosenbrock_errormap_{model}')
        plt.close()
    ax = plot_colorbar(ax.images[0])
    ax.figure.savefig(f'fig/rosenbrock_errormap_colorbar')
    plt.close()
    # Ackley
    x, y = np.ogrid[-5:5:RES, -5:5:RES]
    z = ackley(x, y)
    vmax = 4
    xticks = np.linspace(-4, 4, 5)
    yticks = np.linspace(-4, 4, 5)
    zticks = np.linspace(0, 15, 4)
    zlim = (0, 15)
    for model in ('NN', '2x2', '3x3', '4x4', '5x5', '6x6', '7x7'):
        z_model = evaluate_model(x, y, f'eval_mem/ackley/train/{model}/gfnn_best.tf')
        ax = plot_error(x, y, z_model, z, 0, vmax)
        ax.set_zlim(*zlim)
        ax.set_xticks(xticks)
        ax.set_yticks(yticks)
        ax.set_zticks(zticks)
        ax.figure.savefig(f'fig/ackley_error_{model}')
        plt.close()
        ax = plot_errormap(x, y, z_model, z, xticks, yticks, 0, vmax)
        ax.figure.savefig(f'fig/ackley_errormap_{model}')
        plt.close()
    ax = plot_colorbar(ax.images[0])
    ax.figure.savefig(f'fig/ackley_errormap_colorbar')
    plt.close()
    # Ackley (small region)
    x, y = np.ogrid[-1:1:RES, -1:1:RES]
    z = ackley(x, y)
    vmax = 1
    xticks = np.linspace(-1, 1, 5)
    yticks = np.linspace(-1, 1, 5)
    zticks = np.linspace(0, 5, 6)
    zlim = (0, 5)
    for model in ('NN', '2x2', '3x3', '4x4', '5x5', '6x6', '7x7'):
        z_model = evaluate_model(x, y, f'eval_mem/ackleysmall/train/{model}/gfnn_best.tf')
        ax = plot_error(x, y, z_model, z, 0, vmax)
        ax.set_zlim(*zlim)
        ax.set_xticks(xticks)
        ax.set_yticks(yticks)
        ax.set_zticks(zticks)
        ax.figure.savefig(f'fig/ackleysmall_error_{model}')
        plt.close()
        ax = plot_errormap(x, y, z_model, z, xticks, yticks, 0, vmax)
        ax.figure.savefig(f'fig/ackleysmall_errormap_{model}')
        plt.close()
    ax = plot_colorbar(ax.images[0])
    ax.figure.savefig(f'fig/ackleysmall_errormap_colorbar')
    plt.close()

if __name__ == '__main__':
    main()
