from contextlib import contextmanager
import tensorflow as tf
import numpy as np

class SplineFunctionedVariables:

    def __init__(self, spline_points, closed, dtype=tf.float32):
        self._spline_points_orig = np.atleast_1d(spline_points).astype(np.int32)
        self._spline_points_orig.flags.writeable = False
        if np.any(self._spline_points_orig <= 0):
            raise ValueError('spline points must be positive.')
        self._dims_mask = self._spline_points_orig != 1
        self._dims_mask.flags.writeable = False
        self._spline_points = self._spline_points_orig[self._dims_mask]
        self._spline_points.flags.writeable = False
        self._spline_axes = np.arange(len(self._spline_points_orig))[self._dims_mask]
        self._spline_axes.flags.writeable = False
        self._closed = np.broadcast_to(closed, self._spline_points_orig.shape).astype(np.bool_)
        self._closed = self._closed[self._dims_mask]
        self._closed.flags.writeable = False
        self._spline_ndims = len(self._spline_points)
        self._dtype = tf.dtypes.as_dtype(dtype)
        self._axis_weight_dims = None
        self._evaluation_mode = []

    @contextmanager
    def use_coordinates(self, coordinates):
        coordinates = self._ensure_2d(coordinates)
        coordinates = tf.gather(coordinates, self._spline_axes, axis=-1)
        weights, batch_size = self._spline_weights(coordinates)
        def eval_spline(spline_value):
            with coordinates.graph.as_default():
                val =  self._evaluate_spline(spline_value, weights, batch_size)
                return val
        self._evaluation_mode.append(eval_spline)
        self._evaluation_mode[-1].graph = coordinates.graph
        self._evaluation_mode[-1].reuse = False
        yield
        self._evaluation_mode.pop()

    @contextmanager
    def frozen(self, coordinates, session, feed_dict=None):
        feed_dict = feed_dict or {}
        coordinates = self._ensure_2d(coordinates)
        coordinates = tf.gather(coordinates, self._spline_axes, axis=-1)
        weights, batch_size = self._spline_weights(coordinates)
        def eval_cubic_value(spline_value):
            value = session.run(spline_value, feed_dict=feed_dict)
            with coordinates.graph.as_default():
                return self._evaluate_spline(value, weights, batch_size)
        self._evaluation_mode.append(eval_cubic_value)
        self._evaluation_mode[-1].graph = session.graph
        self._evaluation_mode[-1].reuse = True
        yield
        self._evaluation_mode.pop()

    @contextmanager
    def frozen_linear(self, coordinates, session, evaluations, feed_dict=None):
        feed_dict = feed_dict or {}
        coordinates = self._ensure_2d(coordinates)
        evaluations = np.broadcast_to(evaluations, self._spline_points_orig.shape)
        evaluations = evaluations[self._dims_mask]
        with coordinates.graph.as_default():
            k_coords, t_rel_dims = self._linear_weights(coordinates, evaluations)
        eval_graph = tf.Graph()
        eval_session = tf.Session(graph=eval_graph)
        try:
            with eval_graph.as_default():
                eval_coordinates = self._make_evaluation_coordinates(evaluations)
                eval_weights, batch_size = self._spline_weights(eval_coordinates)
            def eval_linear_value(spline_value):
                value = session.run(spline_value, feed_dict=feed_dict)
                with eval_graph.as_default():
                    linear_value = eval_session.run(self._evaluate_spline(value, eval_weights, batch_size))
                if evaluations.size > 0:
                    new_shape = np.r_[evaluations, linear_value.shape[1:]].astype(int)
                    linear_value = np.reshape(linear_value, new_shape)
                with coordinates.graph.as_default():
                    return self._evaluate_linear(linear_value, k_coords, t_rel_dims)
            self._evaluation_mode.append(eval_linear_value)
            self._evaluation_mode[-1].graph = session.graph
            self._evaluation_mode[-1].reuse = True
            yield
            self._evaluation_mode.pop()
        finally:
            eval_session.close()

    @contextmanager
    def frozen_closest(self, coordinates, session, evaluations, feed_dict=None):
        feed_dict = feed_dict or {}
        coordinates = self._ensure_2d(coordinates)
        evaluations = np.broadcast_to(evaluations, self._spline_points_orig.shape)
        evaluations = evaluations[self._dims_mask]
        with coordinates.graph.as_default():
            index = self._closest_weights(coordinates, evaluations)
        outer_session = tf.get_default_session()
        eval_graph = tf.Graph()
        eval_session = tf.Session(graph=eval_graph)
        try:
            with eval_graph.as_default():
                eval_coordinates = self._make_evaluation_coordinates(evaluations)
                eval_weights, batch_size = self._spline_weights(eval_coordinates)
            def eval_closest_value(spline_value):
                value = session.run(spline_value, feed_dict=feed_dict)
                with eval_graph.as_default():
                    closest_value = eval_session.run(self._evaluate_spline(value, eval_weights, batch_size))
                if evaluations.size > 0:
                    new_shape = np.r_[evaluations, closest_value.shape[1:]].astype(int)
                    closest_value = np.reshape(closest_value, new_shape)
                with coordinates.graph.as_default():
                    return self._evaluate_closest(closest_value, index)
            self._evaluation_mode.append(eval_closest_value)
            self._evaluation_mode[-1].graph = session.graph
            self._evaluation_mode[-1].reuse = True
            yield
            self._evaluation_mode.pop()
        finally:
            eval_session.close()

    def get_variable(self, name, shape, dtype=None, initializer=None, regularizer=None, **kwargs):
        if not self._evaluation_mode:
            raise RuntimeError('no active evaluation mode.')
        spline_shape = list(self._spline_points) + [int(s) for s in shape]
        eval_mode = self._evaluation_mode[-1]
        vs = tf.get_variable_scope()
        with eval_mode.graph.as_default(), tf.variable_scope(vs, reuse=eval_mode.reuse):
            spline_var = tf.get_variable(name, shape=spline_shape, dtype=dtype, initializer=initializer, **kwargs)
        spline_eval = eval_mode(spline_var)
        if not eval_mode.reuse and regularizer is not None:
            with eval_mode.graph.as_default(), tf.variable_scope(vs):
                tf.losses.add_loss(regularizer(spline_eval) / tf.cast(tf.shape(spline_eval)[0], spline_eval.dtype),
                                   tf.GraphKeys.REGULARIZATION_LOSSES)
        return spline_eval

    def get_value(self, x):
        if not self._evaluation_mode:
            raise RuntimeError('no active evaluation mode.')
        x = tf.convert_to_tensor(x)
        for i_dim, valid_dim in list(enumerate(self._dims_mask))[::-1]:
            if not valid_dim:
                x = tf.squeeze(x, axis=i_dim)
        eval_mode = self._evaluation_mode[-1]
        spline_eval = eval_mode(x)
        return spline_eval

    def _spline_weights(self, coordinates):
        axis_weight_dims = []
        dtype = self._dtype
        coordinates = tf.convert_to_tensor(coordinates)
        batch_size = tf.shape(coordinates)[0] if coordinates.shape[0].value != 1 else None
        for i_axis, (n, close) in enumerate(zip(self._spline_points, self._closed)):
            if n == 1:
                axis_weight_dims.append((None, None))
                continue
            coord = coordinates[:, i_axis]
            # Find indices of interpolated points
            if close:
                q = tf.cast(n, dtype) * coord
                k = tf.cast(q, tf.int32)
                k1 = k % n
                k0 = (k - 1) % n
                k2 = (k + 1) % n
                k3 = (k + 2) % n
            else:
                q = tf.cast(n - 1, dtype) * coord
                k = tf.cast(q, tf.int32)
                k1 = tf.minimum(tf.maximum(    k, 0), n - 1)
                k0 = tf.minimum(tf.maximum(k - 1, 0), n - 1)
                k2 = tf.minimum(tf.maximum(k + 1, 0), n - 1)
                k3 = tf.minimum(tf.maximum(k + 2, 0), n - 1)
            ks = tf.stack([k0, k1, k2, k3], axis=-1)
            # Compute weights
            alpha1 = q - tf.floor(q)
            alpha0 = tf.ones_like(alpha1)
            alpha2 = tf.square(alpha1)
            alpha3 = alpha2 * alpha1
            alpha1_05 = 0.5 * alpha1
            alpha2_05 = 0.5 * alpha2
            alpha2_20 = alpha2 + alpha2
            alpha2_25 = 2.5 * alpha2
            alpha3_05 = 0.5 * alpha3
            alpha3_15 = 1.5 * alpha3
            w0 = alpha2 - alpha1_05 - alpha3_05
            w1 = alpha0 - alpha2_25 + alpha3_15
            w2 = alpha1_05 + alpha2_20 - alpha3_15
            w3 = alpha3_05 - alpha2_05
            ws = tf.stack([w0, w1, w2, w3], axis=-1)
            if ks.shape[0].value == 1 and n >= 4:
                # Optimization for big grids at runtime
                # Use indices to select subset of parameters
                idx = tf.squeeze(ks, 0)
            else:
                # Default case
                # All parameters are multiplied by their corresponding weight
                idx = None
                ks_1h = tf.one_hot(ks, n, dtype=dtype)
                ws = tf.reduce_sum(tf.expand_dims(ws, axis=-1) * ks_1h, axis=1)
            # Add additional dimensions at beginning and end
            ws = tf.reshape(ws, [-1] + [1] * i_axis + [ws.shape[1].value] + [1] * (self._spline_ndims - i_axis - 1))
            axis_weight_dims.append((idx, [ws]))
        return axis_weight_dims, batch_size

    def _evaluate_spline(self, value, axis_weight_dims, batch_size):
        # Slice where required
        for axis, (idx, _) in enumerate(axis_weight_dims):
            if idx is not None:
                value = tf.gather(value, idx, axis=axis)
        # Sort dimensions by descending number of points
        sorted_axis_weight_dims = sorted(enumerate(axis_weight_dims), key=lambda el: -self._spline_points[el[0]])
        # Interpolate across dimensions
        value = tf.expand_dims(value, 0)
        value_ndim = value.shape.ndims - self._spline_ndims
        for axis, (idx, ws) in sorted_axis_weight_dims:
            if ws is not None:
                while value_ndim > len(ws):
                    ws.append(tf.expand_dims(ws[-1], axis=-1))
                value_ws = ws[value_ndim - 1]
                value = tf.reduce_sum(value * value_ws, axis=1 + axis, keepdims=True)
        # Squeeze spline dimensions
        s = tf.shape(value)
        value = tf.reshape(value, tf.concat([[s[0]], s[1 + self._spline_ndims:]], axis=0))
        # Tile if necessary
        if value.shape[0].value == 1 and batch_size is not None:
            tiles = tf.concat([[batch_size], tf.ones(tf.rank(value) - 1, dtype=batch_size.dtype)], axis=0)
            value = tf.tile(value, tiles)
        return value

    def _linear_weights(self, coordinates, linear_points):
        if self._spline_ndims <= 0:
            batch_size = tf.shape(coordinates)[0] if coordinates.shape[0].value != 1 else None
            return batch_size, None
        b = tf.shape(coordinates)[0]
        # "Filter" singleton dimensions
        idx = [i for i, d in enumerate(linear_points) if d > 1]
        n = linear_points[idx]
        close = self._closed[idx]
        coords = tf.gather(coordinates, self._spline_axes[idx], axis=-1)
        # Find interpolation indices
        t = tf.where(close, n, n - 1)
        t = tf.cast(t, self._dtype) * coords
        k = tf.cast(t, tf.int32)
        close_k = tf.tile(tf.expand_dims(close, 0), [b, 1])
        k_ini = tf.where(close_k, k % n, tf.minimum(tf.maximum(k, 0), n - 1))
        k_end = tf.where(close_k, (k + 1) % n, tf.minimum(tf.maximum(k + 1, 0), n - 1))
        # Interpolation factor
        t_rel_end = t % 1
        t_rel_ini = 1 - t_rel_end
        # Interpolation hypercube indices and factors
        r = 1 << len(n)
        d_pow = tf.bitwise.left_shift(1, tf.range(len(n)))
        ind_ini = tf.bitwise.bitwise_and(tf.expand_dims(tf.range(r), 1), d_pow)
        ind_ini = tf.cast(ind_ini, tf.bool)
        ind_ini = tf.tile(tf.expand_dims(ind_ini, 0), (b, 1, 1))
        k_coords = tf.where(ind_ini,
                            tf.tile(tf.expand_dims(k_ini, 1), (1, r, 1)),
                            tf.tile(tf.expand_dims(k_end, 1), (1, r, 1)))
        t_rel = tf.where(ind_ini,
                         tf.tile(tf.expand_dims(t_rel_ini, 1), (1, r, 1)),
                         tf.tile(tf.expand_dims(t_rel_end, 1), (1, r, 1)))
        t_rel = tf.reduce_prod(t_rel, axis=-1)
        return k_coords, [t_rel]

    def _evaluate_linear(self, value, k_coords, t_rel_dims):
        if self._spline_ndims <= 0:
            batch_size = k_coords
            if batch_size is not None:
                return tf.tile(tf.expand_dims(value, 0), [batch_size] + ([1] * value.shape.ndims))
            else:
                return value
        # Squeeze singleton dimensions
        for i in reversed(range(self._spline_ndims)):
            if value.shape[i] is not None and int(value.shape[i]) == 1:
                if isinstance(value, np.ndarray):
                    value = np.squeeze(value, axis=i)
                else:
                    value = tf.squeeze(value, axis=i)
        value_k = tf.gather_nd(value, k_coords)
        value_ndim = value_k.shape.ndims - 2
        while value_ndim >= len(t_rel_dims):
            t_rel_dims.append(tf.expand_dims(t_rel_dims[-1], -1))
        return tf.reduce_sum(t_rel_dims[value_ndim] * value_k, axis=1)

    def _closest_weights(self, coordinates, closest_points):
        if self._spline_ndims <= 0:
            batch_size = tf.shape(coordinates)[0] if coordinates.shape[0].value != 1 else None
            return batch_size
        b = tf.shape(coordinates)[0]
        # "Filter" singleton dimensions
        idx = [i for i, d in enumerate(closest_points) if d > 1]
        n = closest_points[idx]
        close = self._closed[idx]
        coords = tf.gather(coordinates, self._spline_axes[idx], axis=-1)
        # Find closest indices
        t = tf.where(close, n, n - 1)
        t = tf.cast(t, self._dtype) * coords
        t = tf.cast(tf.round(t), tf.int32)
        close_idx = tf.tile(tf.expand_dims(close, 0), [b, 1])
        idx = tf.where(close_idx, t % n, tf.minimum(tf.maximum(t, 0), n - 1))
        return idx

    def _evaluate_closest(self, value, index):
        if self._spline_ndims <= 0:
            bs = index
            return tf.tile(tf.expand_dims(value, 0), [bs] + ([1] * value.shape.ndims)) if bs is not None else value
        # Squeeze singleton dimensions
        for i in reversed(range(self._spline_ndims)):
            if value.shape[i] is not None and int(value.shape[i]) == 1:
                if isinstance(value, np.ndarray):
                    value = np.squeeze(value, axis=i)
                else:
                    value = tf.squeeze(value, axis=i)
        return tf.gather_nd(value, index)

    def _make_evaluation_coordinates(self, evaluations):
        dtype_np = self._dtype.as_numpy_dtype
        eval_spaces = []
        for n, closed in zip(evaluations, self._closed):
            if closed:
                eval_spaces.append(np.arange(n, dtype=dtype_np) / n)
            else:
                eval_spaces.append(np.asarray(np.linspace(0, 1, n), dtype=dtype_np))
        grid = np.meshgrid(*eval_spaces, indexing='ij')
        return np.stack([np.reshape(g, (-1,)) for g in grid], axis=1) if grid else np.empty([1, 0], np.int32)

    @staticmethod
    def _ensure_2d(coordinates):
        d = coordinates.shape.ndims
        if d is None or d < 0 or d > 2:
            raise ValueError('coordinates must be broadcastable to rank 2.')
        for _ in range(d, 2):
            coordinates = tf.expand_dims(coordinates, 1)
        return coordinates
