# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def _latex_escape(text):
    escape_chars = ["\\", "{", "}", "$", "%", "&", "~", "_", "^", "#"]
    for char in escape_chars:
        text = text.replace(char, "\\" + char)
    return text


def make_confusion_matrix(labels, predictions, class_names):
    predictions = [np.array([p], dtype=np.int)
                   if np.isscalar(p) else p for p in predictions]
    labels = [np.full(len(p), l, dtype=np.int)
              if np.isscalar(l) else l for l, p in zip(labels, predictions)]
    labels = np.concatenate(labels)
    predictions = np.concatenate(predictions)
    m = (labels >= 0) & (labels < len(class_names)) \
        & (predictions >= 0) & (predictions < len(class_names))
    mat = np.zeros((len(class_names), len(class_names)), dtype=np.int)
    idx = np.stack((labels[m], predictions[m]))
    np.add.at(mat, tuple(idx), 1)
    mat = pd.DataFrame(mat, index=class_names, columns=class_names)
    mat.axes[0].name = "Class"
    mat.axes[1].name = "Prediction"
    return mat


def make_class_metrics(confusion_matrix):
    hits = np.diag(confusion_matrix)
    precision = hits / np.maximum(confusion_matrix.sum(axis=0), 1)
    precision = pd.Series(precision, index=confusion_matrix.index,
                          name="Precision")
    recall = hits / np.maximum(confusion_matrix.sum(axis=1), 1)
    recall = pd.Series(recall, index=confusion_matrix.columns, name="Recall")
    metrics = pd.concat([precision, recall], axis=1)
    metrics.index.name = "Class"
    return metrics


def make_jaccard(labels, predictions, class_names, skip_margin=None):
    scores = [jaccard_score(labels, predictions, c, skip_margin)
              for c in range(len(class_names))]
    df = pd.DataFrame({"Score": scores}, index=class_names,)
    df.index.name = "Class"
    return df


def jaccard_score(labels, predictions, for_class, skip_margin=None):
    label_all = np.concatenate([label == for_class for label in labels],
                               axis=0)
    pred_all = np.concatenate([pred == for_class for pred in predictions],
                              axis=0)
    if skip_margin is not None and skip_margin >= 0:
        m = np.full(label_all.shape, True)
        idx = np.where(label_all[:-1] ^ label_all[1:])[0] + 1
        for i in idx:
            m[max(i - skip_margin, 0):i + skip_margin] = False
        label_all = label_all[m]
        pred_all = pred_all[m]
    s = np.count_nonzero(label_all & pred_all)
    u = np.count_nonzero(label_all | pred_all)
    if u <= 0:
        return 1.0
    return s / u


def make_hit_counts(labels, predictions, class_names):
    counts = []
    some_hits = []
    most_hits = []
    some_ratios = []
    most_ratios = []
    for class_id in range(len(class_names)):
        label_all = np.concatenate([label == class_id for label in labels],
                                   axis=0)
        pred_all = np.concatenate([pred == class_id for pred in predictions],
                                  axis=0)
        label_all_id = np.cumsum(~label_all)
        label_ids, label_id_counts = np.unique(label_all_id[label_all],
                                               return_counts=True)
        label_ids, label_id_counts = np.unique(label_all_id[label_all],
                                               return_counts=True)
        hit_ids, hit_id_counts = np.unique(label_all_id[label_all & pred_all],
                                           return_counts=True)
        label_id_acc = np.zeros(len(label_ids))
        for i, (lid, c) in enumerate(zip(label_ids, label_id_counts)):
            idx, = np.where(hit_ids == lid)
            if len(idx) > 0:
                label_id_acc[i] = float(hit_id_counts[idx[0]]) / max(c, 1)
        counts.append(len(label_ids))
        some_hits.append(len(hit_ids))
        most_hits.append(np.count_nonzero(label_id_acc >= 0.5))
        some_ratios.append(float(some_hits[-1]) / max(counts[-1], 1))
        most_ratios.append(float(most_hits[-1]) / max(counts[-1], 1))
    df = pd.DataFrame(index=class_names)
    df.index.name = "Class"
    df["Count"] = counts
    df["SomeHits"] = some_hits
    df["MostHits"] = most_hits
    df["SomeHitRatio"] = some_ratios
    df["MostHitRatio"] = most_ratios
    return df


def plot_confusion_matrix(confusion_matrix, nodiag=False):
    confusion_matrix = pd.DataFrame(confusion_matrix.copy())
    confusion_matrix.axes[0].name = "Class"
    confusion_matrix.axes[1].name = "Prediction"
    mask = None
    if nodiag:
        mask = np.eye(len(confusion_matrix)) > 0
    # Escape labels in axes
    for i_axis, axis in enumerate(confusion_matrix.axes):
        axis_name = axis.name
        axis_values = [_latex_escape(str(a)) for a in axis]
        confusion_matrix.set_axis(labels=axis_values, axis=i_axis, inplace=True)
        confusion_matrix.axes[i_axis].name = axis_name
    plt.figure()
    sns.heatmap(confusion_matrix, annot=True, fmt=",.0f", cbar=False,
                vmin=0, robust=True, linewidths=0.1, mask=mask)
    plt.yticks(rotation=0)
    plt.xticks(rotation=45, ha="right")
    plt.title("Confusion matrix")
    plt.gcf().tight_layout(rect=(0, 0, .95, 1))


def plot_accuracy_mass(labels, predictions, class_names, class_filter=None,
                       axes=None, fps=None):
    # Concatenate labels and predictions with intermediate invalid labels
    invalid_label = np.array([-1])
    labels_all = sum(([elem, invalid_label] for elem in labels), [])
    labels_all = np.concatenate(labels_all, axis=0)
    predictions_all = sum(([elem, invalid_label] for elem in predictions), [])
    predictions_all = np.concatenate(predictions_all, axis=0)
    if class_filter is not None:
        if np.isscalar(class_filter):
            class_filter = [class_filter]
        class_filter = np.asarray(class_filter)
        m = np.expand_dims(labels_all, 1) == np.expand_dims(class_filter, 0)
        m = np.logical_or.reduce(m, axis=1)
    else:
        m = np.full(labels_all.shape, True)
    # Filter invalid labels
    m &= (labels_all >= 0) & (labels_all < len(class_names))
    if np.count_nonzero(m) <= 0:
        # No valid examples found!
        return
    # Make relative time positions
    labels_diff = np.diff(labels_all)
    labels_idx = np.concatenate([[0], np.where(labels_diff != 0)[0] + 1],
                                axis=0)
    labels_length = np.diff(np.concatenate([labels_idx, [len(labels_all)]],
                                           axis=0))
    gesture_ids = np.repeat(np.arange(len(labels_length)), labels_length)
    if fps is None:
        time_scale = 1.
    else:
        time_scale = 1. / fps
    relative_time = np.concatenate([np.arange(lab_len) * time_scale
                                    for lab_len in labels_length],
                                   axis=0)
    # Compute hit and miss
    labels_all = labels_all[m]
    predictions_all = predictions_all[m]
    gesture_ids = gesture_ids[m]
    relative_time = relative_time[m]
    hits = labels_all == predictions_all
    bins = np.linspace(np.min(relative_time), np.max(relative_time), 31)
    hist_total, _ = np.histogram(relative_time, bins)
    hist_hits, _ = np.histogram(relative_time[hits], bins)
    hist_miss = hist_total - hist_hits
    # Compute gestures on each bin
    total_gestures = float(len(np.unique(gesture_ids)))
    bin_ratio = np.full(len(bins) - 1, .0)
    for i_bin, (bin_start, bin_end) in enumerate(zip(bins[:-1], bins[1:])):
        # Last bin is closed interval
        if i_bin < len(bins) - 2:
            bin_m = (relative_time >= bin_start) & (relative_time < bin_end)
        else:
            bin_m = (relative_time >= bin_start) & (relative_time <= bin_end)
        bin_gestures = len(np.unique(gesture_ids[bin_m]))
        bin_ratio[i_bin] = bin_gestures / total_gestures
    # Compute accuracy
    hist_mask = hist_total > 0
    hist_acc = (hist_hits[hist_mask] / hist_total[hist_mask])
    hist_err = (hist_miss[hist_mask] / hist_total[hist_mask])
    # Scale bins
    hist_acc = hist_acc * bin_ratio[hist_mask]
    hist_err = hist_err * bin_ratio[hist_mask]
    # Slightly stretch ends to match begin and end
    hist_locs = ((bins[:-1] + bins[1:]) / 2.)[hist_mask]
    if len(hist_locs) > 0:
        hist_locs[0] = bins[0]
        hist_locs[-1] = bins[-1]
    # Plot
    axes = axes or plt.figure().gca()
    axes.stackplot(hist_locs, hist_acc, hist_err, labels=["Hit", "Miss"],
                   colors=["g", "r"])
    axes.legend(loc="upper right")
    axes.set_xlim(hist_locs[0], hist_locs[-1])
    axes.set_ylim(0, 1)
    time_unit = "s" if fps is not None else "frames"
    axes.set_xlabel("Time ({})".format(time_unit))
    axes.set_ylabel("Gestures")
    # Ticks as percent
    axes.set_yticklabels(["${:3.0f}\\%$".format(y * 100)
                          for y in axes.get_yticks()])
    title_suffix = ""
    if class_filter is not None and len(class_filter) == 1:
        plot_class_name = _latex_escape(class_names[class_filter[0]])
        title_suffix = " for {}".format(plot_class_name)
    axes.set_title("Accuracy distribution{}".format(title_suffix))


def plot_confidence_roc(labels, predictions, predictions_confs, class_names,
                        class_filter=None, axes=None):
    if len(labels) <= 0:
        return
    labels_all = np.concatenate(labels, axis=0)
    predictions_all = np.concatenate(predictions, axis=0)
    predictions_confs_all = np.concatenate(predictions_confs, axis=0)
    # Filter by class and invalid labels
    if class_filter is not None:
        if np.isscalar(class_filter):
            class_filter = [class_filter]
        m = np.expand_dims(labels_all, 1) == np.expand_dims(class_filter, 0)
        m = np.logical_or.reduce(m, axis=1)
    else:
        m = np.full(labels_all.shape, True)
    m &= (labels_all >= 0) & (labels_all < len(class_names))
    labels_all = labels_all[m]
    predictions_all = predictions_all[m]
    predictions_confs_all = predictions_confs_all[m]
    # Compute ROC
    conf_threshold = np.linspace(0, 1, 101)
    hits = labels_all == predictions_all
    conf_mask = (predictions_confs_all[:, np.newaxis] >=
                 conf_threshold[np.newaxis, :])
    hits_mask = hits[:, np.newaxis] & conf_mask
    total_conf = np.count_nonzero(conf_mask, axis=0)
    total_conf[total_conf <= 0] = 1
    hits_conf = np.count_nonzero(hits_mask, axis=0)
    acc_conf = hits_conf / total_conf
    predictions_conf = total_conf / len(labels_all)
    # Plot
    axes = axes or plt.figure().gca()
    pred_plt, = axes.plot(conf_threshold, predictions_conf,
                          label="Predictions", c="r")
    acc_plt, = axes.plot(conf_threshold, acc_conf, label="Accuracy", c="g")
    axes.set_xlim(0, 1)
    axes.set_ylim(0, 1)
    axes.set_xlabel("Confidence threshold")
    axes.set_ylabel("")
    axes.legend([acc_plt, pred_plt], ["Accuracy", "Predictions"],
                loc="lower left", frameon=True)
    # Ticks as percent
    axes.set_xticklabels(["${:3.0f}\\%$".format(x * 100)
                          for x in axes.get_xticks()])
    axes.set_yticklabels(["${:3.0f}\\%$".format(y * 100)
                          for y in axes.get_yticks()])
    title_suffix = ""
    if class_filter is not None and len(class_filter) == 1:
        plot_class_name = _latex_escape(class_names[class_filter[0]])
        title_suffix = " for {}".format(plot_class_name)
    axes.set_title("Confidence/accuracy trade-off{}".format(title_suffix))


def plot_progress_accuracy(labels, label_progs, predictions, prediction_progs,
                           class_names, class_filter=None, hits_only=False,
                           axes=None):
    if len(labels) <= 0:
        return
    labels_all = np.concatenate(labels, axis=0)
    label_progs_all = np.concatenate(label_progs, axis=0)
    predictions_all = np.concatenate(predictions, axis=0)
    prediction_progs_all = np.concatenate(prediction_progs, axis=0)
    # Filter by class and invalid labels
    if class_filter is not None:
        if np.isscalar(class_filter):
            class_filter = [class_filter]
        m = np.expand_dims(labels_all, 1) == np.expand_dims(class_filter, 0)
        m = np.logical_or.reduce(m, axis=1)
    else:
        m = np.full(labels_all.shape, True)
    m &= (labels_all >= 0) & (labels_all < len(class_names))
    m &= (label_progs_all >= 0) & (label_progs_all <= 1)
    labels_all = labels_all[m]
    label_progs_all = label_progs_all[m]
    predictions_all = predictions_all[m]
    prediction_progs_all = prediction_progs_all[m]
    if len(labels_all) <= 0:
        raise ValueError("No valid progress data found.")
    # Binning
    bins = np.linspace(0, 1, 31)
    df = pd.DataFrame(index=np.arange(len(labels_all)))
    df["Label"] = labels_all
    df["LabelProgress"] = label_progs_all
    df["Prediction"] = predictions_all
    df["PredictionProgress"] = prediction_progs_all
    df["Hit"] = df["Label"] == df["Prediction"]
    df["PredictionError"] = df["PredictionProgress"] - df["LabelProgress"]
    df["Bin"] = pd.cut(df["LabelProgress"], bins)
    acc = df.groupby("Bin")["Hit"].mean().fillna(0)
    if hits_only:
        df_bin = df[df["Hit"]].groupby("Bin")
    else:
        df_bin = df.groupby("Bin")
    # prog = df_bin["LabelProgress"].mean().fillna(0)
    pred_prog_err = df_bin["PredictionError"].median().fillna(0)
    pred_prog_low_ex = df_bin["PredictionError"].quantile(0.025).fillna(0)
    pred_prog_low_q = df_bin["PredictionError"].quantile(0.25).fillna(0)
    pred_prog_high_ex = df_bin["PredictionError"].quantile(0.975).fillna(0)
    pred_prog_high_q = df_bin["PredictionError"].quantile(0.75).fillna(0)
    # Plot
    axes = axes or plt.figure().gca()
    plot_x = bins[:-1] + np.diff(bins) / 2.0
    plot_x[0] = 0
    plot_x[-1] = 1
    acc_plt, = axes.plot(plot_x, acc, "g--")
    # prog_plt, = axes.plot(plot_x, prog, "b")
    prog_plt, = axes.plot([0, 1], [0, 1], "b")
    err_ex_plt = axes.fill_between(plot_x, plot_x + pred_prog_high_ex,
                                   plot_x + pred_prog_low_ex,
                                   alpha=0.25, color="c")
    err_plt = axes.fill_between(plot_x, plot_x + pred_prog_high_q,
                                plot_x + pred_prog_low_q,
                                alpha=0.5, color="c")
    prog_pred_plt, = axes.plot(plot_x, plot_x + pred_prog_err, "c")

    axes.set_xlim(0, 1)
    axes.set_ylim(0, 1)
    axes.set_xlabel("Gesture progress")
    axes.set_ylabel("")
    axes.legend([prog_plt, (err_ex_plt, err_plt, prog_pred_plt), acc_plt],
                ["Progress", "Prediction", "Accuracy"],
                loc="lower right", frameon=True)
    # Ticks as percent
    axes.set_xticklabels(["${:3.0f}\\%$".format(x * 100)
                          for x in axes.get_xticks()])
    axes.set_yticklabels(["${:3.0f}\\%$".format(y * 100)
                          for y in axes.get_yticks()])
    title_suffix = ""
    if class_filter is not None and len(class_filter) == 1:
        plot_class_name = _latex_escape(class_names[class_filter[0]])
        title_suffix = " for {}".format(plot_class_name)
    hit_suffix = " on hit" if hits_only else ""
    axes.set_title("Gesture progress prediction{}{}".format(
        hit_suffix, title_suffix))


def plot_confidence_accuracy(labels, label_confs, predictions,
                             prediction_confs, class_names, class_filter=None,
                             axes=None):
    if len(labels) <= 0:
        return
    labels_all = np.concatenate(labels, axis=0)
    label_confs_all = np.concatenate(label_confs, axis=0)
    predictions_all = np.concatenate(predictions, axis=0)
    prediction_confs_all = np.concatenate(prediction_confs, axis=0)
    label_lens_all = np.concatenate(
        [np.diff(np.concatenate([[0],
                                 np.where(np.diff(label) != 0)[0] + 1,
                                 [len(label)]], axis=0))
         for label in labels], axis=0)
    label_progs_all = np.concatenate(
        [np.linspace(0, 1, l) for l in label_lens_all], axis=0)
    # Filter by class and invalid labels
    if class_filter is not None:
        if np.isscalar(class_filter):
            class_filter = [class_filter]
        m = np.expand_dims(labels_all, 1) == np.expand_dims(class_filter, 0)
        m = np.logical_or.reduce(m, axis=1)
    else:
        m = np.full(labels_all.shape, True)
    m &= (labels_all >= 0) & (labels_all < len(class_names))
    m &= (label_confs_all >= 0) & (label_confs_all <= 1)
    labels_all = labels_all[m]
    label_confs_all = label_confs_all[m]
    predictions_all = predictions_all[m]
    prediction_confs_all = prediction_confs_all[m]
    label_progs_all = label_progs_all[m]
    if len(labels_all) <= 0:
        raise ValueError("No valid confidence data found.")
    # Binning
    bins = np.linspace(0, 1, 31)
    df = pd.DataFrame(index=np.arange(len(labels_all)))
    df["Label"] = labels_all
    df["LabelProgress"] = label_progs_all
    df["LabelConfidence"] = label_confs_all
    df["Prediction"] = predictions_all
    df["PredictionConfidence"] = prediction_confs_all
    df["Hit"] = df["Label"] == df["Prediction"]
    df["PredictionError"] = df["PredictionConfidence"] - df["LabelConfidence"]
    df["Bin"] = pd.cut(df["LabelProgress"], bins)
    df_bin = df.groupby("Bin")
    acc = df_bin["Hit"].mean().fillna(0)
    conf = df_bin["LabelConfidence"].mean().fillna(0)
    pred_conf_err = df_bin["PredictionError"].median().fillna(0)
    pred_conf_low_ex = df_bin["PredictionError"].quantile(0.025).fillna(0)
    pred_conf_low_q = df_bin["PredictionError"].quantile(0.25).fillna(0)
    pred_conf_high_ex = df_bin["PredictionError"].quantile(0.975).fillna(0)
    pred_conf_high_q = df_bin["PredictionError"].quantile(0.75).fillna(0)
    # Plot
    axes = axes or plt.figure().gca()
    plot_x = bins[:-1] + np.diff(bins) / 2.0
    plot_x[0] = 0
    plot_x[-1] = 1
    acc_plt, = axes.plot(plot_x, acc, "g--")
    conf_plt, = axes.plot(plot_x, conf, "b")
    err_ex_plt = axes.fill_between(plot_x, conf + pred_conf_high_ex,
                                   conf + pred_conf_low_ex,
                                   alpha=0.25, color="c")
    err_plt = axes.fill_between(plot_x, conf + pred_conf_high_q,
                                conf + pred_conf_low_q,
                                alpha=0.5, color="c")
    conf_pred_plt, = axes.plot(plot_x, conf + pred_conf_err, "c")

    axes.set_xlim(0, 1)
    axes.set_ylim(0, 1)
    axes.set_xlabel("Gesture confidence")
    axes.set_ylabel("")
    axes.legend([conf_plt, (err_ex_plt, err_plt, conf_pred_plt), acc_plt],
                ["Confidence", "Prediction", "Accuracy"],
                loc="lower right", frameon=True)
    # Ticks as percent
    axes.set_xticklabels(["${:3.0f}\\%$".format(x * 100)
                          for x in axes.get_xticks()])
    axes.set_yticklabels(["${:3.0f}\\%$".format(y * 100)
                          for y in axes.get_yticks()])
    title_suffix = ""
    if class_filter is not None and len(class_filter) == 1:
        plot_class_name = _latex_escape(class_names[class_filter[0]])
        title_suffix = " for {}".format(plot_class_name)
    axes.set_title("Gesture progress prediction{}".format(title_suffix))


def plot_sensitivity(sensitivity, name=None):
    if isinstance(sensitivity, pd.DataFrame):
        idx = sensitivity.index
    else:
        idx = -np.arange(len(sensitivity))
    df = pd.DataFrame(sensitivity.copy(), index=idx)
    cols = [_latex_escape(str(c)) for c in df.columns]
    df.rename(columns=dict(zip(df.columns, cols)), inplace=True)
    df.axes[0].name = "Frame distance"
    df.axes[1].name = "Features"
    gridspec = {"width_ratios": (.025, .9)}
    fig, (ax_cbar, ax_sens) = plt.subplots(1, 2, gridspec_kw=gridspec)
    xtick_step = int(round((len(df) / 10.0) / 5)) * 5
    if xtick_step <= 1:
        xtick_step = True
    sns.heatmap(df.T, ax=ax_sens, xticklabels=xtick_step, vmin=0, robust=False,
                cbar_ax=ax_cbar)
    ax_sens.invert_xaxis()
    ax_sens.yaxis.tick_right()
    ax_sens.yaxis.set_label_position("right")
    for ytick in ax_sens.get_yticklabels():
        ytick.set_rotation(0)
        ytick.set_ha("left")
    ax_cbar.yaxis.tick_left()
    ax_cbar.yaxis.set_label_position("left")
    title = "Sensitivity analysis"
    if name:
        title += " for {}".format(_latex_escape(name))
    ax_sens.set_title(title)
    ax_sens.figure.tight_layout()
