# -*- coding: utf-8 -*-

from collections import namedtuple, OrderedDict
from datetime import datetime
from pathlib import Path
import random
import sys

import numpy as np
import pandas as pd
import tensorflow as tf


ConvolutionConfig = namedtuple("ConvolutionConfig", [
    # Convolution kernel size
    "kernel_size",
    # Convolution output channels
    "channels",
    # Stride size
    "stride",
    # Dilation rate
    "dilation_rate",
])

PoolingConfig = namedtuple("PoolingConfig", [
    # "MAX" or "AVG"
    "type",
    # Pooling window size
    "window_size",
    # Stride size
    "stride",
    # Dilation rate
    "dilation_rate",
])

DilatedStackConfig = namedtuple("DilatedStackConfig", [
    # Dilated stack base
    "base",
    # Number of channels at each level of the stack
    "channels",
])

LocalRecurrentStackConfig = namedtuple("LocalRecurrentStackConfig", [
    # Number of layers
    "depth",
    # Cell size
    "size",
])

TimeNetConfig = namedtuple("TimeNetConfig", [
    # Convolution stack configuration as list of pairs
    # (ConvolutionConfig, PoolingConfig)
    "convolution_stack",
    # Dilated stack configuration
    "dilated_stack",
    # List of dense layer sizes
    "dense_stack",
    # Locally recurrent stack configuration
    "local_recurrent_stack",
    # List of recurrent cell sizes
    "recurrent_stack",
    # True to use softmax on the output layer, False to use sigmoid
    "softmax",
    # If not None, create a separate "default class" neuron
    # for the specified class
    "default_class",
    # Number of samples used to smooth the output
    "output_smoothing",
    # If greater than 1 the progress value will be quantized
    # in the given number of classes
    "progress_quantization",
    # If greater than 1 the confidence value will be quantized
    # in the given number of classes
    "confidence_quantization",
    # Time step size for each parameter update
    # (None to use always full sequences)
    "step_size",
    # Batch size for each parameter update
    "batch_size",
    # Class progress loss weight
    "progress_loss",
    # Class confidence loss weight
    "confidence_loss",
    # Weight decay regularization
    "weight_decay",
    # Consistence loss penalization
    "consistence_loss",
    # Dropout probability on each layer
    "dropout",
])


class TimeNet(object):

    MODEL_OUTPUTS = "timenet_model_outputs"
    WEIGHT_VARIABLES = "timenet_weight_variables"
    WEIGHT_UPDATES = "timenet_weight_updates"
    WEIGHT_UPDATE_OP = "timenet_weight_update_op"
    STATE_VARIABLES = "timenet_state_variables"
    STATE_UPDATES = "timenet_state_updates"
    STATE_UPDATE_OP = "timenet_state_update_op"
    STATE_RESET_OP = "timenet_state_reset_op"
    CURRENT_MOVING_AVERAGE_VARIABLES = "timenet_current_moving_average_variables"
    UPDATE_MOVING_AVERAGE_OP = "timenet_update_moving_average_op"
    METRIC_VARIABLES = "timenet_metric_variables"
    METRIC_UPDATES = "timenet_metric_updates"
    METRIC_UPDATE_OP = "timenet_metric_update_op"
    METRIC_RESET_OP = "timenet_metric_reset_op"
    RESET_OP = "timenet_reset_op"
    INPUT_ALIASES = "timenet_input_aliases"
    OUTPUT_ALIASES = "timenet_output_aliases"

    def __init__(self, data_dim, class_names, config, dtype=tf.float32,
                 stateful=True, fixed_step=False, mask_value=0,
                 use_dropout=True, input_filter=None, name=None, seed=None):
        self._data_dim = data_dim
        self._class_names = list(class_names)
        self._config = config
        self._dtype = tf.as_dtype(dtype)
        self._stateful = stateful
        self._fixed_step = fixed_step
        self._mask_value = mask_value
        self._use_dropout = use_dropout
        self._input_filter = input_filter

        if fixed_step and config.step_size is None:
            raise ValueError("You must provide a valid step size "
                             "with fixed step.")

        if not name:
            name = type(self).__name__.split(".")[-1]
        self._name = str(name)

        self._graph = tf.Graph()
        # The session is lazyly created through "session" property
        self._session = None

        # Placeholders
        self._input = None
        self._input_buffer = None
        self._dropout = None
        self._labels = None
        self._label_progress = None
        self._label_confidence = None
        self._class_weights = None
        # Outputs
        self._outputs = OrderedDict()
        # Training
        self._training = None
        self._global_step = None
        self._train_step = None
        # Metrics
        self._metrics = OrderedDict()
        # Gradients
        self._class_grad = None
        self._class_progress_grad = None
        self._class_progress_conf = None
        # Saver
        self._saver = None

        with self._graph.as_default(), tf.variable_scope(self._name):
            if seed is not None:
                tf.set_random_seed(seed)
            self._make_time_net()
            self._make_aliases()
        self._graph.finalize()

        # Training history
        self._train_history = None
        self._best_loss = None
        self._best_weights = None
        self._latest_weights = None
        self._current_weights = None
        self.clear_train_history()

        # Logging
        self._last_log_len = 0

    def _make_time_net(self):
        config = self._config
        dtype = self._dtype
        num_classes = len(self._class_names)
        mask_value = self._mask_value

        # Training phase
        self._training = tf.placeholder_with_default(
            tf.constant(False, dtype=tf.bool), [], name="TrainingPhase")
        # Global training step
        self._global_step = tf.placeholder_with_default(
            tf.constant(0, dtype=tf.int32), [], name="GlobalStep")
        # Dropout value
        self._dropout = tf.placeholder_with_default(
            tf.constant(config.dropout, dtype=dtype), [], name="Dropout")
        step_size = config.step_size if self._fixed_step else None
        with tf.variable_scope("Input"):
            # Data input
            self._input = tf.placeholder(dtype=dtype,
                                         shape=(config.batch_size,
                                                step_size,
                                                self._data_dim),
                                         name="Data")
            if self._fixed_step:
                input_step_size = step_size
            else:
                input_step_size = tf.shape(self._input)[1]
            # Make mask and lengths
            if mask_value is not None:
                input_mask = tf.reduce_any(
                    tf.not_equal(self._input, mask_value), axis=-1)
            else:
                input_mask = tf.fill(
                    (config.batch_size, input_step_size), True)
            input_lengths = tf.reduce_sum(
                tf.cast(input_mask, dtype=tf.int32), axis=1, name="Lengths")
            input = self._input
            with tf.variable_scope("InputFilter"):
                if self._input_filter is not None:
                    input = self._input_filter(self, input)
        with tf.variable_scope("Buffer"):
            input_buffer = self._make_input_buffer(input)
        # Make layers
        y = input_buffer
        # with tf.variable_scope("Normalization"):
        #     y = self._normalize(y, self._training, self._global_step)
        with tf.variable_scope("ConvolutionStack"):
            y = self._make_convolution_stack(y)
        with tf.variable_scope("DilatedStack"):
            y = self._make_dilated_stack(y)
        with tf.variable_scope("DenseStack"):
            y = self._make_dense_stack(y)
        with tf.variable_scope("LocalRecurrentStack"):
            y = self._make_local_recurrent_stack(y)
        # Discard extra data
        y = y[:, -input_step_size:, :]
        with tf.variable_scope("RecurrentStack"):
            y = self._make_recurrent_stack(y, input_lengths)
        # Output layer
        with tf.variable_scope("Output"):
            with tf.variable_scope("Logits"):
                class_logits = self._make_class_logits(y)
            # Outputs
            if config.softmax:
                class_dist = tf.nn.softmax(class_logits)
            else:
                class_dist = tf.sigmoid(class_logits)
            default_class_idx = None
            if config.default_class:
                default_class_idx = self._class_names.index(
                    config.default_class)
            if default_class_idx is not None:
                # Default class = 1 - max(non-default classes)
                default_class_prob = 1 - tf.reduce_max(class_dist, axis=-1,
                                                       keepdims=True)
                full_class_dist = tf.concat(
                    [class_dist[..., :default_class_idx],
                     default_class_prob,
                     class_dist[..., default_class_idx:]],
                    axis=-1)
            else:
                full_class_dist = class_dist
            with tf.variable_scope("Smoothing"):
                full_class_dist = self._make_smoothed_output(full_class_dist)
            # Apply state update _after_ every stateful component
            if self._stateful:
                with tf.variable_scope("StateUpdate"):
                    class_logits, full_class_dist = self._make_state_update(
                        [class_logits, full_class_dist])
            full_class_dist = tf.identity(full_class_dist,
                                          name="ClassDistribution")
            self._outputs["ClassDistribution"] = full_class_dist
            class_id = tf.cast(tf.argmax(full_class_dist, axis=-1),
                               dtype=tf.int32, name="ClassId")
            self._outputs["ClassId"] = class_id
            grid = tf.meshgrid(tf.range(config.batch_size),
                               tf.range(input_step_size), indexing="ij")
            class_id_grid = tf.stack(grid + [class_id], axis=-1)
            class_probability = tf.gather_nd(full_class_dist, class_id_grid,
                                             name="ClassProbability")
            self._outputs["ClassProbability"] = class_probability
            class_names = tf.constant(self._class_names, dtype=tf.string,
                                      name="ClassNames")
            self._outputs["ClassNames"] = class_names
            class_name = tf.gather_nd(class_names,
                                      tf.expand_dims(class_id, axis=-1),
                                      name="ClassName")
            self._outputs["ClassName"] = class_name
            # Class progress
            with tf.variable_scope("ClassProgress"):
                prog_logits = self._make_progress_logits(y)
                prog_quant = config.progress_quantization or 0
                if prog_quant > 1:
                    # Quantized progress
                    prog_segs = tf.constant(np.linspace(0, 1, prog_quant),
                                            dtype=dtype)
                    prog_dist = tf.nn.softmax(prog_logits)
                    class_progs = tf.reduce_sum(prog_dist * prog_segs, axis=-1,
                                                name="ClassProgress")
                else:
                    # Non-quantized progress
                    class_progs = tf.sigmoid(tf.squeeze(prog_logits, axis=-1),
                                             name="ClassProgress")
            self._outputs["ClassProgress"] = class_progs
            # Class confidence
            with tf.variable_scope("ClassConfidence"):
                conf_logits = self._make_confidence_logits(y)
                conf_quant = config.confidence_quantization or 0
                if conf_quant > 1:
                    # Quantized progress
                    conf_segs = tf.constant(np.linspace(0, 1, conf_quant),
                                            dtype=dtype)
                    conf_dist = tf.nn.softmax(conf_logits)
                    class_confs = tf.reduce_sum(conf_dist * conf_segs, axis=-1,
                                                name="ClassConfidence")
                else:
                    # Non-quantized progress
                    class_confs = tf.sigmoid(tf.squeeze(conf_logits, axis=-1),
                                             name="ClassConfidence")
            self._outputs["ClassConfidence"] = class_confs
            # Save to outputs collection
            for out in self._outputs.values():
                tf.add_to_collection(self.MODEL_OUTPUTS, out)
        # Training
        with tf.variable_scope("Train"):
            # Class labels
            self._labels = tf.placeholder(dtype=tf.int32,
                                          shape=(config.batch_size, step_size),
                                          name="Labels")
            labels_mask = (self._labels >= 0) & (self._labels < num_classes)
            labels_mask &= input_mask
            labels_valid = tf.where(labels_mask,
                                    self._labels,
                                    tf.zeros_like(self._labels))
            labels_mask_f = tf.cast(labels_mask, dtype=dtype)
            # Label progress
            self._label_progress = tf.placeholder(
                dtype=dtype, shape=(config.batch_size, step_size),
                name="LabelProgress")
            label_progs_mask = (self._label_progress >= 0) \
                & (self._label_progress <= 1)
            label_progs_mask &= labels_mask
            label_progs_valid = tf.where(label_progs_mask,
                                         self._label_progress,
                                         tf.zeros_like(self._label_progress))
            label_progs_mask_f = tf.cast(label_progs_mask, dtype=dtype)
            # Label confidence
            self._label_confidence = tf.placeholder(
                dtype=dtype, shape=(config.batch_size, step_size),
                name="LabelConfidence")
            label_confs_mask = (self._label_confidence >= 0) \
                & (self._label_confidence <= 1)
            label_confs_mask &= labels_mask
            label_confs_valid = tf.where(label_confs_mask,
                                         self._label_confidence,
                                         tf.ones_like(self._label_confidence))
            label_confs_mask_f = tf.cast(label_confs_mask, dtype=dtype)
            # Label weights
            self._class_weights = tf.placeholder_with_default(
                tf.ones(num_classes, dtype=dtype), [num_classes],
                name="ClassWeights")
            class_weights_norm = \
                num_classes * (self._class_weights /
                               tf.reduce_sum(self._class_weights))
            label_weights = tf.gather_nd(class_weights_norm,
                                         tf.expand_dims(labels_valid, axis=-1))
            label_weights *= labels_mask_f
            label_progress_weights = label_weights * label_progs_mask_f
            label_confidence_weights = label_weights * label_confs_mask_f
            # Metrics
            num_frames = tf.reduce_sum(labels_mask_f, name="NumFrames")
            prev_num_frames = tf.get_variable(
                "NumFramesPrevious", shape=num_frames.get_shape(),
                dtype=num_frames.dtype, initializer=tf.zeros_initializer(),
                trainable=False)
            tf.add_to_collection(self.METRIC_VARIABLES, prev_num_frames)
            total_frames = tf.add(prev_num_frames, num_frames,
                                  name="NumFramesUpdate")
            tf.add_to_collection(self.METRIC_UPDATES, total_frames)
            with tf.variable_scope("Loss"):
                # Update moving averages
                upd_moving_avg = tf.get_collection(
                    self.UPDATE_MOVING_AVERAGE_OP)
                with tf.control_dependencies(upd_moving_avg):
                    class_logits = tf.identity(class_logits)
                # Cross-entropy
                if config.softmax:
                    if default_class_idx is not None:
                        default_class_mask = labels_valid != default_class_idx
                        default_class_mask_f = tf.cast(
                            default_class_mask, dtype=dtype)
                        labels_valid_classes = tf.where(
                            labels_valid <= default_class_idx,
                            labels_valid, labels_valid - 1)
                    else:
                        default_class_mask_f = tf.constant(1, dtype=dtype)
                        labels_valid_classes = labels_valid
                    batch_loss = \
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            labels=labels_valid_classes, logits=class_logits)
                    batch_loss *= default_class_mask_f
                else:
                    labels_valid_1h = tf.one_hot(labels_valid, num_classes,
                                                 dtype=dtype)
                    if default_class_idx is not None:
                        labels_valid_1h = tf.concat(
                            [labels_valid_1h[..., :default_class_idx],
                             labels_valid_1h[..., default_class_idx + 1:]],
                            axis=-1)
                    batch_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=labels_valid_1h, logits=class_logits)
                    batch_loss = tf.reduce_sum(batch_loss, axis=-1)
                batch_loss *= label_weights
                # Weight by confidence
                batch_loss *= label_confs_valid
                batch_loss = tf.reduce_sum(batch_loss)
                batch_loss = tf.divide(batch_loss, tf.cast(num_frames, dtype),
                                       name="CrossEntropy")
                self._metrics["CrossEntropyLoss"] = \
                    self._make_averaged_metric("CrossEntropyLoss", batch_loss,
                                               num_frames, prev_num_frames,
                                               total_frames)
                # Progress loss
                if config.progress_loss > 0:
                    if prog_quant > 1:
                        # Quantized progress loss
                        label_quant_progs_valid = tf.to_int32(
                            label_progs_valid // (1 / (2 * (prog_quant - 1))))
                        label_quant_progs_valid = \
                            (label_quant_progs_valid + 1) // 2
                        batch_progress_loss = \
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                labels=label_quant_progs_valid,
                                logits=prog_logits)
                    else:
                        # Non-quantized progress loss
                        batch_progress_loss = tf.squared_difference(
                            class_progs, label_progs_valid)
                    batch_progress_loss *= label_progress_weights
                    batch_progress_loss *= config.progress_loss
                    batch_progress_loss = tf.reduce_sum(batch_progress_loss)
                    batch_progress_loss = tf.divide(
                        batch_progress_loss, tf.cast(num_frames, dtype),
                        name="Progress")
                    self._metrics["ProgressLoss"] = \
                        self._make_averaged_metric("ProgressLoss",
                                                   batch_progress_loss,
                                                   num_frames, prev_num_frames,
                                                   total_frames)
                    batch_loss += batch_progress_loss
                # Confidence loss
                if config.confidence_loss > 0:
                    if conf_quant > 1:
                        # Quantized progress loss
                        label_quant_confs_valid = tf.to_int32(
                            label_confs_valid // (1 / (2 * (conf_quant - 1))))
                        label_quant_confs_valid = \
                            (label_quant_confs_valid + 1) // 2
                        batch_confidence_loss = \
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                labels=label_quant_confs_valid,
                                logits=conf_logits)
                    else:
                        # Non-quantized progress loss
                        batch_confidence_loss = tf.squared_difference(
                            class_confs, label_confs_valid)
                    batch_confidence_loss *= label_confidence_weights
                    batch_confidence_loss *= config.confidence_loss
                    batch_confidence_loss = tf.reduce_sum(
                        batch_confidence_loss)
                    batch_confidence_loss = tf.divide(
                        batch_confidence_loss, tf.cast(num_frames, dtype),
                        name="Confidence")
                    self._metrics["ConfidenceLoss"] = \
                        self._make_averaged_metric("ConfidenceLoss",
                                                   batch_confidence_loss,
                                                   num_frames, prev_num_frames,
                                                   total_frames)
                    batch_loss += batch_confidence_loss
                # Regularization
                if config.weight_decay > 0:
                    reg_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    weight_decay = tf.constant(config.weight_decay,
                                               dtype=dtype)
                    reg = tf.multiply(weight_decay, tf.add_n(reg_losses),
                                      name="Regularization")
                    self._metrics["RegularizationLoss"] = \
                        self._make_averaged_metric("RegularizationLoss", reg,
                                                   num_frames, prev_num_frames,
                                                   total_frames)
                    batch_loss += reg
                # Consistence
                if config.consistence_loss > 0:
                    # class_dist / full_class_dist?
                    class_diff = tf.concat([class_dist[:, :1, :],
                                            class_dist], axis=1)
                    class_diff = tf.squared_difference(
                        class_diff[:, 1:, :], class_diff[:, :-1, :])
                    class_diff = tf.reduce_sum(class_diff, axis=-1)
                    class_diff = class_diff * labels_mask_f
                    cons_loss = \
                        config.consistence_loss * tf.reduce_sum(class_diff)
                    cons_loss = tf.divide(cons_loss,
                                          tf.cast(num_frames, self._dtype),
                                          name="Consistence")
                    self._metrics["ConsistenceLoss"] = \
                        self._make_averaged_metric("ConsistenceLoss",
                                                   cons_loss, num_frames,
                                                   prev_num_frames,
                                                   total_frames)
                    batch_loss += cons_loss
                self._metrics["Loss"] = \
                    self._make_averaged_metric("Loss", batch_loss, num_frames,
                                               prev_num_frames, total_frames)
            with tf.variable_scope("JaccardScore"):
                _, iou_confmat = tf.metrics.mean_iou(labels_valid, class_id,
                                                     num_classes,
                                                     weights=labels_mask_f)
                hits = tf.diag_part(iou_confmat)
                total = (tf.reduce_sum(iou_confmat, axis=0) +
                         tf.reduce_sum(iou_confmat, axis=1) - hits)
                if default_class_idx is not None:
                    hits = tf.concat([hits[:default_class_idx],
                                      hits[default_class_idx + 1:]], axis=0)
                    total = tf.concat([total[:default_class_idx],
                                       total[default_class_idx + 1:]], axis=0)
                jaccard = hits / tf.maximum(total, 1)
                jaccard_score = tf.reduce_mean(jaccard, name="JaccardScore")
                self._metrics["JaccardScore"] = jaccard_score
            with tf.variable_scope("Accuracy"):
                # TODO Replace with tf.metrics
                label_hits = tf.equal(self._labels, class_id)
                accuracy_weigths = labels_mask_f  # * label_confs_valid
                hits = tf.reduce_sum(tf.cast(label_hits, dtype=dtype) *
                                     accuracy_weigths)
                batch_accuracy = tf.divide(
                    hits, tf.reduce_sum(accuracy_weigths),
                    name="HitRate")
                batch_accuracy = tf.cast(batch_accuracy, dtype)
                self._metrics["Accuracy"] = self._make_averaged_metric(
                    "Accuracy", batch_accuracy, num_frames,
                    prev_num_frames, total_frames)
            # Update metrics _after_ batch loss is computed
            if self._stateful:
                with tf.variable_scope("MetricUpdate"):
                    batch_loss, *metric_updates = self._make_metrics_update(
                        [batch_loss] + list(self._metrics.values()))
                    for k, upd in zip(self._metrics.keys(), metric_updates):
                        self._metrics[k] = upd
            # Optimize
            optimizer = tf.train.AdamOptimizer(learning_rate=1e-4,
                                               epsilon=1e-3)
            # optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-1)
            # optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-4,
            #                                       epsilon=1e-5)
            self._train_step = optimizer.minimize(batch_loss, name="TrainStep")
        # Initialization operation
        tf.add_to_collection(tf.GraphKeys.INIT_OP,
                             tf.variables_initializer(tf.global_variables(),
                                                      name="Init"))
        # Reset operation
        reset_op = tf.group(*(tf.get_collection(self.STATE_RESET_OP) +
                              tf.get_collection(self.METRIC_RESET_OP) +
                              [tf.local_variables_initializer()]),
                            name="Reset")
        tf.add_to_collection(self.RESET_OP, reset_op)
        # Restore weights operation
        for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
            tf.add_to_collection(self.WEIGHT_VARIABLES, var)
        for var in tf.get_collection(self.CURRENT_MOVING_AVERAGE_VARIABLES):
            tf.add_to_collection(self.WEIGHT_VARIABLES, var)
        weights_assigns = []
        for var in tf.get_collection(self.WEIGHT_VARIABLES):
            name = var.op.name
            with tf.name_scope(None):
                var_update = tf.placeholder(var.dtype, var.get_shape(),
                                            name="{}_Update".format(name))
                tf.add_to_collection(self.WEIGHT_UPDATES, var_update)
                weights_assigns.append(
                    tf.assign(var, var_update, name="{}_Assign".format(name)))
        tf.add_to_collection(self.WEIGHT_UPDATE_OP,
                             tf.group(*weights_assigns,
                                      name="WeightsUpdate"))
        # # Sensitivity = sum(abs(gradient))
        # # This is generally useful only with an input with 1 time step
        # with tf.variable_scope("Sensitivity"):
        #     # Class
        #     dist_class_example = tf.unstack(full_class_dist, axis=-1)
        #     class_grad = [tf.gradients(c, input_buffer)[0]
        #                   for c in dist_class_example]
        #     self._class_grad = tf.stack(class_grad, axis=-1,
        #                                 name="BufferGradientsClass")
        #     # Progress
        #     self._class_progress_grad = tf.gradients(
        #         class_progs, input_buffer,
        #         name="BufferGradientsProgress")[0]
        #     # Progress
        #     self._class_confidence_grad = tf.gradients(
        #         class_confs, input_buffer,
        #         name="BufferGradientsConfidence")[0]
        # Saver
        self._saver = tf.train.Saver()

    def _make_input_buffer(self, input_):
        # Compute buffer size
        batch_size = self._config.batch_size
        step_size = self._config.step_size
        convolution_stack = self._config.convolution_stack
        dilated_stack = self._config.dilated_stack
        local_recurrent_stack = self._config.local_recurrent_stack
        # With strided convolution/pooling step_size cannot be None
        if step_size is None:
            for conv, pool in convolution_stack:
                if conv is not None and conv.stride > 1:
                    raise ValueError("Must provide a step size "
                                     "to use strided convolution.")
                if pool is not None and pool.stride > 1:
                    raise ValueError("Must provide a step size "
                                     "to use strided pooling.")
            # Without strided convolution/pooling step_size does not matter
            step_size = 1
        buffer_size = step_size
        # Buffer for local recurrent stack
        if (local_recurrent_stack is not None and
                local_recurrent_stack.depth > 0 and
                local_recurrent_stack.size > 0):
            buffer_size += local_recurrent_stack.depth
        # Buffer for dilated stack
        if dilated_stack is not None and dilated_stack.base > 0:
            num_levels = len(dilated_stack.channels)
            buffer_size += dilated_stack.base ** num_levels - 1
        # Buffer for convolution stack
        if convolution_stack is not None:
            for conv, pool in reversed(convolution_stack):
                if pool is not None and pool.window_size > 0:
                    buffer_size = buffer_size * pool.stride \
                        + (pool.window_size - 1) * pool.dilation_rate
                if conv is not None and conv.kernel_size > 0:
                    buffer_size = buffer_size * conv.stride \
                        + (conv.kernel_size - 1) * conv.dilation_rate
        buffer_size -= step_size
        # Create buffer
        self._input_buffer = tf.get_variable(
            "Buffer", (batch_size, buffer_size, input_.get_shape()[-1]),
            dtype=input_.dtype, initializer=tf.zeros_initializer(),
            trainable=False)
        # Prepend buffer to input
        y = tf.concat([self._input_buffer, input_], axis=1)
        # Buffer update
        buffer_update_start = tf.shape(y)[1] - buffer_size
        buffer_update = tf.slice(y, [0, buffer_update_start, 0], [-1, -1, -1],
                                 name="BufferUpdated")
        tf.add_to_collection(self.STATE_VARIABLES, self._input_buffer)
        tf.add_to_collection(self.STATE_UPDATES, buffer_update)
        return y

    def _normalize(self, input_, training, global_step=None):
        dtype = self._dtype
        input_size = input_.get_shape()[-1].value
        ema = tf.train.ExponentialMovingAverage(decay=0.99,
                                                num_updates=global_step,
                                                zero_debias=True)
        batch_mean, batch_variance = tf.nn.moments(input_, [0, 1],
                                                   keepdims=True)
        ema_update = ema.apply([batch_mean, batch_variance])
        tf.add_to_collection(self.UPDATE_MOVING_AVERAGE_OP, ema_update)
        avg_mean = ema.average(batch_mean)
        avg_variance = ema.average(batch_variance)
        tf.add_to_collection(self.CURRENT_MOVING_AVERAGE_VARIABLES, avg_mean)
        tf.add_to_collection(self.CURRENT_MOVING_AVERAGE_VARIABLES,
                             avg_variance)
        beta = tf.get_variable("Beta", shape=[input_size], dtype=dtype,
                               initializer=tf.zeros_initializer())
        gamma = tf.get_variable("Gamma", shape=[input_size], dtype=dtype,
                                initializer=tf.ones_initializer(),
                                regularizer=tf.nn.l2_loss)
        epsilon = tf.constant(1e-8, dtype=dtype, name="Epsilon")
        return tf.cond(training,
                       lambda: tf.nn.batch_normalization(
                           input_, batch_mean, batch_variance, beta, gamma,
                           epsilon),
                       lambda: tf.nn.batch_normalization(
                           input_, avg_mean, avg_variance, beta, gamma,
                           epsilon))

    def _make_convolution_stack(self, input_):
        convolution_stack = self._config.convolution_stack
        dtype = self._dtype
        if not convolution_stack:
            return input_
        y = input_
        prev_channels = y.get_shape()[-1]
        if self._use_dropout:
            keep_prob = tf.constant(1, dtype=dtype) - self._dropout
        for i, (conv, pool) in enumerate(convolution_stack):
            layer_name = ""
            if len(convolution_stack) > 1:
                layer_name = "_{}".format(i + 1)
            with tf.variable_scope("ConvolutionLayer{}".format(layer_name)):
                if conv is not None \
                   and conv.kernel_size > 0 and conv.channels > 0:
                    filters = tf.get_variable(
                        "Filter",
                        shape=(conv.kernel_size, prev_channels, conv.channels),
                        dtype=dtype,
                        initializer=tf.contrib.layers.xavier_initializer(),
                        regularizer=tf.nn.l2_loss)
                    bias = tf.get_variable(
                        "Bias", shape=(1, 1, conv.channels), dtype=dtype,
                        initializer=tf.random_normal_initializer())
                    y = tf.nn.convolution(y, filters, "VALID",
                                          strides=[conv.stride],
                                          dilation_rate=[conv.dilation_rate],
                                          name="Convolution")
                    y = y + bias
                    y = tf.nn.tanh(y, name="Activation")
                    prev_channels = conv.channels
                    if self._use_dropout:
                        y = tf.nn.dropout(y, keep_prob=keep_prob)
                if pool is not None and pool.window_size > 0:
                    pool_type = str(pool.type).strip().upper()
                    y = tf.nn.pool(y, [pool.window_size], pool_type, "VALID",
                                   strides=[pool.stride],
                                   dilation_rate=[pool.dilation_rate],
                                   name="Pooling")
        return y

    def _make_dilated_stack(self, input_):
        config = self._config
        dtype = self._dtype
        if not config.dilated_stack:
            return input_
        dil_base = config.dilated_stack.base
        dil_channels = config.dilated_stack.channels
        if dil_base <= 0 or not dil_channels:
            return input_
        if self._use_dropout:
            keep_prob = tf.constant(1, dtype=dtype) - self._dropout
        y = input_
        prev_channels = y.get_shape()[-1]
        for i, dil_channels_i in enumerate(dil_channels):
            conv_name = "_{}".format(i + 1) if len(dil_channels) > 1 else ""
            with tf.variable_scope("DilatedConvolution{}".format(conv_name)):
                filters = tf.get_variable(
                    "Filter",
                    shape=(dil_base, prev_channels, dil_channels_i),
                    dtype=dtype,
                    initializer=tf.contrib.layers.xavier_initializer(),
                    regularizer=tf.nn.l2_loss)
                bias = tf.get_variable(
                    "Bias", shape=(1, 1, dil_channels_i), dtype=dtype,
                    initializer=tf.random_normal_initializer())
                if config.step_size is not None and config.step_size > 1:
                    y = tf.nn.convolution(y, filters, "VALID",
                                          dilation_rate=[dil_base ** i],
                                          name="Convolution")
                else:
                    # Optimize to strided convolution for 1-step case
                    y = tf.nn.convolution(y, filters, "VALID",
                                          strides=[dil_base],
                                          name="Convolution")
                y = y + bias
                y = tf.nn.tanh(y, name="Activation")
                if self._use_dropout:
                    y = tf.nn.dropout(y, keep_prob=keep_prob)
                prev_channels = dil_channels_i
        return y

    def _make_dense_stack(self, input_):
        config = self._config
        dtype = self._dtype
        if not config.dense_stack:
            return input_
        prev_layer = input_
        if self._use_dropout:
            keep_prob = tf.constant(1, dtype=dtype) - self._dropout
        for i, dense_size in enumerate(config.dense_stack):
            layer_name = ""
            if len(config.dense_stack) > 1:
                layer_name = "_{}".format(i + 1)
            with tf.variable_scope("DenseLayer{}".format(layer_name)):
                prev_layer_size = prev_layer.get_shape()[-1].value
                layer_weights = tf.get_variable(
                    "W", (prev_layer_size, dense_size), dtype=self._dtype,
                    initializer=tf.contrib.layers.xavier_initializer(),
                    regularizer=tf.nn.l2_loss)
                layer_biases = tf.get_variable(
                    "B", dense_size, dtype=self._dtype,
                    initializer=tf.random_normal_initializer())
                layer = self._batch_matmul(prev_layer, layer_weights,
                                           layer_biases)
                layer_act = tf.sigmoid(layer, name="Activation")
                if self._use_dropout:
                    layer_act = tf.nn.dropout(layer_act, keep_prob=keep_prob)
                prev_layer = layer_act
        return prev_layer

    def _make_local_recurrent_stack(self, input_):
        config = self._config
        dtype = self._dtype
        batch_size = config.batch_size
        if not config.local_recurrent_stack:
            return input_
        depth = config.local_recurrent_stack.depth
        cell_size = config.local_recurrent_stack.size
        if depth <= 0 or cell_size <= 0:
            return input_
        input_feats = input_.get_shape()[-1].value
        input_length = tf.shape(input_)[1]
        weight_init = tf.contrib.layers.xavier_initializer()
        bias_init = tf.random_normal_initializer()
        # Variables - Update gate
        wz = tf.get_variable("Wz", (input_feats, cell_size), dtype=self._dtype,
                             initializer=weight_init,
                             regularizer=tf.nn.l2_loss)
        uz = tf.get_variable("Uz", (cell_size, cell_size), dtype=self._dtype,
                             initializer=weight_init,
                             regularizer=tf.nn.l2_loss)
        bz = tf.get_variable("Bz", cell_size, dtype=self._dtype,
                             initializer=bias_init)
        # Variables - Reset gate
        wr = tf.get_variable("Wr", (input_feats, cell_size), dtype=self._dtype,
                             initializer=weight_init,
                             regularizer=tf.nn.l2_loss)
        ur = tf.get_variable("Ur", (cell_size, cell_size), dtype=self._dtype,
                             initializer=weight_init,
                             regularizer=tf.nn.l2_loss)
        br = tf.get_variable("Br", cell_size, dtype=self._dtype,
                             initializer=bias_init)
        # Variables - Output
        wh = tf.get_variable("Wh", (input_feats, cell_size), dtype=self._dtype,
                             initializer=weight_init,
                             regularizer=tf.nn.l2_loss)
        uh = tf.get_variable("Uh", (cell_size, cell_size), dtype=self._dtype,
                             initializer=weight_init,
                             regularizer=tf.nn.l2_loss)
        bh = tf.get_variable("Bh", cell_size, dtype=self._dtype,
                             initializer=bias_init)
        # Same dropout at every layer (https://arxiv.org/abs/1512.05287)
        keep_prob = None
        dropout_mask = None
        if self._use_dropout:
            keep_prob = tf.constant(1, dtype=dtype) - self._dropout
            drouput_rand = tf.random_uniform((batch_size, 1, cell_size),
                                             dtype=dtype)
            dropout_mask = tf.cast(drouput_rand < keep_prob, dtype)
            dropout_weights = dropout_mask / keep_prob
        # Initial state
        h = tf.zeros((batch_size, input_length - depth + 1, cell_size),
                     dtype=dtype)
        # RNN layers
        for i_layer in range(depth):
            layer_name = ""
            if depth > 1:
                layer_name = "_{}".format(i_layer + 1)
            with tf.name_scope("Layer{}".format(layer_name)):
                input_begin = i_layer
                input_end = input_length - (depth - i_layer - 1)
                layer_input = input_[:, input_begin:input_end, :]
                z = tf.sigmoid(self._batch_matmul(layer_input, wz) +
                               self._batch_matmul(h, uz) +
                               tf.reshape(bz, [1, 1, cell_size]))
                r = tf.sigmoid(self._batch_matmul(layer_input, wr) +
                               self._batch_matmul(h, ur) +
                               tf.reshape(br, [1, 1, cell_size]))
                h = z * h + (1 - z) * tf.nn.tanh(
                    self._batch_matmul(layer_input, wh) +
                    self._batch_matmul(r * h, uh) +
                    tf.reshape(bh, [1, 1, cell_size]))
                if self._use_dropout:
                    h *= dropout_weights
        return h

    def _make_recurrent_stack(self, input_, input_lengths):
        config = self._config
        dtype = self._dtype
        if not config.recurrent_stack:
            return input_
        # RNN layers
        keep_prob = None
        if self._use_dropout:
            keep_prob = tf.constant(1, dtype=dtype) - self._dropout
        rnn_cells = []
        rnn_state_vars = ()  # Must be tuple, not list
        for i, rnn_cell_size in enumerate(config.recurrent_stack):
            cell_name = ""
            if len(config.recurrent_stack) > 1:
                cell_name = "_{}".format(i + 1)
            with tf.variable_scope("Cell{}".format(cell_name)):
                # Create cell and state variable
                cell = tf.contrib.rnn.GRUCell(num_units=rnn_cell_size)
                cell_state = tf.get_variable(
                    "State", (config.batch_size, cell.state_size), dtype=dtype,
                    initializer=tf.zeros_initializer(), trainable=False)
                if self._use_dropout:
                    cell = tf.contrib.rnn.DropoutWrapper(
                        cell=cell, output_keep_prob=keep_prob)
                rnn_cells.append(cell)
                rnn_state_vars += (cell_state,)
        if rnn_cells:
            multi_rnn = len(rnn_cells) > 1
            if multi_rnn:
                rnn_layer = tf.contrib.rnn.MultiRNNCell(rnn_cells)
            else:
                rnn_layer = rnn_cells[0]
                rnn_state_vars = rnn_state_vars[0]
            if self._fixed_step:
                rnn_state_updates = rnn_state_vars
                ys = []
                for i in range(config.step_size):
                    y, rnn_state_updates = rnn_layer(
                        input_[:, i], rnn_state_updates)
                    ys.append(y)
                y = tf.stack(ys, axis=1)
                mask = tf.sequence_mask(input_lengths, maxlen=config.step_size,
                                        dtype=dtype)
                y = y * tf.expand_dims(mask, axis=-1)
            else:
                y, rnn_state_updates = tf.nn.dynamic_rnn(
                    rnn_layer, input_, sequence_length=input_lengths,
                    initial_state=rnn_state_vars)
            if not multi_rnn:
                rnn_state_vars = (rnn_state_vars,)
                rnn_state_updates = [rnn_state_updates]
            y = tf.sigmoid(y, name="Activation")
            for cell_state_var, cell_state_update in zip(rnn_state_vars,
                                                         rnn_state_updates):
                tf.add_to_collection(self.STATE_VARIABLES, cell_state_var)
                tf.add_to_collection(self.STATE_UPDATES, cell_state_update)
        return y

    def _make_class_logits(self, input_):
        num_classes = len(self._class_names)
        if self._config.default_class:
            num_classes -= 1
        input_size = input_.get_shape()[-1].value
        output_weights = tf.get_variable(
            "W", (input_size, num_classes), dtype=self._dtype,
            initializer=tf.contrib.layers.xavier_initializer(),
            regularizer=tf.nn.l2_loss)
        output_biases = tf.get_variable(
            "B", num_classes, dtype=self._dtype,
            initializer=tf.random_normal_initializer())
        logits = self._batch_matmul(input_, output_weights, output_biases)
        return logits

    def _make_progress_logits(self, input_):
        input_shape = tf.shape(input_)
        input_size = input_.get_shape()[-1].value
        prog_quant = max(int(self._config.progress_quantization), 1)
        # Initialize to zeros if not training progress
        if self._config.progress_loss > 0:
            init = tf.random_normal_initializer()
        else:
            init = tf.zeros_initializer()
        weights = tf.get_variable("W", (input_size, prog_quant),
                                  dtype=self._dtype, initializer=init,
                                  regularizer=tf.nn.l2_loss)
        bias = tf.get_variable("B", (1, 1, prog_quant), dtype=self._dtype,
                               initializer=init)
        input_res = tf.reshape(input_, [-1, input_size])
        input_mul_res = tf.matmul(input_res, weights)
        out_shape = tf.concat([input_shape[:-1], [prog_quant]], axis=0)
        input_mul = tf.reshape(input_mul_res, out_shape)
        logits = input_mul + bias
        return logits

    def _make_confidence_logits(self, input_):
        input_shape = tf.shape(input_)
        input_size = input_.get_shape()[-1].value
        conf_quant = max(int(self._config.confidence_quantization), 1)
        # Initialize to zeros if not training confidence
        if self._config.confidence_loss > 0:
            init = tf.random_normal_initializer()
        else:
            init = tf.zeros_initializer()
        weights = tf.get_variable("W", (input_size, conf_quant),
                                  dtype=self._dtype, initializer=init,
                                  regularizer=tf.nn.l2_loss)
        bias = tf.get_variable("B", (1, 1, conf_quant), dtype=self._dtype,
                               initializer=init)
        input_res = tf.reshape(input_, [-1, input_size])
        input_mul_res = tf.matmul(input_res, weights)
        out_shape = tf.concat([input_shape[:-1], [conf_quant]], axis=0)
        input_mul = tf.reshape(input_mul_res, out_shape)
        logits = input_mul + bias
        return logits

    def _make_smoothed_output(self, output):
        smoothing = self._config.output_smoothing
        if not smoothing or smoothing <= 1:
            return output
        output_buffer_shape = ([smoothing] +
                               output.get_shape().as_list())
        output_buffer = tf.get_variable(
            "ClassDistributionBuffer",
            shape=output_buffer_shape,
            dtype=output.dtype,
            initializer=tf.zeros_initializer(),
            trainable=False)
        output_buffer_update = tf.concat(
            [output_buffer[:-1, ...], output[tf.newaxis, ...]], axis=0,
            name="ClassDistributionBufferUpdated")
        tf.add_to_collection(self.STATE_VARIABLES, output_buffer)
        tf.add_to_collection(self.STATE_UPDATES, output_buffer_update)
        smoothed_output = tf.reduce_sum(output_buffer, axis=0)
        smoothed_output /= smoothing
        return smoothed_output

    @staticmethod
    def _batch_matmul(batch, weights, bias=None):
        batch_size = batch.get_shape()[0].value
        batch_features = batch.get_shape()[-1].value
        output_features = weights.get_shape()[-1].value
        batch_res = tf.reshape(batch, [-1, batch_features])
        output_res = tf.matmul(batch_res, weights)
        output = tf.reshape(output_res, [batch_size, -1, output_features])
        if bias is not None:
            output += tf.reshape(bias, [1, 1, output_features])
        return output

    def _make_state_update(self, input_):
        unpack = False
        if not isinstance(input_, (list, tuple)):
            input_ = [input_]
            unpack = True
        # State step
        state_vars = tf.get_collection(self.STATE_VARIABLES)
        state_updates = tf.get_collection(self.STATE_UPDATES)
        with tf.control_dependencies(input_ + state_updates):
            state_assigns = [tf.assign(var, update)
                             for var, update in zip(state_vars, state_updates)]
            step = tf.group(*state_assigns, name="Update")
            tf.add_to_collection(self.STATE_UPDATE_OP, step)
            with tf.control_dependencies([step]):
                output = [tf.identity(inp) for inp in input_]
        # State reset operation
        tf.add_to_collection(self.STATE_RESET_OP,
                             tf.variables_initializer(state_vars,
                                                      name="Reset"))
        if unpack:
            output = output.pop()
        return output

    def _make_averaged_metric(self, name, value, num_frames, prev_num_frames,
                              total_frames):
        with tf.variable_scope("{}Average".format(name)):
            dtype = value.dtype
            prev_value = tf.get_variable("{}Previous".format(name),
                                         shape=value.get_shape(), dtype=dtype,
                                         initializer=tf.zeros_initializer(),
                                         trainable=False)
            tf.add_to_collection(self.METRIC_VARIABLES, prev_value)
            averaged_metric = tf.divide(
                tf.add(value * tf.cast(num_frames, dtype=dtype),
                       prev_value * tf.cast(prev_num_frames, dtype=dtype)),
                tf.cast(total_frames, dtype=dtype),
                name=name)
            averaged_metric = tf.cast(averaged_metric, dtype)
            tf.add_to_collection(self.METRIC_UPDATES, averaged_metric)
            # Do not make summaries here
            # tf.summary.scalar(name, averaged_metric)
        return averaged_metric

    def _make_metrics_update(self, input_):
        unpack = False
        if not isinstance(input_, (list, tuple)):
            input_ = [input_]
            unpack = True
        # Metrics step
        metric_vars = tf.get_collection(self.METRIC_VARIABLES)
        metric_updates = tf.get_collection(self.METRIC_UPDATES)
        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        with tf.control_dependencies(input_ + metric_updates + summaries):
            metric_assigns = [tf.assign(var, update)
                              for var, update in zip(metric_vars,
                                                     metric_updates)]
            step = tf.group(*metric_assigns, name="Update")
            tf.add_to_collection(self.METRIC_UPDATE_OP, step)
            with tf.control_dependencies([step]):
                output = [tf.identity(inp) for inp in input_]
        # Metrics reset operation
        tf.add_to_collection(self.METRIC_RESET_OP,
                             tf.variables_initializer(metric_vars,
                                                      name="Reset"))
        if unpack:
            output = output.pop()
        return output

    def _make_aliases(self):
        with tf.name_scope(None):
            tf.add_to_collection(self.INPUT_ALIASES,
                                 tf.identity(self._input, name="Input"))
            for name, value in self._outputs.items():
                tf.add_to_collection(self.OUTPUT_ALIASES,
                                     tf.identity(value, name=name))
            state_vars = tf.get_collection(self.STATE_VARIABLES)
            state_upds = tf.get_collection(self.STATE_UPDATES)
            for i, (var, upd) in enumerate(zip(state_vars, state_upds)):
                name = "_{}".format(i + 1) if len(state_vars) > 1 else ""
                var_out = var.op.outputs[0]
                state_in_name = "StateInput{}".format(name)
                state_out_name = "StateOutput{}".format(name)
                tf.add_to_collection(self.INPUT_ALIASES,
                                     tf.identity(var_out, name=state_in_name))
                tf.add_to_collection(self.OUTPUT_ALIASES,
                                     tf.identity(upd, name=state_out_name))

    def log_graph(self, logdir):
        tf.summary.FileWriter(logdir=str(logdir), graph=self._graph).close()

    def reinitialize(self):
        with self._graph.as_default():
            self.session.run(tf.get_collection(tf.GraphKeys.INIT_OP))
        self.clear_train_history()

    def reset(self):
        with self._graph.as_default():
            self.session.run(tf.get_collection(self.RESET_OP))

    def train(self, examples, labels, label_progs=None, label_confs=None,
              validation_data=None, test_data=None, max_epochs=200,
              patience=None, class_weights=None, evaluation_step=1,
              checkpoint_file=None, log_dir=None):
        if patience is not None and patience < 0:
            patience = None
        if class_weights is not None:
            class_weights = np.asarray(class_weights)
        if checkpoint_file:
            checkpoint_file = str(checkpoint_file)
        if log_dir:
            log_dir = str(log_dir)
        np_dtype = self._dtype.as_numpy_dtype
        if label_progs is None:
            label_progs = self._make_label_progress(labels)
        if label_confs is None:
            label_confs = self._make_label_confidence(labels)
        if validation_data and len(validation_data) < 3:
            validation_data += (self._make_label_progress(validation_data[1]),)
        if validation_data and len(validation_data) < 4:
            validation_data += (self._make_label_confidence(
                validation_data[1]),)
        if test_data and len(test_data) < 3:
            test_data += (self._make_label_progress(test_data[1]),)
        if test_data and len(test_data) < 4:
            test_data += (self._make_label_confidence(test_data[1]),)
        try:
            train_writer = None
            validation_writer = None
            test_writer = None
            if log_dir is not None:
                train_writer = tf.summary.FileWriter(
                    logdir=str(Path(log_dir, "Train")))
                if validation_data:
                    validation_writer = tf.summary.FileWriter(
                        logdir=str(Path(log_dir, "Validation")))
                if test_data:
                    test_writer = tf.summary.FileWriter(
                        logdir=str(Path(log_dir, "Test")))
            # Prepare weights
            if self._best_weights is None:
                self._current_weights = self._get_weights()
                self._best_weights = self._current_weights
            if self._latest_weights is None:
                self._latest_weights = self._best_weights
            # Prepare metrics
            num_metrics = len(self._metrics)
            if len(self._train_history) <= 1:
                self._log("Training started.")
            else:
                self._log("Training resumed from epoch {}."
                          .format(self._train_history.index[-1]))
            # Initial evaluation
            if len(self._train_history) <= 0:
                train_metrics = np.full((1, num_metrics), np.nan,
                                        dtype=np_dtype)
                validation_metrics = np.full((1, num_metrics), np.nan,
                                             dtype=np_dtype)
                test_metrics = np.full((1, num_metrics), np.nan,
                                       dtype=np_dtype)
                metrics = self._evaluate_dataset(examples, labels, label_progs,
                                                 label_confs, train_writer, 0)
                train_metrics[0] = list(metrics.values())
                self._best_loss = metrics["Loss"]
                if validation_data:
                    metrics = self._evaluate_dataset(*validation_data,
                                                     validation_writer, 0)
                    validation_metrics[0] = list(metrics.values())
                    self._best_loss = metrics["Loss"]
                if test_data:
                    metrics = self._evaluate_dataset(*test_data,
                                                     test_writer, 0)
                    test_metrics[0] = list(metrics.values())
                self._append_history(
                    [0], train_metrics, validation_metrics, test_metrics)
                for writer in [train_writer, validation_writer, test_writer]:
                    if writer is not None:
                        writer.flush()
            base_epochs = self._train_history.index[-1]
            # Allocate metrics arrays
            last_epoch = base_epochs + max_epochs
            last_epoch -= last_epoch % evaluation_step
            global_steps = np.arange(base_epochs, last_epoch) + 1
            evaluation_steps = []
            num_evaluations = np.count_nonzero(
                (global_steps % evaluation_step) == 0)
            train_metrics = np.full((num_evaluations, num_metrics),
                                    np.nan, dtype=np_dtype)
            validation_metrics = np.full((num_evaluations, num_metrics),
                                         np.nan, dtype=np_dtype)
            test_metrics = np.full((num_evaluations, num_metrics),
                                   np.nan, dtype=np_dtype)
            # Back to the latest weights
            self._set_weights(self._latest_weights)
            self._current_weights = None
            self._latest_weights = None
            # Train
            remaining_patience = patience
            self._log(newline=True)
            try:
                global_step = 0
                for global_step in global_steps:
                    self._log("Epoch {}/{}: Training..."
                              .format(global_step, base_epochs + max_epochs))
                    if remaining_patience is not None:
                        remaining_patience -= 1
                    self.reset()
                    self._train_epoch(global_step, examples, labels,
                                      label_progs, label_confs, class_weights)
                    if global_step % evaluation_step == 0:
                        self._log("Epoch {}/{}: Evaluating..."
                                  .format(global_step,
                                          base_epochs + max_epochs))
                        i_evaluation = len(evaluation_steps)
                        metrics = self._evaluate_dataset(
                            examples, labels, label_progs, label_confs,
                            train_writer, global_step)
                        metrics_val = list(metrics.values())
                        train_metrics[i_evaluation] = metrics_val
                        epoch_loss = metrics["Loss"]
                        if validation_data:
                            metrics = self._evaluate_dataset(
                                *validation_data, validation_writer,
                                global_step)
                            metrics_val = list(metrics.values())
                            validation_metrics[i_evaluation] = metrics_val
                            epoch_loss = metrics["Loss"]
                        if test_data:
                            metrics = self._evaluate_dataset(
                                *test_data, test_writer, global_step)
                            metrics_val = list(metrics.values())
                            test_metrics[i_evaluation] = metrics_val
                        for writer in [train_writer,
                                       validation_writer,
                                       test_writer]:
                            if writer is not None:
                                writer.flush()
                        # Save weights
                        if epoch_loss < self._best_loss:
                            self._best_loss = epoch_loss
                            self._best_weights = self._get_weights()
                            remaining_patience = patience
                        self._latest_weights = self._get_weights()
                        if checkpoint_file:
                            self.save_checkpoint(checkpoint_file)
                        evaluation_steps.append(global_step)
                        early_stop = remaining_patience is not None \
                            and remaining_patience < 0
                        if early_stop:
                            self._log("Early stopping at epoch {}."
                                      .format(global_step))
                            break
                else:
                    self._log("Training completed at epoch {}."
                              .format(global_step))
            except KeyboardInterrupt:
                self._log("Training interrupted at epoch {}."
                          .format(global_step))
        finally:
            if train_writer is not None:
                train_writer.close()
            if validation_writer is not None:
                validation_writer.close()
            if test_writer is not None:
                test_writer.close()
        # Save train history
        self._append_history(evaluation_steps,
                             train_metrics[:len(evaluation_steps)],
                             validation_metrics[:len(evaluation_steps)],
                             test_metrics[:len(evaluation_steps)])
        # Restore best weights and save
        self._set_weights(self._best_weights)
        self._current_weights = self._best_weights
        if checkpoint_file:
            self.save_checkpoint(checkpoint_file)
        self._log("Evaluating final metrics...", newline=True)
        final_metrics = np.full(3 * num_metrics, np.nan, dtype=np_dtype)
        examples_eval = self.evaluate(examples, labels)
        final_metrics[:num_metrics] = list(examples_eval.values())
        if validation_data:
            validation_eval = self.evaluate(*validation_data)
            final_metrics[num_metrics:num_metrics * 2] = \
                list(validation_eval.values())
        if test_data:
            test_eval = self.evaluate(*test_data)
            final_metrics[num_metrics * 2:] = list(test_eval.values())
        self._log("Training completed.", newline=True)
        self._log(newline=True)
        return pd.Series(final_metrics, index=self._train_history.columns)

    def _train_epoch(self, global_step, examples, labels, label_progs,
                     label_confs, class_weights=None):
        self.reset()
        batches = self._batches(examples, labels, label_progs, label_confs,
                                shuffle=True)
        for (examples_batch, labels_batch,
             label_progs_batch, label_confs_batch) in batches:
            batch_steps = self._batch_steps(examples_batch,
                                            labels_batch,
                                            label_progs_batch,
                                            label_confs_batch)
            for (examples_step, labels_step,
                 label_progs_step, label_confs_step) in batch_steps:
                feed_dict = {
                    self._input: examples_step,
                    self._labels: labels_step,
                    self._label_progress: label_progs_step,
                    self._label_confidence: label_confs_step,
                    self._training: True,
                    self._global_step: global_step
                }
                if class_weights is not None:
                    feed_dict[self._class_weights] = class_weights
                with self._graph.as_default():
                    self.session.run(tf.get_collection(tf.GraphKeys.TRAIN_OP),
                                     feed_dict=feed_dict)
            self.reset()

    def _evaluate_dataset(self, examples, labels, label_progs, label_confs,
                          writer, global_step):
        metrics = self.evaluate(examples, labels, label_progs, label_confs)
        if writer is not None:
            for name, value in metrics.items():
                summary = tf.Summary()
                summary.value.add(tag=name, simple_value=value)
                writer.add_summary(summary, global_step=global_step)
        return metrics

    def evaluate(self, examples, labels, label_progs=None, label_confs=None):
        if label_progs is None:
            label_progs = self._make_label_progress(labels)
        if label_confs is None:
            label_confs = self._make_label_confidence(labels)
        total_frames = 0
        total_metrics = np.zeros(len(self._metrics),
                                 dtype=self._dtype.as_numpy_dtype)
        self.reset()
        batches = self._batches(examples, labels, label_progs, label_confs,
                                shuffle=False)
        for (examples_batch, labels_batch,
             label_progs_batch, label_confs_batch) in batches:
            batch_steps = self._batch_steps(examples_batch,
                                            labels_batch,
                                            label_progs_batch,
                                            label_confs_batch)
            for (examples_step, labels_step,
                 label_progs_step, label_confs_step) in batch_steps:
                # Compute metrics (last result is averaged over whole sequence)
                metrics = self.session.run(
                    list(self._metrics.values()), feed_dict={
                        self._input: examples_step,
                        self._labels: labels_step,
                        self._label_progress: label_progs_step,
                        self._label_confidence: label_confs_step,
                        self._dropout: 0
                    })
            # Weight-average metrics by batch
            if self._mask_value is not None:
                m = np.any(examples_batch != self._mask_value, axis=-1)
            else:
                m = np.ones(examples_batch.shape[:-1], dtype=bool)
            m &= (labels_batch >= 0) & (labels_batch < len(self._class_names))
            num_frames = np.count_nonzero(m)
            total_metrics += num_frames * np.asarray(metrics)
            total_frames += num_frames
            self.reset()
        # Return averaged metrics
        return OrderedDict(zip(self._metrics.keys(),
                               total_metrics / max(total_frames, 1)))

    def predict(self, examples, outputs=None):
        unpack = False
        if np.isscalar(outputs):
            outputs = [outputs]
            unpack = True
        outputs = outputs or list(self._outputs.keys())
        # Transform numeric outputs to names
        for i, output in enumerate(outputs):
            if np.isreal(output):
                outputs[i] = self._odict_at(self._outputs, output)
        output_vals = [self._outputs[output] for output in outputs]
        predictions = [[None] * len(examples) for _ in outputs]
        # Predict
        batch_base = 0
        self.reset()
        for examples_batch in self._batches(examples):
            step_base = 0
            for examples_step in self._batch_steps(examples_batch):
                outs = self.session.run(output_vals, feed_dict={
                    self._input: examples_step,
                    self._dropout: 0
                })
                # Copy predictions to output
                for pred, out in zip(predictions, outs):
                    out_n = len(out)
                    pred_begin = batch_base
                    pred_end = min(batch_base + out_n, len(examples))
                    pred_out = zip(pred[pred_begin:pred_end], out)
                    for i, (pred_i, out_i) in enumerate(pred_out, pred_begin):
                        if pred_i is None:
                            pred_i_shape = (len(examples[i]),) + out.shape[2:]
                            pred_i = np.empty(pred_i_shape, dtype=out.dtype)
                            pred[i] = pred_i
                        step_size = min(len(out_i), len(pred_i) - step_base)
                        if step_size > 0:
                            step_begin = step_base
                            step_end = step_base + step_size
                            pred_i[step_begin:step_end] = out_i[:step_size]
                step_base += examples_step.shape[1]
            batch_base += examples_batch.shape[0]
            self.reset()
        res = OrderedDict(zip(outputs, predictions))
        if unpack:
            res = res[outputs[0]]
        return res

    def input_sensitivity(self, examples, labels=None, label_progs=None,
                          label_confs=None):
        batch_size = self._config.batch_size
        step_size = self._config.step_size
        dtype_np = self._dtype.as_numpy_dtype
        if step_size is not None and self._fixed_step and step_size != 1:
            raise ValueError("Cannot compute compute sensitivity with "
                             "a fixed step size different to one.")
        if labels is None:
            labels = [np.zeros(len(e), dtype=np.int32) for e in examples]
        if label_progs is None:
            label_progs = [np.zeros(len(e), dtype=np.int32) for e in examples]
        if label_confs is None:
            label_confs = [np.ones(len(e), dtype=np.int32) for e in examples]
        buffer_shape = self._input_buffer.get_shape()
        receptive_field_size = buffer_shape[1].value + 1
        input_dim = buffer_shape[2].value
        num_classes = len(self._class_names)
        total_sens_class = np.zeros((receptive_field_size, input_dim,
                                     num_classes), dtype=dtype_np)
        total_class_frames = np.zeros(receptive_field_size)
        total_sens_prog = np.zeros((receptive_field_size, input_dim),
                                   dtype=dtype_np)
        total_prog_frames = np.zeros(receptive_field_size)
        total_sens_conf = np.zeros((receptive_field_size, input_dim),
                                   dtype=dtype_np)
        total_conf_frames = np.zeros(receptive_field_size)
        batches = self._batches(examples, labels, label_progs, label_confs,
                                step_size=1)
        for (examples_batch, labels_batch,
             label_progs_batch, label_confs_batch) in batches:
            class_mask = np.full((batch_size, receptive_field_size), False)
            prog_mask = np.full((batch_size, receptive_field_size), False)
            conf_mask = np.full((batch_size, receptive_field_size), False)
            steps = self._batch_steps(examples_batch, labels_batch,
                                      label_progs_batch, label_confs_batch,
                                      step_size=1)
            for (examples_step, labels_step,
                 label_progs_step, label_confs_step) in steps:
                grad_class, grad_prog, grad_conf = self.session.run(
                    [self._class_grad,
                     self._class_progress_grad,
                     self._class_progress_conf],
                    feed_dict={
                        self._input: examples_step,
                        self._dropout: 0
                    })
                # Mask invalids
                valid_labels = (labels_step >= 0) & (labels_step < num_classes)
                if self._mask_value is not None:
                    valid_labels &= np.logical_and.reduce(
                        examples_step == self._mask_value, axis=-1)
                grad_class *= np.reshape(valid_labels.astype(grad_class.dtype),
                                         valid_labels.shape + (1, 1))
                valid_progs = valid_labels.copy()
                valid_progs &= \
                    (label_progs_step >= 0) & (label_progs_step <= 1)
                grad_prog *= np.expand_dims(
                    valid_progs.astype(grad_prog.dtype), axis=-1)
                valid_confs = valid_labels.copy()
                valid_confs &= \
                    (label_progs_step >= 0) & (label_progs_step <= 1)
                grad_conf *= np.expand_dims(
                    valid_confs.astype(grad_conf.dtype), axis=-1)
                # Advance masks
                class_mask[:, :-1] = class_mask[:, 1:]
                class_mask[:, -1:] = valid_labels
                prog_mask[:, :-1] = prog_mask[:, 1:]
                prog_mask[:, -1:] = valid_progs
                conf_mask[:, :-1] = conf_mask[:, 1:]
                conf_mask[:, -1:] = valid_confs
                # Accumulate sensitivities
                sens_class = np.sum(np.abs(grad_class), axis=0)
                class_frames = np.sum(class_mask, axis=0)
                total_sens_class += sens_class * np.reshape(
                    class_frames, class_frames.shape + (1, 1))
                total_class_frames += class_frames
                sens_prog = np.sum(np.abs(grad_prog), axis=0)
                prog_frames = np.sum(prog_mask, axis=0)
                total_sens_prog += sens_prog * np.expand_dims(prog_frames,
                                                              axis=-1)
                total_prog_frames += prog_frames
                sens_conf = np.sum(np.abs(grad_conf), axis=0)
                conf_frames = np.sum(conf_mask, axis=0)
                total_sens_conf += sens_conf * np.expand_dims(conf_frames,
                                                              axis=-1)
                total_conf_frames += conf_frames
            self.reset()
        # Finish
        total_sens_class /= \
            np.maximum(np.reshape(total_class_frames,
                                  total_class_frames.shape + (1, 1)), 1)
        total_sens_class = total_sens_class[::-1, ...]
        total_sens_class = np.split(total_sens_class,
                                    total_sens_class.shape[-1], axis=-1)
        total_sens_class = [s.squeeze(-1) for s in total_sens_class]
        total_sens_prog /= \
            np.maximum(np.reshape(total_prog_frames,
                                  total_prog_frames.shape + (1,)), 1)
        total_sens_prog = total_sens_prog[::-1, ...]
        total_sens_conf /= \
            np.maximum(np.reshape(total_conf_frames,
                                  total_conf_frames.shape + (1,)), 1)
        total_sens_conf = total_sens_conf[::-1, ...]
        return total_sens_class, total_sens_prog, total_sens_conf

    def _batches(self, examples, labels=None, label_progs=None,
                 label_confs=None, step_size=None, shuffle=False):
        config = self._config
        batch_size = config.batch_size
        dtype_np = self._dtype.as_numpy_dtype
        output_labels = labels is not None
        output_label_progs = label_progs is not None
        output_label_confs = label_confs is not None
        num_examples = len(examples)
        if output_labels and num_examples != len(labels):
            raise ValueError("There must be the same number "
                             "of examples and labels.")
        if output_label_progs and num_examples != len(label_progs):
            raise ValueError("There must be the same number "
                             "of examples and label progress.")
        if output_label_confs and num_examples != len(label_confs):
            raise ValueError("There must be the same number "
                             "of examples and label confidence.")
        if num_examples <= 0:
            raise StopIteration()
        examples = examples.copy()
        labels = labels.copy() if output_labels else [0] * len(examples)
        if output_label_progs:
            label_progs = label_progs.copy()
        else:
            label_progs = [0] * len(examples)
        if output_label_confs:
            label_confs = label_confs.copy()
        else:
            label_confs = [1] * len(examples)
        if shuffle:
            data = list(zip(examples, labels, label_progs, label_confs))
            random.shuffle(data)
            examples, labels, label_progs, label_confs = map(list, zip(*data))
        # Normalize labels
        data = zip(examples, labels, label_progs, label_confs)
        for i, (example, label, label_prog, label_conf) in enumerate(data):
            if np.isscalar(label):
                labels[i] = np.full(len(example), label)
            if np.isscalar(label_prog):
                label_progs[i] = np.full(len(example), label_prog)
            if np.isscalar(label_conf):
                label_confs[i] = np.full(len(example), label_conf)
        empty_value = 0
        if self._mask_value is not None:
            empty_value = self._mask_value
        while examples:
            # Take next batch
            examples_batch = examples[:batch_size]
            examples = examples[batch_size:]
            labels_batch = labels[:batch_size]
            labels = labels[batch_size:]
            label_progs_batch = label_progs[:batch_size]
            label_progs = label_progs[batch_size:]
            label_confs_batch = label_confs[:batch_size]
            label_confs = label_confs[batch_size:]
            length_batch = max(map(len, examples_batch))
            # Round to multiple of step size
            if step_size is not None:
                if config.step_size is not None:
                    if step_size > config.step_size:
                        raise ValueError("Step size is too big.")
                    if self._fixed_step and step_size != config.step_size:
                        raise ValueError("Cannot override fixed step size.")
                batch_step_size = step_size
            elif config.step_size is not None:
                batch_step_size = config.step_size
                if not self._fixed_step:
                    batch_step_size = min(batch_step_size, length_batch)
            else:
                batch_step_size = length_batch
            length_batch = \
                (((length_batch - 1) // batch_step_size) + 1) * batch_step_size
            # Put into arrays
            batch_shape = (batch_size, length_batch, self._data_dim)
            examples_arr = np.full(batch_shape, empty_value, dtype=dtype_np)
            labels_arr = np.full((batch_size, length_batch),
                                 empty_value, dtype=np.int32)
            labels_progress_arr = np.full((batch_size, length_batch),
                                          empty_value, dtype=dtype_np)
            labels_confidence_arr = np.full((batch_size, length_batch),
                                            empty_value, dtype=dtype_np)
            batch_data = zip(examples_batch, labels_batch,
                             label_progs_batch, label_confs_batch)
            for i, (example, label,
                    label_prog, label_conf) in enumerate(batch_data):
                examples_arr[i, :len(example), :] = example
                labels_arr[i, :len(label)] = label
                labels_progress_arr[i, :len(label_prog)] = label_prog
                labels_confidence_arr[i, :len(label_conf)] = label_conf
            res = examples_arr
            if output_labels or output_label_progs or output_label_confs:
                res = (examples_arr,)
                if output_labels:
                    res += (labels_arr,)
                if output_label_progs:
                    res += (labels_progress_arr,)
                if output_label_confs:
                    res += (labels_confidence_arr,)
            yield res

    def _batch_steps(self, examples_batch, labels_batch=None,
                     label_progs_batch=None, label_confs_batch=None,
                     step_size=None):
        output_labels = labels_batch is not None
        output_label_progs = label_progs_batch is not None
        output_label_confs = label_confs_batch is not None
        labels_shape = examples_batch.shape[:2]
        if output_labels and labels_shape != labels_batch.shape[:2]:
            raise ValueError("Invalid batch shapes.")
        if output_label_progs and labels_shape != label_progs_batch.shape[:2]:
            raise ValueError("Invalid batch shapes.")
        if output_label_confs and labels_shape != label_confs_batch.shape[:2]:
            raise ValueError("Invalid batch shapes.")
        config = self._config
        num_steps = examples_batch.shape[1]
        if step_size is not None:
            if config.step_size is not None:
                if step_size > config.step_size:
                    raise ValueError("Step size is too big.")
                if self._fixed_step and step_size != config.step_size:
                    raise ValueError("Cannot override fixed step size.")
        elif config.step_size is not None:
            step_size = config.step_size
            if not self._fixed_step:
                step_size = min(step_size, num_steps)
        else:
            step_size = num_steps
        if num_steps % step_size != 0:
            raise ValueError("The length of the batch must be "
                             "a multiple of the step size.")
        for begin in range(0, num_steps, step_size):
            end = begin + step_size
            step = examples_batch[:, begin:end, :]
            if output_labels or output_label_progs or output_label_confs:
                step = (step,)
                if output_labels:
                    step += (labels_batch[:, begin:end],)
                if output_label_progs:
                    step += (label_progs_batch[:, begin:end],)
                if output_label_confs:
                    step += (label_confs_batch[:, begin:end],)
            yield step

    def _make_label_progress(self, labels):
        label_progs = []
        for label in labels:
            label_diff = np.diff(label)
            label_start_idx = np.concatenate(
                [[0], np.where(label_diff != 0)[0] + 1], axis=0)
            label_length = np.diff(np.concatenate(
                [label_start_idx, [len(label)]], axis=0))
            label_prog = np.concatenate(
                [np.linspace(0, 1, l) for l in label_length])
            label_progs.append(label_prog)
        return label_progs

    def _make_label_confidence(self, labels):
        return [np.ones_like(label) for label in labels]

    def save_checkpoint(self, file_path):
        self._saver.save(self.session, str(file_path))

    def restore_checkpoint(self, file_path):
        self._saver.restore(self.session, str(file_path))
        self.clear_train_history()
        self._current_weights = self._get_weights()

    def _odict_at(self, odict, index):
        return list(odict.keys())[index]

    def _odict_index(self, odict, elem):
        return list(odict.keys()).index(elem)

    def _log(self, msg="", newline=False):
        if newline:
            if self._last_log_len > 0:
                sys.stdout.write("\n")
                sys.stdout.flush()
                self._last_log_len = 0
        elif self._last_log_len > 0:
            # sys.stdout.write("\b" * self._last_log_len)
            sys.stdout.write("\r")
        if len(msg) > 0:
            now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            msg = "{} - {}".format(now_str, msg)
        if len(msg) < self._last_log_len:
            msg += " " * max(self._last_log_len - len(msg), 0)
        sys.stdout.write(msg)
        self._last_log_len = len(msg)
        sys.stdout.flush()

    def close_session(self):
        if self._session:
            self._session.close()
        self._session = None

    @property
    def name(self):
        return self._name

    @property
    def graph(self):
        return self._graph

    def get_collection(self, collection):
        return self._graph.get_collection(collection)

    @property
    def session(self):
        if not self._session:
            with self._graph.as_default():
                self._session = tf.Session()
                self._session.run(tf.get_collection(tf.GraphKeys.INIT_OP))
                if self._current_weights is not None:
                    self._set_weights(self._current_weights)
        return self._session

    @property
    def data_dim(self):
        return self._data_dim

    @property
    def class_names(self):
        return self._class_names.copy()

    @property
    def config(self):
        return self._config

    @property
    def dtype(self):
        return self._dtype

    @property
    def use_dropout(self):
        return self._use_dropout

    @property
    def input_filter(self):
        return self._input_filter

    @property
    def stateful(self):
        return self._stateful

    @property
    def fixed_step(self):
        return self._fixed_step

    @property
    def mask_value(self):
        return self._mask_value

    @property
    def metric_names(self):
        return list(self._metrics.keys())

    @property
    def output_names(self):
        return list(self._outputs.keys())

    @property
    def aliases(self):
        return (self._graph.get_collection(self.INPUT_ALIASES),
                self._graph.get_collection(self.OUTPUT_ALIASES))

    @property
    def train_history(self):
        if self._train_history is None:
            self.clear_train_history()
        return self._train_history

    def clear_train_history(self):
        history_cols = pd.MultiIndex.from_product(
            [("Train", "Validation", "Test"), self.metric_names],
            names=["Dataset", "Metric"])
        self._train_history = pd.DataFrame(columns=history_cols,
                                           dtype=self._dtype.as_numpy_dtype,
                                           index=np.arange(0))
        self._train_history.index.name = "Epoch"
        self._best_loss = np.inf
        self._best_weights = None
        self._latest_weights = None
        self._current_weights = None

    def _append_history(self, steps, train_metrics,
                        validation_metrics=None, test_metrics=None):
        if validation_metrics is None:
            validation_metrics = np.full(train_metrics.shape,
                                         np.nan, train_metrics.dtype)
        if test_metrics is None:
            test_metrics = np.full(train_metrics.shape,
                                   np.nan, train_metrics.dtype)
        if len(steps) != len(train_metrics):
            raise ValueError("Metrics shape mismatch")
        if train_metrics.shape != validation_metrics.shape:
            raise ValueError("Metrics shape mismatch")
        if train_metrics.shape != test_metrics.shape:
            raise ValueError("Metrics shape mismatch")
        history_idx = np.asarray(steps).astype(np.int32)
        history_data = [train_metrics, validation_metrics, test_metrics]
        history_data = np.concatenate(history_data, axis=1)
        history = pd.DataFrame(history_data,
                               columns=self._train_history.columns,
                               index=history_idx)
        history.index.name = self._train_history.index.name
        self._train_history = self._train_history.append(history)

    def _get_current_weights(self):
        if self._current_weights is None:
            self._current_weights = self._get_weights()
        return self._current_weights.copy()

    def _get_weights(self):
        with self._graph.as_default():
            weight_vars = tf.get_collection(self.WEIGHT_VARIABLES)
            return self.session.run(weight_vars)

    def _set_weights(self, weights):
        updates = self._graph.get_collection(self.WEIGHT_UPDATES)
        if len(weights) != len(updates):
            raise ValueError("The number of weigths must match "
                             "the number of weights variables.")
        self.session.run(self._graph.get_collection(self.WEIGHT_UPDATE_OP),
                         feed_dict=dict(zip(updates, weights)))
        self._current_weights = self._get_weights()

    weights = property(_get_current_weights, _set_weights)

    @property
    def best_weights(self):
        if self._best_weights is None:
            self._best_weights = self._get_weights()
        return self._best_weights.copy()

    @property
    def latest_weights(self):
        if self._latest_weights is None:
            self._latest_weights = self._get_weights()
        return self._latest_weights.copy()

    @property
    def num_parameters(self):
        with self._graph.as_default():
            weight_vars = tf.get_collection(self.WEIGHT_VARIABLES)
        return sum(v.get_shape().num_elements() for v in weight_vars)

    def _get_state(self):
        with self._graph.as_default():
            return self.session.run(tf.get_collection(self.STATE_VARIABLES))

    def _set_state(self, state):
        updates = self._graph.get_collection(self.STATE_UPDATES)
        if len(state) != len(updates):
            raise ValueError("The number of state values must match "
                             "the number of state variables.")
        self.session.run(self._graph.get_collection(self.STATE_UPDATE_OP),
                         feed_dict=dict(zip(updates, state)))

    state = property(_get_state, _set_state)
