import os
import sys
from pathlib import Path
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

RESOLUTION = 50

def get_problem_config(problem, resolution):
    prob = problem.strip().lower()
    res = round(abs(resolution)) * 1j
    if prob == 'rosenbrock' or prob == 'rosenbrock2':
        x, y = np.ogrid[-2:2:res, -1:3:res]
        z = rosenbrock(x, y)
        vmax = 500
        xticks = np.linspace(-2, 2, 5)
        yticks = np.linspace(-1, 3, 5)
        zticks = np.linspace(0, 2500, 6)
        zlim = (0, 2500)
    elif prob == 'ackley' or prob == 'ackley2' or prob == 'ackley3':
        x, y = np.ogrid[-5:5:res, -5:5:res]
        z = ackley(x, y)
        vmax = 4
        xticks = np.linspace(-4, 4, 5)
        yticks = np.linspace(-4, 4, 5)
        zticks = np.linspace(0, 15, 4)
        zlim = (0, 15)
    elif prob == 'ackleysmall' or prob == 'ackleysmall2':
        x, y = np.ogrid[-1:1:res, -1:1:res]
        z = ackley(x, y)
        vmax = 1
        xticks = np.linspace(-1, 1, 5)
        yticks = np.linspace(-1, 1, 5)
        zticks = np.linspace(0, 5, 6)
        zlim = (0, 5)
    else:
        raise ValueError(f'unknown problem "{problem}".')
    return x, y, z, vmax, xticks, yticks, zticks, zlim

def rosenbrock(x, y):
    a, b = 1, 100
    return np.square(a - x) + 100 * np.square(y - np.square(x))

def ackley(x, y):
    return (-20 * np.exp(-0.2 * np.sqrt(0.5 * (np.square(x) + np.square(y))))
            - np.exp(0.5 * (np.cos(2 * np.pi * x) + np.cos(2 * np.pi * y)))
            + np.e + 20)

def evaluate_model(x, y, model_file):
    with tf.Graph().as_default(), tf.Session() as sess:
        with open(model_file, 'rb') as f:
            gd = tf.GraphDef()
            gd.ParseFromString(f.read())
        inp, out = tf.import_graph_def(gd, name='', return_elements=['Input:0', 'Output:0'])
        del gd
        xx, yy = np.broadcast_arrays(x, y)
        zz = np.empty(xx.size, xx.dtype)
        for i, xy in enumerate(np.stack([xx.ravel(), yy.ravel()], axis=1)):
            zz[i] = sess.run(out, feed_dict={inp: xy})
        return zz.reshape(xx.shape)

def plot_reconstruction(x, y, z):
    sns.set_context('notebook', font_scale=1.4)
    fig = plt.figure(figsize=(6, 4))
    pal = sns.color_palette()
    ax = fig.add_subplot(111, projection='3d')
    plot_surface(x, y, z, pal[1], ax)
    fig.show()
    fig.subplots_adjust(left=-0.05, right=0.99, bottom=0.03, top=1.04)
    return ax

def plot_surface(x, y, z, color, ax):
    xx, yy, zz = np.broadcast_arrays(x, y, z)
    bgcolor = ax.get_facecolor()
    ax.w_xaxis.set_pane_color(bgcolor)
    ax.w_yaxis.set_pane_color(bgcolor)
    ax.w_zaxis.set_pane_color(bgcolor)
    ax.set_facecolor('none')
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_zlabel('$z$')
    ax.xaxis.labelpad = 12.0
    ax.yaxis.labelpad = 12.0
    ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, linewidth=0, alpha=0.75, color=color)
    ax.view_init(elev=35, azim=-60)


def plot_error(x, y, z, z2, vmin=None, vmax=None):
    sns.set_context('notebook', font_scale=1.4)
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111, projection='3d')
    z_err = np.abs(z - z2)
    plot_surface_error(x, y, z, z_err, vmin, vmax, ax)
    fig.show()
    fig.subplots_adjust(left=-0.05, right=0.99, bottom=0.03, top=1.04)
    return ax

def plot_errormap(x, y, z, z2, xticks, yticks, vmin=None, vmax=None):
    sns.set_context('notebook', font_scale=1.6)
    fig = plt.figure(figsize=(4.9, 4))
    ax = fig.add_subplot(111)
    ax.grid(False)
    z_err = np.abs(z - z2)
    extent = (x.min(), x.max(), y.min(), y.max())
    im = ax.imshow(z_err, vmin=0, vmax=vmax, extent=extent,
                   aspect='equal', interpolation='bilinear')
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    fig.show()
    fig.subplots_adjust(left=0.175, right=0.825, bottom=0.11, top=1.0)
    return ax

def plot_colorbar(mappable):
    sns.set_context('notebook', font_scale=1.6)
    fig = plt.figure(figsize=(4, 4))
    fig.colorbar(mappable, cax=plt.gca(), orientation='horizontal')
    fig.subplots_adjust(left=0.06, right=0.94, bottom=0.45, top=0.55)
    return plt.gca()

def plot_surface_error(x, y, z, err, vmin, vmax, ax):
    xx, yy, zz = np.broadcast_arrays(x, y, z)
    bgcolor = ax.get_facecolor()
    ax.w_xaxis.set_pane_color(bgcolor)
    ax.w_yaxis.set_pane_color(bgcolor)
    ax.w_zaxis.set_pane_color(bgcolor)
    ax.set_facecolor('none')
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_zlabel('$z$')
    ax.xaxis.labelpad = 12.0
    ax.yaxis.labelpad = 12.0
    cmap = mpl.cm.plasma
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
    ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, linewidth=0, alpha=0.75, facecolors=cmap(norm(err)))
    ax.view_init(elev=35, azim=-60)

def main(problem, model_file, out_dir):
    # Setup plotting
    sns.set(context='notebook',
            style='darkgrid',
            palette='deep',
            font='serif',
            font_scale=1.4,
            color_codes=True)
    # Figure file format
    plt.rc('savefig', format='pdf')
    # Use Latex to render text in plots
    plt.rc('text', usetex=True)
    plt.rc('ps', usedistiller='xpdf')
    plt.rc('font', family='serif')
    plt.close('all')

    Path(out_dir).mkdir(parents=True, exist_ok=True)
    x, y, z, vmax, xticks, yticks, zticks, zlim = get_problem_config(problem, RESOLUTION)
    z_model = evaluate_model(x, y, model_file)
    ax = plot_reconstruction(x, y, z_model)
    ax.set_zlim(*zlim)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_zticks(zticks)
    ax.figure.savefig(f'{out_dir}/reconst')
    plt.close()
    ax = plot_error(x, y, z_model, z, 0, vmax)
    ax.set_zlim(*zlim)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_zticks(zticks)
    ax.figure.savefig(f'{out_dir}/error')
    plt.close()
    ax = plot_errormap(x, y, z_model, z, xticks, yticks, 0, vmax)
    ax.figure.savefig(f'{out_dir}/errormap')
    plt.close()

if __name__ == '__main__':
    if len(sys.argv) != 4:
        print(f'Usage: {sys.argv[0]} <problem> <model file> <output dir>')
        sys.exit(1)
    main(*sys.argv[1:])
