#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
import os
import random
import sys

import numpy as np
import pandas as pd
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # nopep8
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import yaml

from dataio import read_data_directory
from report import make_confusion_matrix, make_class_metrics, make_jaccard, \
    plot_confusion_matrix, plot_accuracy_mass, plot_progress_accuracy, \
    plot_confidence_accuracy, plot_confidence_roc, plot_sensitivity, \
    make_hit_counts
from TimeNet import TimeNet, TimeNetConfig, \
    ConvolutionConfig, PoolingConfig, DilatedStackConfig, \
    LocalRecurrentStackConfig
import rice_features as F
from rice_train_helpers import make_runtime_graph, rice_input_filter, \
    make_class_progress, make_class_confidence, estimate_label_confidence, \
    trim_class_margins, normalized_class_weights

# Setup plotting
sns.set(color_codes=True, context="talk")
# Figure file format
#plt.rc("savefig", format="pdf")
# Use Latex to render text in plots
#plt.rc("text", usetex=True)
#plt.rc("ps", usedistiller="xpdf")
#plt.rc("font", family="serif")


def main(config_name, config):
    # Features names
    FEATURES = F.DATA_COLUMNS
    # Training data directory
    TRAIN_DIR = "dataset/train"
    # Validation data directory
    VALIDATION_DIR = "dataset/validation"
    # Test data directory
    TEST_DIR = "dataset/test"
    # Data frames per second, None if unknown
    FPS = 90

    data_dir = config.get("DataDir", ".")
    dataset_dir = Path(data_dir, config["Dataset"])
    output_dir = config.get("OutputDir", ".")
    seed = config.get("RandomSeed", None)
    no_gesture_config = config["DefaultGesture"]
    no_gesture = no_gesture_config["Name"]
    no_gesture_weight = no_gesture_config.get("Weight", 1.)
    no_gesture_prog = no_gesture_config.get("Progress", False) or False
    gestures_list = config["Gestures"]
    gestures = [g["Name"] for g in gestures_list]
    gestures_weights = [g.get("Weight", 1.) for g in gestures_list]
    gesture_trims = [g.get("Trim", {}) for g in gestures_list]
    gesture_trims = [(g.get("Begin", 0), g.get("End", 0))
                     for g in gesture_trims]
    gesture_progs = [g.get("Progress", False) or False for g in gestures_list]
    if no_gesture in gestures:
        raise ValueError(("The default gesture {} should not be "
                          "in the list of gestures").format(no_gesture))
    gestures.insert(0, no_gesture)
    gestures_weights.insert(0, no_gesture_weight)
    gesture_trims.insert(0, (0, 0))
    gesture_progs.insert(0, no_gesture_prog)

    # Training configuration
    training_config = config["Training"]
    batch_size = training_config["BatchSize"]
    step_size = training_config["StepSize"]
    max_epochs = training_config["MaxEpochs"]
    patience = training_config.get("Patience", -1)
    class_weights_norm_factor = float(
        training_config.get("NormalizeClassWeights", 0) or 0)
    delay = training_config.get("Delay", 0) or 0
    confidence_margin_ratio = training_config.get("ConfidenceMargin", 0) or 0
    estimate_condfidence = \
        training_config.get("EstimateConfidence", False) or False
    progress_loss = training_config.get("ProgressLoss", 0) or 0
    confidence_loss = training_config.get("ConfidenceLoss", 0) or 0
    weight_decay = training_config.get("WeightDecay", 0) or 0
    consistence_loss = training_config.get("ConsistenceLoss", 0) or 0
    dropout = training_config.get("Dropout", 0) or 0
    evaluation_step = training_config.get("EvaluationStep", 1) or 1
    dtype = tf.float32

    sensitivity_analysis = config.get("SensitivityAnalysis", False) or False

    # Model configuration
    model_config = config.get("Model", {}) or {}
    model_name = model_config.get("Name", None)
    drag_config = model_config.get("Drag", {}) or {}
    drag_pos = drag_config.get("Position", 0) or 0
    drag_yaw = drag_config.get("Yaw", 0) or 0
    drop_head = model_config.get("DropHead", False) or False
    drop_secondary_hand = model_config.get("DropSecondaryHand", False) or False
    conv_stack_config = model_config.get("ConvolutionStack", []) or []
    conv_stack = []
    for conv_layer in conv_stack_config:
        conv_config = conv_layer.get("Convolution", {}) or {}
        conv_kernel_size = conv_config.get("KernelSize", 0) or 0
        conv_channels = conv_config.get("Channels", 0) or 0
        conv_stride = conv_config.get("Stride", 1) or 1
        conv_dil_rate = conv_config.get("DilationRate", 1) or 1
        conv = ConvolutionConfig(kernel_size=conv_kernel_size,
                                 channels=conv_channels,
                                 stride=conv_stride,
                                 dilation_rate=conv_dil_rate)
        pool_config = conv_layer.get("Pooling", {}) or {}
        pool_type = pool_config.get("Type", "MAX") or "MAX"
        pool_window_size = pool_config.get("WindowSize", 0) or 0
        pool_stride = pool_config.get("Stride", 1) or 1
        pool_dil_rate = pool_config.get("DilationRate", 1) or 1
        pool = PoolingConfig(type=pool_type,
                             window_size=pool_window_size,
                             stride=pool_stride,
                             dilation_rate=pool_dil_rate)
        conv_stack.append((conv, pool))
    dil_stack_config = model_config.get("DilatedStack", {}) or {}
    dil_base = dil_stack_config.get("Base", 0) or 0
    dil_channels = dil_stack_config.get("Channels", []) or []
    dense_stack = model_config.get("DenseStack", []) or []
    local_rnn_stack = model_config.get("LocalRecurrentStack", {}) or {}
    local_rnn_depth = local_rnn_stack.get("Depth", 0) or 0
    local_rnn_size = local_rnn_stack.get("Size", 0) or 0
    local_rnn_config = LocalRecurrentStackConfig(depth=local_rnn_depth,
                                                 size=local_rnn_size)
    rnn_stack = model_config.get("RecurrentStack", []) or []
    softmax = True if model_config.get("Softmax") is not False else False
    output_smoothing = model_config.get("OutputSmoothing", 0) or 0
    use_default_class = model_config.get("UseDefaultClass", False) or False
    default_class = no_gesture if use_default_class else None
    progress_quantization = model_config.get("ProgressQuantization", 0) or 0
    confidence_quantization = \
        model_config.get("ConfidenceQuantization", 0) or 0
    net_config = TimeNetConfig(
        convolution_stack=conv_stack,
        dilated_stack=DilatedStackConfig(base=dil_base, channels=dil_channels),
        dense_stack=dense_stack,
        local_recurrent_stack=local_rnn_config,
        recurrent_stack=rnn_stack,
        softmax=softmax,
        default_class=default_class,
        output_smoothing=output_smoothing,
        progress_quantization=progress_quantization,
        confidence_quantization=confidence_quantization,
        step_size=step_size,
        batch_size=batch_size,
        progress_loss=progress_loss,
        confidence_loss=confidence_loss,
        weight_decay=weight_decay,
        consistence_loss=consistence_loss,
        dropout=dropout,
    )
    restore_checkpoint_file = config.get("RestoreCheckpoint", None)

    now_str = datetime.now().strftime("%Y%m%d%H%M%S")
    log_dir = Path(output_dir, "{}_{}".format(config_name, now_str))
    print("Results directory: {}".format(log_dir), flush=True)
    if not log_dir.exists():
        log_dir.mkdir(0o755, parents=True)
    tf_log_dir = Path(log_dir, "tflogs")
    if not tf_log_dir.exists():
        tf_log_dir.mkdir(0o755, parents=True)
    # Save process pid
    with Path(log_dir, "train.pid").open("w") as f:
        print(os.getpid(), file=f)
    # Save experiment configuration
    with Path(log_dir, "experiment.yaml").open("w") as f:
        yaml.dump(config, stream=f)

    # Random seed
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        tf.set_random_seed(seed)

    # Create the model
    data_dim = len(FEATURES)
    class_names = gestures
    mask_value = 0
    fixed_step = True  # False
    if output_smoothing is not None and output_smoothing > 1:
        fixed_step = True

    # Scaling, head-relative, dragging
    input_filter = rice_input_filter(
        drag_pos=drag_pos, drag_yaw=drag_yaw, drop_head=drop_head,
        drop_secondary_hand=drop_secondary_hand)
    features_filtered = FEATURES
    if drop_head:
        features_filtered = [f for f in FEATURES if f not in F.HEAD]
    if drop_secondary_hand:
        features_filtered = [f for f in FEATURES if f not in F.SECONDARY_HAND]

    print("Creating model...", flush=True)
    net = TimeNet(data_dim, class_names, net_config, dtype=dtype,
                  stateful=True, fixed_step=fixed_step, mask_value=mask_value,
                  use_dropout=True, input_filter=input_filter, name=model_name,
                  seed=seed)
    # net.log_graph(tf_log_dir)
    with Path(log_dir, "config.txt").open("w") as f:
        for k, v in net_config._asdict().items():
            print("* {}: {}".format(k, v), file=f)
        print(file=f)
        print("Total parameters: {}".format(net.num_parameters), file=f)

    try:
        # Save untrained network
        untrained_graph = make_runtime_graph(net)
        tf.train.write_graph(untrained_graph, str(log_dir),
                             "{}.untrained.tf".format(net.name), as_text=False)
        # Restore checkpoint
        if restore_checkpoint_file:
            net.restore_checkpoint(restore_checkpoint_file)
            print("Resored checkpoint: {}.".format(restore_checkpoint_file))
    finally:
        net.close_session()

    print("Model created.")

    # Data load
    print("Loading data...", flush=True)
    df_train = read_data_directory(Path(dataset_dir, TRAIN_DIR),
                                   index_col=0)
    df_validation = read_data_directory(Path(dataset_dir, VALIDATION_DIR),
                                        index_col=0)
    df_test = read_data_directory(Path(dataset_dir, TEST_DIR),
                                  index_col=0)
    if len(df_train) <= 0:
        raise RuntimeError("No training data in dataset.")
    df_data_train = ((df[FEATURES].values.copy(),
                      df[F.GESTURE_CLASS].values.copy()) for df in df_train)
    examples_train, labels_str_train = map(list, zip(*df_data_train))
    if df_validation:
        df_data_validation = ((df[FEATURES].values.copy(),
                               df[F.GESTURE_CLASS].values.copy())
                              for df in df_validation)
        examples_validation, labels_str_validation = \
            map(list, zip(*df_data_validation))
    else:
        examples_validation = []
        labels_str_validation = []
    if df_test:
        df_data_test = ((df[FEATURES].values.copy(),
                         df[F.GESTURE_CLASS].values.copy()) for df in df_test)
        examples_test, labels_str_test = map(list, zip(*df_data_test))
    else:
        examples_test = []
        labels_str_test = []

    # Replace unknown gestures
    for label_str in labels_str_train:
        label_str[np.in1d(label_str, gestures, invert=True)] = no_gesture
    for label_str in labels_str_validation:
        label_str[np.in1d(label_str, gestures, invert=True)] = no_gesture
    for label_str in labels_str_test:
        label_str[np.in1d(label_str, gestures, invert=True)] = no_gesture

    # Apply_delay
    if delay > 0:
        labels_str_train = [
            np.concatenate([np.full(min(delay, len(label_str)), no_gesture),
                            label_str[delay:]], axis=0)
            for label_str in labels_str_train]
        labels_str_validation = [
            np.concatenate([np.full(min(delay, len(label_str)), no_gesture),
                            label_str[delay:]], axis=0)
            for label_str in labels_str_validation]
        labels_str_test = [
            np.concatenate([np.full(min(delay, len(label_str)), no_gesture),
                            label_str[delay:]], axis=0)
            for label_str in labels_str_test]
    elif delay < 0:
        labels_str_train = [
            np.concatenate([label_str[:delay],
                            np.full(min(-delay, len(label_str)), no_gesture)],
                           axis=0)
            for label_str in labels_str_train]
        labels_str_validation = [
            np.concatenate([label_str[:delay],
                            np.full(min(-delay, len(label_str)), no_gesture)],
                           axis=0)
            for label_str in labels_str_validation]
        labels_str_test = [
            np.concatenate([label_str[:delay],
                            np.full(min(-delay, len(label_str)), no_gesture)],
                           axis=0)
            for label_str in labels_str_test]

    # Find numeric labels
    gesture_sorter = np.argsort(gestures)
    labels_train = []
    for label_str in labels_str_train:
        label_ids = gesture_sorter[np.searchsorted(gestures, label_str,
                                                   sorter=gesture_sorter)]
        labels_train.append(label_ids)
    labels_validation = []
    for label_str in labels_str_validation:
        label_ids = gesture_sorter[np.searchsorted(gestures, label_str,
                                                   sorter=gesture_sorter)]
        labels_validation.append(label_ids)
    labels_test = []
    for label_str in labels_str_test:
        label_ids = gesture_sorter[np.searchsorted(gestures, label_str,
                                                   sorter=gesture_sorter)]
        labels_test.append(label_ids)

    # Progress data
    gesture_progs_idx = [i for i, p in enumerate(gesture_progs) if p]
    label_progs_train = make_class_progress(labels_train,
                                            classes=gesture_progs_idx)
    label_progs_validation = make_class_progress(labels_validation,
                                                 classes=gesture_progs_idx)
    label_progs_test = make_class_progress(labels_test,
                                           classes=gesture_progs_idx)

    # Confidence margins
    label_confs_train = make_class_confidence(labels_train,
                                              confidence_margin_ratio)
    label_confs_validation = make_class_confidence(labels_validation,
                                                   confidence_margin_ratio)
    label_confs_test = make_class_confidence(labels_test,
                                             confidence_margin_ratio)

    # Estimate confidence data
    if estimate_condfidence:
        if all(F.GESTURE_PREDICTION in df for df in df_train):
            pred_train = [(df[F.GESTURE_CLASS].values ==
                           df[F.GESTURE_PREDICTION].values)
                          for df in df_train]
            label_confs_train = estimate_label_confidence(labels_train,
                                                          pred_train)
        if all(F.GESTURE_PREDICTION in df for df in df_validation):
            pred_validation = [(df[F.GESTURE_CLASS].values ==
                                df[F.GESTURE_PREDICTION].values)
                               for df in df_validation]
            label_confs_validation = estimate_label_confidence(
                labels_validation, pred_validation)
        if all(F.GESTURE_PREDICTION in df for df in df_test):
            pred_test = [(df[F.GESTURE_CLASS].values ==
                          df[F.GESTURE_PREDICTION].values)
                         for df in df_test]
            label_confs_test = estimate_label_confidence(labels_test,
                                                         pred_test)

    # Trim
    for i_gesture, (trim_begin, trim_end) in enumerate(gesture_trims):
        if trim_begin <= 0 and trim_end <= 0:
            continue
        labels_train = trim_class_margins(labels_train, trim_begin, trim_end,
                                          classes=i_gesture, replacement=0)
        labels_validation = trim_class_margins(labels_validation, trim_begin,
                                               trim_end, classes=i_gesture,
                                               replacement=0)
        labels_test = trim_class_margins(labels_test, trim_begin, trim_end,
                                         classes=i_gesture, replacement=0)

    print("Data loaded.", flush=True)

    # Training preparation
    train_data = (examples_train, labels_train,
                  label_progs_train, label_confs_train)
    validation_data = None
    if examples_validation and labels_validation:
        validation_data = (examples_validation, labels_validation,
                           label_progs_validation, label_confs_validation)
    test_data = None
    if examples_test and labels_test:
        test_data = (examples_test, labels_test,
                     label_progs_test, label_confs_test)

    # Weight normalization
    class_weights = np.array(gestures_weights)
    if class_weights_norm_factor > 0:
        class_weights = normalized_class_weights(labels_train,
                                                 class_weights_norm_factor,
                                                 class_weights)

    # Training
    try:
        checkpoint_file = Path(log_dir, "model.ckpt")
        final_metrics = net.train(*train_data,
                                  validation_data=validation_data,
                                  test_data=test_data,
                                  max_epochs=max_epochs,
                                  patience=patience,
                                  class_weights=class_weights,
                                  evaluation_step=evaluation_step,
                                  checkpoint_file=checkpoint_file,
                                  log_dir=tf_log_dir)
    finally:
        net.close_session()
    # Save train results
    final_metrics.unstack(1).to_csv(str(Path(log_dir, "metrics.csv")))
    net.train_history.to_csv(str(Path(log_dir, "train_history.csv")))

    # Export
    print("Saving model...")
    try:
        rt_graph = make_runtime_graph(net)
    finally:
        net.close_session()
    tf.train.write_graph(rt_graph, str(log_dir),
                         "{}.tf".format(net.name), as_text=False)
    tf.train.write_graph(rt_graph, str(log_dir),
                         "{}.txt".format(net.name), as_text=True)
    print("Model saved.")

    # Evaluate
    print()
    print("== Model evaluation ==")

    try:
        datasets = [("train", train_data)]
        if validation_data:
            datasets.append(("validation", validation_data))
        if test_data:
            datasets.append(("test", test_data))
        for name, (examples, labels, label_progs, label_confs) in datasets:
            results_dir = Path(log_dir, "results", name)
            if not results_dir.exists():
                results_dir.mkdir(0o755, parents=True)
            print()
            print("*" * (len(name) + 4))
            print("* " + name.upper() + " *")
            print("*" * (len(name) + 4))
            print(flush=True)
            print("Predicting on dataset...")
            pred_dict = net.predict(examples, outputs=["ClassId",
                                                       "ClassProgress",
                                                       "ClassConfidence"])
            predictions, prediction_progs, prediction_confs = \
                pred_dict.values()
            print("Building confusion matrix...")
            confusion_matrix = make_confusion_matrix(labels,
                                                     predictions,
                                                     class_names)
            confusion_matrix.to_csv(
                str(Path(results_dir, "confusion_matrix.csv")))
            plot_confusion_matrix(confusion_matrix)
            plt.gcf().savefig(str(Path(results_dir, "confusion_matrix")))
            plt.close(plt.gcf())
            print("Computing per-class metrics...")
            class_metrics = make_class_metrics(confusion_matrix)
            class_metrics.to_csv(str(Path(results_dir, "class_metrics.csv")))
            jaccard = make_jaccard(labels, predictions, class_names)
            jaccard.to_csv(str(Path(results_dir, "jaccard.csv")))
            jaccard_skip = make_jaccard(labels, predictions, class_names,
                                        skip_margin=15)
            jaccard_skip.to_csv(str(Path(results_dir, "jaccard_skip15.csv")))
            score = np.mean(jaccard.values[1:])
            with Path(results_dir, "score.txt").open("w") as f:
                print(score, file=f)
            hit_counts = make_hit_counts(labels, predictions, class_names)
            hit_counts.to_csv(str(Path(results_dir, "hits.csv")))
            print("Plotting class prediction accuracy...")
            results_acc_dir = Path(results_dir, "accuracy")
            if not results_acc_dir.exists():
                results_acc_dir.mkdir(0o755, parents=True)
            plot_accuracy_mass(labels, predictions, class_names, fps=FPS)
            plt.gcf().tight_layout()
            plt.gcf().savefig(
                str(Path(results_acc_dir, "accuracy-allclasses")))
            plt.close(plt.gcf())
            plot_accuracy_mass(labels, predictions, class_names,
                               class_filter=np.arange(1, len(class_names)),
                               fps=FPS)
            plt.gcf().tight_layout()
            plt.gcf().savefig(
                str(Path(results_acc_dir, "accuracy-allgestures")))
            plt.close(plt.gcf())
            for class_id, class_name in enumerate(class_names):
                plot_accuracy_mass(labels, predictions, class_names,
                                   class_filter=class_id, fps=FPS)
                if plt.gcf().axes:
                    plt.gcf().tight_layout()
                class_acc_fig_name = "accuracy-class-{}".format(class_name)
                plt.gcf().savefig(str(Path(results_acc_dir,
                                           class_acc_fig_name)))
                plt.close(plt.gcf())
            if progress_loss > 0:
                print("Plotting progress prediction accuracy...")
                results_prog_dir = Path(results_dir, "progress")
                if not results_prog_dir.exists():
                    results_prog_dir.mkdir(0o755, parents=True)
                try:
                    plot_progress_accuracy(labels, label_progs, predictions,
                                           prediction_progs, class_names)
                    plt.gcf().tight_layout()
                    plt.gcf().savefig(
                        str(Path(results_prog_dir, "progress-all")))
                    plt.close(plt.gcf())
                    plot_progress_accuracy(labels, label_progs, predictions,
                                           prediction_progs, class_names,
                                           hits_only=True)
                    plt.gcf().tight_layout()
                    plt.gcf().savefig(
                        str(Path(results_prog_dir, "progress_hit-all")))
                    plt.close(plt.gcf())
                    for class_id, class_name in enumerate(class_names):
                        try:
                            plot_progress_accuracy(
                                labels, label_progs, predictions,
                                prediction_progs, class_names,
                                class_filter=class_id)
                            if plt.gcf().axes:
                                plt.gcf().tight_layout()
                            class_prog_fig_name = \
                                "progress-class-{}".format(class_name)
                            plt.gcf().savefig(str(Path(results_prog_dir,
                                                       class_prog_fig_name)))
                        except ValueError:
                            pass
                        finally:
                            plt.close(plt.gcf())
                        try:
                            plot_progress_accuracy(
                                labels, label_progs, predictions,
                                prediction_progs, class_names,
                                class_filter=class_id, hits_only=True)
                            if plt.gcf().axes:
                                plt.gcf().tight_layout()
                            class_prog_fig_name = \
                                "progress_hit-class-{}".format(class_name)
                            plt.gcf().savefig(str(Path(results_prog_dir,
                                                       class_prog_fig_name)))
                        except ValueError:
                            pass
                        finally:
                            plt.close(plt.gcf())
                except ValueError:
                    pass
                finally:
                    plt.close(plt.gcf())
            if confidence_loss > 0:
                print("Plotting confidence prediction accuracy...")
                try:
                    plot_confidence_accuracy(labels, label_confs, predictions,
                                             prediction_confs, class_names)
                    plt.gcf().tight_layout()
                    plt.gcf().savefig(
                        str(Path(results_dir, "confidence-all")))
                    plt.close(plt.gcf())
                    for class_id, class_name in enumerate(class_names):
                        try:
                            plot_confidence_accuracy(
                                labels, label_confs, predictions,
                                prediction_confs, class_names,
                                class_filter=class_id)
                            if plt.gcf().axes:
                                plt.gcf().tight_layout()
                            class_conf_fig_name = \
                                "confidence-class-{}".format(class_name)
                            plt.gcf().savefig(
                                str(Path(results_dir, class_conf_fig_name)))
                        except ValueError:
                            pass
                        finally:
                            plt.close(plt.gcf())
                except ValueError:
                    pass
                finally:
                    plt.close(plt.gcf())
                print("Plotting confidence-accuracy trade-off...")
                try:
                    plot_confidence_roc(labels, predictions, prediction_confs,
                                        class_names)
                    plt.gcf().tight_layout()
                    plt.gcf().savefig(
                        str(Path(results_dir, "confacc-all")))
                    plt.close(plt.gcf())
                    for class_id, class_name in enumerate(class_names):
                        try:
                            plot_confidence_roc(labels, predictions,
                                                prediction_confs, class_names,
                                                class_filter=class_id)
                            if plt.gcf().axes:
                                plt.gcf().tight_layout()
                            class_conf_fig_name = \
                                "confacc-class-{}".format(class_name)
                            plt.gcf().savefig(
                                str(Path(results_dir, class_conf_fig_name)))
                        except ValueError:
                            pass
                        finally:
                            plt.close(plt.gcf())
                except ValueError:
                    pass
                finally:
                    plt.close(plt.gcf())
            if sensitivity_analysis:
                print("Computing input sensitivity...")
                sens_dir = Path(results_dir, "sensitivity")
                if not sens_dir.exists():
                    sens_dir.mkdir(0o755, parents=True)
                sens_class, sens_prog = net.input_sensitivity(
                    examples, labels, label_progs)
                np.savez(Path(sens_dir, "sensitivity"),
                         sensitivity_class=sens_class,
                         sensitivity_progress=sens_prog)
                sens_class_df = []
                for s, class_name in zip(sens_class, class_names):
                    df = pd.DataFrame(s, columns=features_filtered,
                                      index=-np.arange(len(s)))
                    df.index.name = "Frame distance"
                    df.columns.name = "Features"
                    file_name = "sensitivity-class-{}.csv".format(class_name)
                    df.to_csv(str(Path(sens_dir, file_name)))
                    sens_class_df.append(df)
                sens_class_all_df = sum(sens_class_df) / len(sens_class_df)
                sens_class_all_df.to_csv(
                    str(Path(sens_dir, "sensitivity-allclasses.csv")))
                sens_prog_df = pd.DataFrame(sens_prog,
                                            columns=features_filtered,
                                            index=-np.arange(len(sens_prog)))
                sens_prog_df.index.name = "Frame distance"
                sens_prog_df.columns.name = "Features"
                sens_prog_df.to_csv(
                    str(Path(sens_dir, "sensitivity-progress.csv")))
                print("Plotting sensitivity analysis...")
                plot_sensitivity(sens_class_all_df, "classes")
                plt.gcf().savefig(
                    str(Path(sens_dir, "sensitivity-allclasses")))
                plt.close(plt.gcf())
                plot_sensitivity(sens_prog_df, "progress")
                plt.gcf().savefig(
                    str(Path(sens_dir, "sensitivity-progress")))
                plt.close(plt.gcf())
                for sens_df, class_name in zip(sens_class_df, class_names):
                    plot_sensitivity(sens_df, class_name)
                    fig_name = "sensitivity-class-{}".format(class_name)
                    plt.gcf().savefig(str(Path(sens_dir, fig_name)))
                    plt.close(plt.gcf())
            print()
            print("=" * 80)
    finally:
        net.close_session()


@contextmanager
def working_directory(path):
    prev_dir = Path.cwd()
    os.chdir(str(path.absolute()))
    try:
        yield
    finally:
        os.chdir(str(prev_dir))


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: {} <experiment file>".format(sys.argv[0]),
              file=sys.stderr)
        sys.exit(1)
    config_file = Path(sys.argv[1])
    config_name = config_file.stem
    with config_file.open("r") as f:
        config = yaml.load(f)
    with working_directory(config_file.parent):
        main(config_name, config)
