# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf


def freeze_session(session, keep_var_names=None, output_names=None,
                   clear_devices=True):
    from tensorflow.python.framework.graph_util import \
        convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        keep_var_names = keep_var_names or []
        global_var_names = [v.op.name for v in tf.global_variables()]
        freeze_var_names = set(global_var_names).difference(keep_var_names)
        freeze_var_names = list(freeze_var_names)
        output_names = output_names or []
        output_names += global_var_names
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session,
                                                      input_graph_def,
                                                      output_names,
                                                      freeze_var_names)
        return frozen_graph


def split_indices(num, validation_split, test_split=None):
    idx = np.random.permutation(num)
    num_train = num
    num_validation = int(np.round(len(idx) * validation_split))
    idx_validation = idx[:num_validation]
    num_train -= num_validation
    if test_split is not None:
        num_test = int(np.round(len(idx) * test_split))
        idx_test = idx[num_validation:(num_validation + num_test)]
        num_train -= num_test
    idx_train = idx[-num_train:]
    if test_split is not None:
        return idx_train, idx_validation, idx_test
    else:
        return idx_train, idx_validation


def examples_subset(examples, labels, idx):
    return [examples[i] for i in idx], [labels[i] for i in idx]
