import sys
from glob import glob
from pathlib import Path
from itertools import chain
import os
import logging

import numpy as np
import tensorflow as tf

from SplineFunctionedVariables import SplineFunctionedVariables

logger = logging.getLogger(__name__)
class Gfnn(object):

    def __init__(self, input_size, output_size, hidden_layers,
                 spline_points, spline_close, logdir, dtype=tf.float32,
                 use_bias=True, regularization=None, dropout=None,
                 input_mean=None, input_scale=None,
                 output_mean=None, output_scale=None,
                 preprocess_fn=None, preprocess_scaled_fn=None, postprocess_fn=None,
                 parse_example_fn=None, loss_fn=None, metrics_fn=None,
                 hidden_activation_fn=tf.nn.relu, output_activation_fn=None,
                 devices=None, seed=None):
        self._graph = tf.Graph()
        self._dtype = dtype
        logdir = Path(logdir).resolve()
        self._logdir = str(logdir)
        self._logger = logger.getChild(logdir.parts[-1])
        self._use_bias = use_bias
        self._regularization = regularization
        self._input_size = input_size
        self._output_size = output_size
        self._hidden_layers = self._immutable_numpy_array(hidden_layers)
        self._spline_points = self._immutable_numpy_array(spline_points)
        self._spline_close = self._immutable_numpy_array(np.broadcast_to(spline_close, self._spline_points.shape))
        self._input_mean = self._immutable_numpy_array(input_mean)
        self._input_scale = self._immutable_numpy_array(input_scale)
        self._output_mean = self._immutable_numpy_array(output_mean)
        self._output_scale = self._immutable_numpy_array(output_scale)
        self._preprocess_fn = ((lambda x, coords: preprocess_fn(self, x, coords))
                               if preprocess_fn else (lambda x, coords: (x, coords)))
        self._preprocess_scaled_fn = ((lambda x: preprocess_scaled_fn(self, x))
                                      if preprocess_scaled_fn else (lambda x: x))
        self._postprocess_fn = ((lambda y: postprocess_fn(self, y))
                                if postprocess_fn else (lambda y: y))
        self._parse_example_fn = ((lambda example_proto: parse_example_fn(self, example_proto))
                                  if parse_example_fn else self._default_parse_example_fn)
        self._loss_fn = self._default_loss_fn
        if loss_fn:
            self._loss_fn = (lambda label_unscaled, out_unscaled, out_logits:
                             loss_fn(self, label_unscaled, out_unscaled, out_logits))
        self._metrics_fn = lambda lbl, lbl_unscaled, out, out_unscaled, out_logits, weight: {}
        if metrics_fn:
            self._metrics_fn = (lambda lbl, lbl_unscaled, out, out_unscaled, out_logits, weight:
                                metrics_fn(self, lbl, lbl_unscaled, out, out_unscaled, out_logits, weight))
        self._hidden_activation_fn = hidden_activation_fn or (lambda x: x)
        self._output_activation_fn = output_activation_fn or (lambda x: x)
        self._devices = list(devices) if devices else None
        if not logdir.exists():
            logdir.mkdir(parents=True)
        self._configure_logger()
        # Build training graph
        with self._graph.as_default():
            if seed is not None:
                tf.random.set_random_seed(seed)
            # Input
            input_, coords, label, weight = self._make_input_pipeline(input_size, output_size)
            self._input = input_
            self._coords = coords
            input_, coords = self._preprocess_fn(input_, coords)
            self._label = label
            self._weight = weight
            self._spline_vars = SplineFunctionedVariables(self._spline_points, self._spline_close, dtype=self._dtype)
            self._dropout_value = np.clip(dropout or 0, 0, 1)
            self._dropout = tf.placeholder_with_default(
                tf.zeros([], self._dtype), [], name='Dropout')
            # Optimizer
            self._learning_rate = tf.placeholder(tf.float32, (), name='LearningRate')
            if self._devices is not None and len(self._devices) > 1:
                # Is this right??
                self._learning_rate *= len(self._devices)
            self._optimizer = tf.train.AdamOptimizer(learning_rate=self._learning_rate)
            # Training variables
            self._global_step = tf.get_variable('GlobalStep', [], tf.int32,
                                                initializer=tf.zeros_initializer(),
                                                trainable=False)
            tf.add_to_collection(tf.GraphKeys.GLOBAL_STEP, self._global_step)
            self._best_loss = tf.get_variable('BestLoss', [], self._dtype,
                                              initializer=tf.constant_initializer(np.inf, self._dtype),
                                              trainable=False)
            # Make model for each device
            devices = self._devices if self._devices else [None]
            batch_devices = zip(
                tf.split(input_, len(devices)),
                tf.split(coords, len(devices)),
                tf.split(self._label, len(devices)),
                tf.split(self._weight, len(devices)))
            grad_devices = []
            output_devices = []
            output_devices_unscaled = []
            output_devices_logits = []
            loss_devices = []
            batch_size_devices = []
            regularize = False if regularization is None else (regularization > 0)
            reuse = False
            for i_device, (device, batch_device) in enumerate(zip(devices, batch_devices)):
                with tf.name_scope('Device_{}'.format(i_device)), tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
                    input_device, coords_device, label_device, weight_device = batch_device
                    with tf.device(device), self._spline_vars.use_coordinates(coords_device):
                        output_device, output_device_unscaled, output_device_logits = self._make_net(
                            input_device, self._dropout, regularize)
                        loss_device, loss_agg_device = self._make_loss(label_device, output_device_unscaled,
                                                                       output_device_logits, weight_device)
                        grad_devices.append(self._optimizer.compute_gradients(loss_agg_device))
                        output_devices.append(output_device)
                        output_devices_unscaled.append(output_device_unscaled)
                        output_devices_logits.append(output_device_logits)
                        loss_devices.append(loss_device)
                        batch_size_devices.append(tf.shape(output_device)[0])
                reuse = True
            # Concatenate
            self._output = tf.concat(output_devices, axis=0, name='Output')
            output_unscaled = tf.concat(output_devices_unscaled, axis=0, name='OutputUnscaled')
            output_logits = tf.concat(output_devices_logits, axis=0, name='OutputLogits')
            with tf.name_scope('Loss'):
                loss = tf.concat(loss_devices, axis=0)
                batch_loss = tf.reduce_mean(loss, name='BatchLoss')
            # Metrics
            with tf.variable_scope('Metrics'):
                self._metric_values, self._metric_summaries = self._make_metrics(
                    loss, self._label, self._output, output_unscaled, output_logits, self._weight)
                self._reset_metrics = tf.variables_initializer(tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))
            # Training
            with tf.variable_scope('Train'):
                self._train_op, grad_agg = self._make_train(grad_devices)
                grad_var_hists = []
                for grad, var in grad_agg:
                    grad_var_hists.append(tf.summary.histogram('value', var, family=var.name))
                    grad_var_hists.append(tf.summary.histogram('gradient', grad, family=var.name))
                self._grad_var_hist = tf.summary.merge(grad_var_hists)
            self._saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=1)
            self._session_manager = tf.train.SessionManager(graph=self._graph)
            # Initialization
            self._init_op = tf.global_variables_initializer()
            tf.add_to_collection(tf.GraphKeys.INIT_OP, self._init_op)
            self._graph.finalize()

    @property
    def logdir(self): return self._logdir
    @property
    def graph(self): return self._graph
    @property
    def input_size(self): return self._input_size
    @property
    def output_size(self): return self._output_size
    @property
    def spline_points(self): return self._spline_points
    @property
    def spline_close(self): return self._spline_close
    @property
    def hidden_layers(self): return self._hidden_layers
    @property
    def use_bias(self): return self._use_bias
    @property
    def input_mean(self): return self._input_mean
    @property
    def input_scale(self): return self._input_scale
    @property
    def output_mean(self): return self._output_mean
    @property
    def output_scale(self): return self._output_scale

    @property
    def num_params(self):
        return sum(v.shape.num_elements() for v in self._graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

    def train(self, training_record_file_pattern, max_steps, batch_size, shuffle_buffer_size,
              learning_rate=0.001, validation_record_file_pattern=None, test_record_file_pattern=None,
              evaluation_steps=10, patience=None,
              session=None):
        training_record_file_pattern = str(training_record_file_pattern)
        if patience is None or patience < 0:
            patience = None
        base_patience = patience
        train_writer = tf.summary.FileWriter(
            str(Path(self._logdir, 'log', 'train')))
        validation_writer = None
        if validation_record_file_pattern:
            validation_record_file_pattern= str(validation_record_file_pattern)
            validation_writer = tf.summary.FileWriter(
                str(Path(self._logdir, 'log', 'validation')))
        test_writer = None
        if test_record_file_pattern:
            test_record_file_pattern = str(test_record_file_pattern)
            test_writer = tf.summary.FileWriter(
                str(Path(self._logdir, 'log', 'test')))
        close_session = False
        if session is None:
            session = self.new_session(self.checkpoint_save_path())
            close_session = True
        try:
            with self._graph.as_default():
                base_step = session.run(self._global_step)
                current_best_loss = session.run(self._best_loss)
                if base_step == 0:
                    self._logger.info('Starting training.')
                    self._logger.info('Initial evaluation...')
                    current_best_loss = self._evaluate_dataset(
                        training_record_file_pattern, batch_size, session,
                        train_writer, base_step)
                    if validation_record_file_pattern:
                        current_best_loss = self._evaluate_dataset(
                            validation_record_file_pattern, batch_size, session,
                            validation_writer, base_step)
                    if test_record_file_pattern:
                        self._evaluate_dataset(
                            test_record_file_pattern, batch_size, session,
                            test_writer, base_step)
                    self._best_loss.load(current_best_loss, session=session)
                    self._save_checkpoint(session)
                    self._save_best(session)
                else:
                    self._logger.info('Resuming training from step {}.'.format(base_step))
                # Training iterator handle
                iter_handle = session.run(self._train_iterator_handle)
                # Initialize training iterator
                session.run(self._train_iterator_init_op, feed_dict={
                    self._file_pattern: training_record_file_pattern,
                    self._shuffle_buffer_size: shuffle_buffer_size,
                    self._batch_size: batch_size,
                })
                # Train loop
                for i_step in range(base_step + 1, max_steps + 1):
                    is_evaluation_step = (i_step % evaluation_steps == 0 or i_step == max_steps)
                    # Train step
                    feed_dict = {
                        self._dropout: self._dropout_value,
                        self._input_iterator_handle: iter_handle,
                        self._learning_rate: learning_rate,
                    }
                    if is_evaluation_step:
                        _, grad_var_hist = session.run([self._train_op, self._grad_var_hist], feed_dict=feed_dict)
                    else:
                        session.run(self._train_op, feed_dict=feed_dict)
                    # Update step count
                    self._global_step.load(i_step, session=session)
                    if is_evaluation_step:
                        self._logger.info('Step {}/{} - Evaluating...'.format(i_step, max_steps))
                        # Write histograms
                        train_writer.add_summary(grad_var_hist, i_step)
                        # Evaluate
                        current_loss = self._evaluate_dataset(
                            training_record_file_pattern, batch_size, session,
                            train_writer, i_step)
                        if validation_record_file_pattern:
                            current_loss = self._evaluate_dataset(
                                validation_record_file_pattern, batch_size, session,
                                validation_writer, i_step)
                        if test_record_file_pattern:
                            self._evaluate_dataset(
                                test_record_file_pattern, batch_size, session,
                                test_writer, i_step)
                        # Checkpoint
                        if current_loss < current_best_loss:
                            current_best_loss = current_loss
                            self._best_loss.load(current_best_loss, session=session)
                            self._save_best(session)
                            patience = base_patience
                        self._save_checkpoint(session)
                        if current_loss >= current_best_loss:
                            if patience is not None:
                                patience -= 1
                                if patience < 0:
                                    self._logger.info(('Early stopping at step {}.').format(i_step))
                                    break
        finally:
            final_step = session.run(self._global_step)
            # Final checkpoint
            self._save_checkpoint(session)
            if close_session:
                session.close()
            train_writer.close()
            if validation_writer is not None:
                validation_writer.close()
            if test_writer is not None:
                test_writer.close()
            # Final checkpointed and best graphs
            with self.new_session(self.checkpoint_save_path()) as sess:
                rt_graph = self.make_runtime_graph(session=sess)
                tf.train.write_graph(rt_graph, self._logdir, 'gfnn_checkpoint.tf', as_text=False)
            with self.new_session(self.best_save_path()) as sess:
                rt_graph = self.make_runtime_graph(session=sess)
                tf.train.write_graph(rt_graph, self._logdir, 'gfnn_best.tf', as_text=False)
            self._logger.info('Training finished after {} steps.\n'.format(final_step))

    def evaluate(self, record_file_pattern, batch_size, return_summary=False, session=None):
        record_file_pattern = str(record_file_pattern)
        with self._graph.as_default():
            batch_size = batch_size or len(examples)
            close_session = False
            if session is None:
                session = self.new_session(self.best_save_path())
                close_session = True
            try:
                iter_handle = session.run(self._test_iterator_handle)
                # Initialize iterator
                session.run(self._test_iterator_init_op, feed_dict={
                    self._file_pattern: record_file_pattern,
                    self._batch_size: batch_size
                })
                # Reset metric variables
                session.run(self._reset_metrics)
                while True:
                    try:
                        metrics, summary = session.run(
                            [self._metric_values, self._metric_summaries], feed_dict={
                                self._dropout: 0,
                                self._input_iterator_handle: iter_handle,
                            })
                    except tf.errors.OutOfRangeError: break
                session.run(self._reset_metrics)
                loss = metrics['Loss']
                return (loss, metrics, summary) if return_summary else (loss, metrics)
            finally:
                if close_session:
                    session.close()

    def predict(self, examples, coords, session=None, batch_size=None):
        with self._graph.as_default():
            batch_size = batch_size or len(examples)
            close_session = False
            if session is None:
                session = self.new_session(self.best_save_path())
                close_session = True
            try:
                outputs = []
                for i_example in range(0, len(examples), batch_size):
                    examples_batch = examples[i_example:i_example + batch_size]
                    coords_batch = coords[i_example:i_example + batch_size]
                    output_batch = session.run(self._output, feed_dict={
                        self._input: examples_batch,
                        self._coords: coords_batch,
                        self._dropout: 0
                    })
                    outputs.append(output_batch)
                return np.concatenate(outputs, axis=0)
            finally:
                if close_session:
                    session.close()

    def new_session(self, checkpoint=None):
        checkpoint = checkpoint or self.best_save_path()
        if len(glob(checkpoint + '*')) <= 0:
            checkpoint = None
        with self._graph.as_default():
            return self._session_manager.prepare_session(
                master='',
                init_op=self._init_op,
                saver=self._saver,
                checkpoint_filename_with_path=checkpoint,
                config=tf.ConfigProto(allow_soft_placement=True))

    def checkpoint_save_path(self):
        return str(Path(self._logdir, 'checkpoint', 'model.ckpt'))

    def best_save_path(self):
        return str(Path(self._logdir, 'best', 'model.ckpt'))

    def make_runtime_graph(self, full=True, linear=False, constant=False, session=None):
        # Make sure there are no unnecessary linear and constant evaluations
        if linear:
            linear = np.broadcast_to(linear, self.spline_points.shape)
            linear = [l if p > 2 else min(l, p) for l, p in zip(linear, self.spline_points)]
        if constant:
            constant = np.broadcast_to(constant, self.spline_points.shape)
            constant = [l if p > 1 else min(l, p) for l, p in zip(constant, self.spline_points)]
        runtime_graph = tf.Graph()
        close_session = False
        if session is None:
            session = self.new_session(self.best_save_path())
            close_session = True
        try:
            with runtime_graph.as_default():
                with tf.name_scope('Meta'):
                    self._make_metadata(full, linear, constant)
                input_rt = tf.placeholder(self._input.dtype,
                                          self._input.shape[1].value,
                                          name='Input')
                coords_rt = tf.placeholder(self._coords.dtype,
                                           [len(self._spline_points)],
                                           name='Coords')
                input_rt = tf.expand_dims(input_rt, 0)
                coords_rt = tf.expand_dims(coords_rt, 0)
                input_rt, coords_rt = self._preprocess_fn(input_rt, coords_rt)
                if full:
                    with tf.name_scope('GFNN'), self._spline_vars.frozen(coords_rt, session):
                        output, _, _ = self._make_net(input_rt)
                        output = output[0]
                    output = tf.identity(output, name='Output')
                    output = tf.identity(output, name='OutputCubic')
                if linear:
                    with tf.name_scope('GFNNLinear'), self._spline_vars.frozen_linear(coords_rt, session, linear):
                        output_linear, _, _ = self._make_net(input_rt)
                        output_linear = output_linear[0]
                    output_linear = tf.identity(output_linear, name='OutputLinear')
                if constant:
                    with tf.name_scope('GFNNConstant'), self._spline_vars.frozen_closest(coords_rt, session, constant):
                        output_constant, _, _ = self._make_net(input_rt)
                        output_constant = output_constant[0]
                    output_constant = tf.identity(output_constant, name='OutputConstant')
                runtime_graph.finalize()
        finally:
            if close_session:
                session.close()
        return runtime_graph

    def _configure_logger(self):
        self._logger.setLevel(logging.DEBUG)
        # formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        formatter = logging.Formatter('%(asctime)s - %(message)s')
        ch = OneLineLogHandler(sys.stdout)
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        ch = logging.FileHandler(str(Path(self._logdir, 'training.log')))
        ch.setLevel(logging.INFO)
        ch.setFormatter(formatter)
        logger.addHandler(ch)

    @staticmethod
    def _immutable_numpy_array(array):
        if array is not None:
            array = np.array(array)
            array.flags.writeable = False
        return array

    def _make_input_pipeline(self, input_size, output_size):
        # Input pipeline parameters
        self._file_pattern = tf.placeholder(tf.string, shape=(), name='FilePattern')
        self._batch_size = tf.placeholder(tf.int64, shape=(), name='BatchSize')
        self._shuffle_buffer_size = tf.placeholder(tf.int64, shape=(), name='ShuffleBufferSize')
        # TFRecord parsing lambda
        concurrency = os.cpu_count()
        dataset = (tf.data.Dataset.list_files(self._file_pattern)
                   .shuffle(1000)
                   .apply(tf.data.experimental.parallel_interleave(
                        tf.data.TFRecordDataset, cycle_length=concurrency, sloppy=False)))
        # Train dataset
        train_dataset = (dataset
                         .repeat()
                         .apply(tf.data.experimental.shuffle_and_repeat(self._shuffle_buffer_size))
                         .map(self._parse_example_fn, num_parallel_calls=concurrency)
                         .batch(self._batch_size)
                         .prefetch(1))
        train_iterator = train_dataset.make_initializable_iterator()
        self._train_iterator_init_op = train_iterator.initializer
        self._train_iterator_handle = train_iterator.string_handle()
        # Test iterator
        test_dataset = (dataset
                        .map(self._parse_example_fn, num_parallel_calls=concurrency)
                        .batch(self._batch_size)
                        .prefetch(1))
        test_iterator = test_dataset.make_initializable_iterator()
        self._test_iterator_init_op = test_iterator.initializer
        self._test_iterator_handle = test_iterator.string_handle()
        # Feedable iterator
        self._input_iterator_handle = tf.placeholder(tf.string, shape=(), name='IteratorHandle')
        input_iterator = tf.data.Iterator.from_string_handle(
            self._input_iterator_handle, train_dataset.output_types, train_dataset.output_shapes)
        return input_iterator.get_next()

    def _default_parse_example_fn(self, example_proto):
        features = {'x': tf.FixedLenFeature((self.input_size,), tf.float32),
                    'coords': tf.FixedLenFeature((len(self.spline_points,)), tf.float32),
                    'y': tf.FixedLenFeature((self.output_size,), tf.float32),
                    'weight': tf.FixedLenFeature((), tf.float32, default_value=1.)}
        parsed_features = tf.parse_single_example(example_proto, features)
        return parsed_features['x'], parsed_features['coords'], parsed_features['y'], parsed_features['weight']

    def _make_net(self, input_, dropout=None, regularize=False):
        # Input is already preprocessed
        if self._input_mean is not None:
            input_ -= self._input_mean
        if self._input_scale is not None:
            input_ /= self._input_scale
        input_ = self._preprocess_scaled_fn(input_)
        keep_prob = 1
        if dropout is not None:
            keep_prob = 1 - dropout
        y = tf.expand_dims(input_, 1)
        if dropout is not None:
            y = tf.nn.dropout(y, keep_prob)
        regularizer = lambda w: tf.reduce_sum(tf.square(w))
        size = input_.shape[-1].value
        for i, layer in enumerate(self._hidden_layers):
            with tf.variable_scope('Layer_{}'.format(i + 1)):
                w = self._spline_vars.get_variable(
                    'W', (size, layer), dtype=self._dtype,
                    initializer=tf.glorot_uniform_initializer(), regularizer=regularizer)
                w = tf.identity(w, name='W/Interpolated')
                b = None
                if self._use_bias:
                    b = self._spline_vars.get_variable('B', (1, layer), dtype=self._dtype,
                                                       initializer=tf.zeros_initializer())
                    b = tf.identity(b, name='B/Interpolated')
                y = self._make_layer(y, w, b)
                y = self._hidden_activation_fn(y)
                if dropout is not None:
                    y = tf.nn.dropout(y, keep_prob)
            size = layer
        with tf.name_scope('Output'):
            w = self._spline_vars.get_variable(
                'W', (size, self._output_size), dtype=self._dtype,
                initializer=tf.glorot_uniform_initializer(), regularizer=regularizer)
            w = tf.identity(w, name='W/Interpolated')
            b = None
            if self._use_bias:
                b = self._spline_vars.get_variable('B', (1, self._output_size), dtype=self._dtype,
                                                   initializer=tf.zeros_initializer())
                b = tf.identity(b, name='B/Interpolated')
            y = self._make_layer(y, w, b)
            y = tf.squeeze(y, 1)
            y_logits = y
            y = self._output_activation_fn(y)
            y_unscaled = y
            if self._output_scale is not None:
                y *= self._output_scale
            if self._output_mean is not None:
                y += self._output_mean
            y = self._postprocess_fn(y)
        return y, y_unscaled, y_logits

    def _make_layer(self, x, w, b=None):
        squeeze = x.shape[0].value == 1
        if squeeze:
            x = tf.squeeze(x, 0)
            w = tf.squeeze(w, 0)
            b = tf.squeeze(b, 0) if b is not None else b
        y = x @ w
        if b is not None:
            y = y + b
        if squeeze:
            y = tf.expand_dims(y, 0)
        return y

    def _make_metadata(self, full, linear, constant):
        tf.constant(self.spline_points, name='SplinePoints', dtype=tf.int32)
        tf.constant(self.spline_close, name='SplineClose', dtype=tf.bool)
        tf.constant(self.input_size, name='InputSize', dtype=tf.int32)
        tf.constant(self.output_size, name='OutputSize', dtype=tf.int32)
        tf.constant(self.hidden_layers, name='HiddenLayers', dtype=tf.int32)
        tf.constant(self.use_bias, name='UseBias', dtype=tf.bool)
        tf.constant(self.num_params, name='NumParameters', dtype=tf.int32)
        if full:
            tf.constant(self.spline_points, name='GridEvaluationsCubic', dtype=tf.int32)
        if linear:
            linear_grid_shape = np.broadcast_to(linear, self.spline_points.shape)
            tf.constant(linear_grid_shape, name='GridEvaluationsLinear', dtype=tf.int32)
        if constant:
            constant_grid_shape = np.broadcast_to(constant, self.spline_points.shape)
            tf.constant(constant_grid_shape, name='GridEvaluationsConstant', dtype=tf.int32)

    def _make_loss(self, label, output_unscaled, output_logits, weight):
        dtype = self._dtype
        if self._output_mean is not None:
            label -= self._output_mean
        if self._output_scale is not None:
            label /= self._output_scale
        loss = tf.identity(self._loss_fn(label, output_unscaled, output_logits), name='Loss')
        loss = tf.multiply(loss, weight, name='WeightedLoss')
        loss_agg = tf.reduce_mean(loss, name='BatchLoss')
        if self._regularization is not None and self._regularization > 0:
            reg_factor = self._regularization / (2 * tf.cast(tf.shape(output_logits)[0], self._dtype))
            # reg_factor /= np.prod(self._spline_points)
            loss_agg = tf.add(loss_agg, reg_factor * tf.losses.get_regularization_loss(), name='TotalLoss')
        return loss, loss_agg

    def _make_metrics(self, loss, label, output, output_unscaled, output_logits, weight):
        dtype = self._dtype
        label_unscaled = label
        if self._output_mean is not None:
            label_unscaled -= self._output_mean
        if self._output_scale is not None:
            label_unscaled /= self._output_scale
        metrics = {'Loss': tf.metrics.mean(loss, weights=weight)}
        custom_metrics = self._metrics_fn(label, label_unscaled, output, output_unscaled, output_logits, weight)
        for k in custom_metrics:
            if k in metrics:
                raise ValueError('cannot use the name "{}" for metrics.'.format(k))
        metrics.update(custom_metrics)
        metric_values = {}
        metric_updates = []
        metric_summaries = []
        for name, (value, update) in metrics.items():
            metric_values[name] = value
            metric_updates.append(update)
            metric_summaries.append(tf.summary.scalar(name, value))
        with tf.control_dependencies([tf.group(metric_updates)]):
            for name in metric_values:
                metric_values[name] = tf.identity(metric_values[name])
            all_summaries = tf.summary.merge(metric_summaries)
        return metric_values, all_summaries

    def _make_train(self, grad_devices):
        with tf.device(self._devices[0] if self._devices else None):
            if len(grad_devices) > 1:
                grads_avg = []
                for var_grads in zip(*grad_devices):
                    grads, vars = zip(*var_grads)
                    grads_avg.append((tf.reduce_mean(tf.stack(grads, axis=0), axis=0), vars[0]))
            else:
                grads_avg = grad_devices.pop()
            return self._optimizer.apply_gradients(grads_avg, global_step=self._global_step), grads_avg

    def _default_loss_fn(self, label_unscaled, output_unscaled, output_logits):
        return tf.reduce_sum(tf.squared_difference(label_unscaled, output_unscaled), axis=1)

    def _evaluate_dataset(self, record_files, batch_size, session,
                          writer=None, global_step=None):
        loss, _, summary = self.evaluate(record_files, batch_size, return_summary=True, session=session)
        if writer is not None:
            writer.add_summary(summary, global_step=global_step)
        return loss

    def _save_checkpoint(self, session):
        with session.graph.as_default():
            self._saver.save(session, self.checkpoint_save_path(), write_meta_graph=False, write_state=True)
            # Save runtime model
            rt_graph = self.make_runtime_graph(session=session)
            tf.train.write_graph(rt_graph, self._logdir, 'gfnn_checkpoint.tf', as_text=False)
            # tf.train.write_graph(rt_graph, self._logdir, 'gfnn_checkpoint.txt', as_text=True)

    def _save_best(self, session):
        with session.graph.as_default():
            self._saver.save(session, self.best_save_path(), write_meta_graph=False, write_state=False)
            # Save runtime model
            rt_graph = self.make_runtime_graph(session=session)
            tf.train.write_graph(rt_graph, self._logdir, 'gfnn_best.tf', as_text=False)
            # tf.train.write_graph(rt_graph, self._logdir, 'gfnn_best.txt', as_text=True)

class OneLineLogHandler(logging.Handler):
    def __init__(self, stream=None, level=logging.NOTSET):
        super().__init__(level)
        self._stream = stream or sys.stderr
        self._last_line_size = 0
    def emit(self, record):
        msg, *lines = self.format(record).split('\n')
        if self._last_line_size > 0:
            # self._stream.write("\b" * self._last_line_size)
            self._stream.write('\r')
        if len(msg) < self._last_line_size:
            msg += " " * (self._last_line_size - len(msg))
        self._stream.write(msg)
        self._last_line_size = len(msg)
        for line in lines:
            self._stream.write('\n')
            self._stream.write(line)
            self._last_line_size = len(line)
        self._stream.flush()
