import tensorflow as tf
import numpy as np
from scipy.spatial import KDTree
from scipy.stats import gaussian_kde

x_future_pos_start = 12
x_future_pos_end = 24

def get_data(record):
    with tf.Graph().as_default(), tf.Session() as sess:
        ds = (tf.data.Dataset.list_files(str(record))
              .flat_map(tf.data.TFRecordDataset)
              .map(parse_example, num_parallel_calls=os.cpu_count())
              .batch(10000)
              .prefetch(1))
        it = ds.make_one_shot_iterator()
        it_next = it.get_next()
        data = []
        while True:
            try:
                data.append(sess.run(it_next))
            except tf.errors.OutOfRangeError: break
        return [np.concatenate(x) for x in zip(*data)]

def get_angvel(x):
    traj_size = (x_future_pos_end - x_future_pos_start) // 2
    fut_pos = np.reshape(x[..., x_future_pos_start:x_future_pos_end], (-1, traj_size, 2))
    ang = np.arctan2(fut_pos[..., -1, 1], fut_pos[..., -1, 0])
    vel = np.sum(np.linalg.norm(fut_pos[..., 1:, :] - fut_pos[..., :-1, :], axis=-1), axis=-1)
    return ang, vel

def get_coords(x):
    ang, vel = get_angvel(x)
    coord_x = (ang / (2 * 3.1416) + 0.5).clip(0, 1)
    coord_y = (vel / 250).clip(0, 1)
    return np.stack([coord_x, coord_y], axis=-1)

def parse_example(example_proto):
    features = {'x': tf.FixedLenFeature((input_size,), tf.float32),
                'phase': tf.FixedLenFeature((phase_size), tf.float32),
                'y': tf.FixedLenFeature((output_size,), tf.float32),
                'weight': tf.FixedLenFeature((), tf.float32, default_value=1.)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['phase'], parsed_features['y'], parsed_features['weight']

def write_records(features, output_path):
    import tensorflow as tf
    output_path = Path(output_path)
    output_dir = output_path.parent
    if output_dir.exists() and not output_dir.is_dir():
        raise ValueError(f'Invalid directory {output_dir}.')
    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_path.resolve()
    names, values = zip(*features.items())
    # Write records
    with tf.io.TFRecordWriter(str(output_path)) as writer:
        for example in zip(*values):
            example = tf.train.Example(
                features=tf.train.Features(
                    feature=dict(zip(names, map(float_feature, example)))))
            writer.write(example.SerializeToString())

def float_feature(value):
    import tensorflow as tf
    value = np.atleast_1d(value)
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

xs, ps, ys, ws = get_data(r'data\records\train\*.tfrecord')
coords = get_coords(xs)
n = 30000

## Method 1 - Nearest neighbours across domain
#kdtree = KDTree(coords)
#np.random.seed(0)
#r = np.random.rand(n, 2)
#_, idx = kdtree.query(r)
#xs2 = xs[idx]
#ps2 = ps[idx]
#ys2 = ys[idx]
#ws2 = ws[idx]
#coords2 = get_coords(xs2)

# Method 2 - Sample with inverse of KDE probability density
np.random.seed(0)
kde = gaussian_kde(coords.T)
p = kde(coords.T)
p1 = (1 / p)
p1 /= p1.sum()
idx = np.random.choice(len(coords), size=n, p=p1, replace=True)
xs3 = xs[idx]
ps3 = ps[idx]
ys3 = ys[idx]
ws3 = ws[idx]
coords3 = get_coords(xs3)
idx_uniq, idx_counts = np.unique(idx, return_counts=True)
bins = np.round(np.logspace(0, np.log10(idx_counts.max() + 1), 10))
hist, bins = np.histogram(idx_counts, bins, density=False)
print(f'Sampled {len(idx)} examples.')
print(f'Used {len(idx_uniq)} out of {len(coords)} examples ({100 * (len(idx_uniq) / len(coords)):.2f}%).')
print(f'Max repetition: {idx_counts.max()}')
print(f'Min repetition: {idx_counts.min()}')
for i, h in enumerate(hist):
    print(f'    [{bins[i]:3.0f}, {bins[i + 1]:3.0f}): {h:5d}')
# Sampled 30000 examples.
# Used 10029 out of 28422 examples (35.29%).
# Max repetition: 217
# Min repetition: 1
#     [  1,   2):  6570
#     [  2,   3):  1622
#     [  3,   6):   963
#     [  6,  11):   383
#     [ 11,  20):   235
#     [ 20,  36):   152
#     [ 36,  66):    79
#     [ 66, 120):    16
#     [120, 218):     9


s = 1000
for i, offset in enumerate(range(0, n, s)):
    path = fr'data\records\train_restructured\data{i:03d}.tfrecord'
    write_records({
        'x': xs3[offset:offset + s],
        'phase': ps3[offset:offset + s],
        'y': ys3[offset:offset + s],
        'weight': ws3[offset:offset + s],
    }, path)
