from pathlib import Path

import numpy as np
import pandas as pd
import quaternion
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as mplanim


ROOT = 'root'
HEAD = 'Head'
JOINTS = [
    'Hips',
    'Spine',
    'Spine1',
    'Neck',
    'Head',
    'Head_end',
    'LeftShoulder',
    'LeftArm',
    'LeftForeArm',
    'LeftHand',
    'LeftHand_end',
    'RightShoulder',
    'RightArm',
    'RightForeArm',
    'RightHand',
    'RightHand_end',
    'LeftUpLeg',
    'LeftLeg',
    'LeftFoot',
    'LeftFoot_end',
    'RightUpLeg',
    'RightLeg',
    'RightFoot',
    'RightFoot_end',
    'Tail',
    'Tail1',
    'Tail1_end',
]
JOINT_PARENTS = {
    'root': None,
    'Hips': 'root',
    'Spine': 'Hips',
    'Spine1': 'Spine',
    'Neck': 'Spine1',
    'Head': 'Neck',
    'Head_end': 'Head',
    'LeftShoulder': 'Spine1',
    'LeftArm': 'LeftShoulder',
    'LeftForeArm': 'LeftArm',
    'LeftHand': 'LeftForeArm',
    'LeftHand_end': 'LeftHand',
    'RightShoulder': 'Spine1',
    'RightArm': 'RightShoulder',
    'RightForeArm': 'RightArm',
    'RightForeArm': 'RightArm',
    'RightHand': 'RightForeArm',
    'RightHand_end': 'RightHand',
    'LeftUpLeg': 'Hips',
    'LeftLeg': 'LeftUpLeg',
    'LeftFoot': 'LeftLeg',
    'LeftFoot_end': 'LeftFoot',
    'RightUpLeg': 'Hips',
    'RightLeg': 'RightUpLeg',
    'RightFoot': 'RightLeg',
    'RightFoot_end': 'RightFoot',
    'Tail': 'Hips',
    'Tail1': 'Tail',
    'Tail1_end': 'Tail1',
}
FOOT_IDX = [10, 15, 19, 15]
COG = 'Hips'
# STYLES = ['Stand', 'Walk', 'Sit', 'Down', 'Up', 'Jump']
GAITS = ['Stand', 'Walk', 'Trot', 'Gallop']
STANDING_PHASE_MIN_CYCLE = 6  # 1/4 s @ 24 fps
PAST_TRAJECTORY_WINDOW = 6
FUTURE_TRAJECTORY_WINDOW = 6
TRAJECTORY_STEP_SIZE = 4  # 1/6 s @ 24 fps
WITH_TERRAIN = False
WITH_CONTACT = False
STANDING_SPEED_THRESHOLD = 0.1
JOINT_SMOOTH = 0.5
ERROR_BONE_SATURATION = 1_000
ERROR_FRAME_SATURATION = 10_000

TIME_COLUMN = ':time'
STANDING_COLUMN = ':notify:PfnnStanding'
STYLE_COLUMN = ':notify:PfnnStyle:Name'
SWORD_BLOCK = ':notify:SwordBlock'

STEP_NOTIFY = 'DogStep'
PHASES = ['FrontLeft', 'FrontRight', 'BackLeft', 'BackRight']
NUM_PHASES = len(PHASES)
STANDING = 'PfnnStanding'
STYLE_NOTIFY = 'DogStyle:Name'

x_trajectory_start = 0
x_past_pos_start = x_trajectory_start
x_past_pos_end = x_past_pos_start + 2 * PAST_TRAJECTORY_WINDOW
x_future_pos_start = x_past_pos_end
x_future_pos_end = x_future_pos_start + 2 * FUTURE_TRAJECTORY_WINDOW
x_past_dir_start = x_future_pos_end
x_past_dir_end = x_past_dir_start + 2 * PAST_TRAJECTORY_WINDOW
x_future_dir_start = x_past_dir_end
x_future_dir_end = x_future_dir_start + 2 * FUTURE_TRAJECTORY_WINDOW
x_style_start = x_future_dir_end
x_style_end = x_style_start + len(GAITS) * (PAST_TRAJECTORY_WINDOW + FUTURE_TRAJECTORY_WINDOW)
x_trajectory_end = x_style_end
x_joint_start = x_trajectory_end
x_joint_pos_start = x_joint_start
x_joint_pos_end = x_joint_pos_start + 3 * len(JOINTS)
x_joint_vel_start = x_joint_pos_end
x_joint_vel_end = x_joint_vel_start + 3 * len(JOINTS)
x_joint_end = x_joint_vel_end
x_terr_start = x_joint_end
x_past_terr_start = x_terr_start
x_past_terr_end = x_past_terr_start + (3 * PAST_TRAJECTORY_WINDOW if WITH_TERRAIN else 0)
x_future_terr_start = x_past_terr_end
x_future_terr_end = x_future_terr_start + (3 * FUTURE_TRAJECTORY_WINDOW if WITH_TERRAIN else 0)
x_terr_end = x_future_terr_end
x_end = x_terr_end
input_size = x_end

phase_size = NUM_PHASES

y_root_vel_start = 0
y_root_vel_end = y_root_vel_start + 2
y_root_angvel_start = y_root_vel_end
y_root_angvel_end = y_root_angvel_start + 1
y_phase_start = y_root_angvel_end
y_phase_end = y_phase_start + NUM_PHASES
y_contact_start = y_phase_end
y_contact_end = y_contact_start + (4 if WITH_CONTACT else 0)
y_trajectory_start = y_contact_end
y_future_pos_start = y_trajectory_start
y_future_pos_end = y_future_pos_start + 2 * FUTURE_TRAJECTORY_WINDOW
y_future_dir_start = y_future_pos_end
y_future_dir_end = y_future_dir_start + 2 * FUTURE_TRAJECTORY_WINDOW
y_trajectory_end = y_future_dir_end
y_joint_start = y_trajectory_end
y_joint_pos_start = y_joint_start
y_joint_pos_end = y_joint_pos_start + 3 * len(JOINTS)
y_joint_vel_start = y_joint_pos_end
y_joint_vel_end = y_joint_vel_start + 3 * len(JOINTS)
y_joint_xy_start = y_joint_vel_end
y_joint_xy_end = y_joint_xy_start + 6 * len(JOINTS)
y_joint_end = y_joint_xy_end
y_end = y_joint_end
output_size = y_end


def get_data(record):
    with tf.Graph().as_default(), tf.Session() as sess:
        ds = (tf.data.TFRecordDataset([str(record)])
            .map(parse_example, num_parallel_calls=os.cpu_count())
            .batch(10000)
            .prefetch(1))
        it = ds.make_one_shot_iterator()
        it_next = it.get_next()
        data = []
        while True:
            try:
                data.append(sess.run(it_next))
            except tf.errors.OutOfRangeError: break
        return [np.concatenate(x) for x in zip(*data)]

def parse_example(example_proto):
    features = {'x': tf.FixedLenFeature((input_size,), tf.float32),
                'phase': tf.FixedLenFeature((phase_size), tf.float32),
                'y': tf.FixedLenFeature((output_size,), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['phase'], parsed_features['y']

def run_model(model, x, phase, sync=1):
    with tf.Graph().as_default(), tf.Session() as sess:
        gd = tf.GraphDef()
        with open(model, 'rb') as f:
            gd.MergeFromString(f.read())
        inp, coords, out = tf.import_graph_def(
            gd, name='', return_elements=['Input:0', 'Coords:0', 'OutputCubic:0'])
        it = iter(range(len(x)))
        res = []
        xi = np.zeros_like(x[0])
        pi = np.zeros_like(phase[0])
        try:
            while True:
                i = next(it)
                pi[:] = phase[i]
                res.append(sess.run(out, feed_dict={inp: x[i], coords: pi}))
                for i in range(sync - 1):
                    i = next(it)
                    xi[:] = x[i]
                    xi[x_joint_pos_start:x_joint_pos_end] = res[-1][x_joint_pos_start:x_joint_pos_end]
                    xi[x_joint_vel_start:x_joint_vel_end] = res[-1][x_joint_vel_start:x_joint_vel_end]
                    pi += res[-1][y_phase_start:y_phase_end]
                    res.append(sess.run(out, feed_dict={inp: xi, coords: pi}))
        except StopIteration: pass
        return np.stack(res)

def split_gaits(x, phase, y):
    sts = x[:, x_style_start:x_style_end].reshape(-1, len(GAITS), PAST_TRAJECTORY_WINDOW + FUTURE_TRAJECTORY_WINDOW)
    st = sts[:, :, PAST_TRAJECTORY_WINDOW]
    idx = np.argmax(st, axis=1)
    res = []
    for i in range(len(GAITS)):
        m = idx == i
        res.append((x[m], phase[m], y[m]))
    return res

def phase_pairs(phase, name):
    pg = sns.pairplot(pd.DataFrame(phase, columns=PHASES), plot_kws={'s': 5})
    pg.fig.subplots_adjust(top=.97)
    pg.fig.suptitle(name, y=.99)
    return pg

# Transforms are multiplied FROM THE LEFT: Transform * Object
# Rightmost transform is applied first: Transform3 * Transform2 * Transform1
def make_transform_relative(tx, root):
    tx_rel = np.linalg.inv(root) @ tx
    return normalize_transform(tx_rel)

def normalize_transform(tx):
    tx = np.asarray(tx)
    return tx / tx[..., 3:, 3:]

def rotate_vector(v, tx):
    tx = normalize_transform(tx)
    v = np.asarray(v)
    v_rot = tx[..., :3, :3] @ v[..., np.newaxis]
    return v_rot[..., 0]

def transform_vector(v, tx):
    v = np.asarray(v)
    v = np.concatenate([v, np.ones_like(v[..., :1])], axis=-1)
    v_rot = tx @ v[..., np.newaxis]
    return v_rot[..., :3, 0] / v_rot[..., 3:, 0]

def make_joint_transform(pos, x, y):
    pos = np.asarray(pos)
    s = np.shape(x)
    d = s[-1]
    xs = np.reshape(x, (-1, d))
    ys = np.reshape(y, (-1, d))
    xs /= np.maximum(np.linalg.norm(xs, axis=-1, keepdims=True), 1e-8)
    ys /= np.maximum(np.linalg.norm(ys, axis=-1, keepdims=True), 1e-8)
    zs = np.cross(xs, ys, axis=-1)
    zs /= np.maximum(np.linalg.norm(zs, axis=-1, keepdims=True), 1e-8)
    ys = np.cross(zs, xs, axis=-1)
    x = np.reshape(xs, s)
    y = np.reshape(ys, s)
    z = np.reshape(zs, s)
    tx_h = np.stack([x, y, z, pos], axis=-1)
    h = np.tile([[0, 0, 0, 1]], s[:-1] + (1, 1))
    tx = np.concatenate([tx_h, h], axis=-2)
    #assert np.allclose(np.linalg.det(tx[..., :3, :3]), 1), f'Invalid rotation matrix\n{tx[..., :3, :3]}'
    return tx

def triangles_area(triangles):
    triangles = np.asarray(triangles)
    ab = triangles[..., 1, :] - triangles[..., 0, :]
    ac = triangles[..., 2, :] - triangles[..., 0, :]
    return np.linalg.norm(np.cross(ab, ac), axis=-1) / 2

def segments_error(segments1, segments2):
    segments1 = np.asarray(segments1)
    segments2 = np.asarray(segments2)
    triangles = np.empty(segments1.shape[:-2] + (4, 3, 3), dtype=segments1.dtype)
    mid_points = np.mean(np.concatenate([segments1, segments2], axis=-2), axis=-2)
    triangles[..., 0, 0, :] = segments1[..., 0, :]
    triangles[..., 0, 1, :] = segments1[..., 1, :]
    triangles[..., 0, 2, :] = mid_points
    triangles[..., 1, 0, :] = segments1[..., 0, :]
    triangles[..., 1, 1, :] = segments2[..., 0, :]
    triangles[..., 1, 2, :] = mid_points
    triangles[..., 2, 0, :] = segments2[..., 0, :]
    triangles[..., 2, 1, :] = segments2[..., 1, :]
    triangles[..., 2, 2, :] = mid_points
    triangles[..., 3, 0, :] = segments1[..., 1, :]
    triangles[..., 3, 1, :] = segments2[..., 1, :]
    triangles[..., 3, 2, :] = mid_points
    areas = triangles_area(triangles)
    return areas.sum(axis=-1)

def make_segments(pos):
    seg_idx = [(i, JOINTS.index(JOINT_PARENTS[j])) for i, j in enumerate(JOINTS) if JOINT_PARENTS[j] in JOINTS]
    seg = np.empty((len(pos), len(seg_idx), 2, 3), pos.dtype)
    seg[:, :, 0] = pos[:, [i for i, _ in seg_idx]]
    seg[:, :, 1] = pos[:, [i for _, i in seg_idx]]
    return seg

def model_score(model, datas, sync=1):
    score = []
    for data in datas:
        if not data: continue
        x, p, y = data
        y_segs = make_segments(y[:, y_joint_pos_start:y_joint_pos_end].reshape(len(y), -1, 3))
        y2 = run_model(model, x, p, sync)
        y2_segs = make_segments(y2[:, y_joint_pos_start:y_joint_pos_end].reshape(len(y2), -1, 3))
        score.append(segments_error(y_segs, y2_segs).sum(1))
    return np.concatenate(score)

def compare_models(models, records, sync=1):
    print('Reading data...')
    datas = [get_data(r) for r in records]
    s = {}
    for model in models:
        model_name = Path(model).parent.stem
        print(f'Evaluating model {model_name}...')
        s[model_name] = model_score(model, datas, sync)
    return pd.DataFrame(s)

def plot_scores(scores, ax=None):
    plt.figure()
    df = scores.unstack().reset_index(0)
    df.columns = ['model', 'value']
    ax = sns.boxenplot(x='model', y='value', data=df, ax=ax)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    return ax


# m = (np.core.defchararray.find(walk_pat, ' ') == -1)

WALK_PATTERN_TRESHOLD = 2

def walk_pattern(phases, thres=0):
    n_foot = phases.shape[1]
    contacts = np.diff(phases, axis=0) < 0
    n_foot = phases.shape[1]
    walk_pat = np.zeros(len(phases), dtype='U{}'.format(n_foot))
    current = ' ' * n_foot
    last_contact = [-1] * n_foot
    for i in range(len(contacts)):
        c = contacts[i]
        if c.any():
            for i_foot in range(n_foot):
                if c[i_foot]:
                    last_contact[i_foot] = i
            current, last_contact = make_walk_pattern(last_contact, thres)
        walk_pat[i] = current
    walk_pat[-1] = current
    return walk_pat

def make_walk_pattern(last_contact, thres=0):
    if not any(f >= 0 for _, f in enumerate(last_contact)):
        return ' ' * last_contact
    g = [[] for _ in last_contact]
    for i1, f1 in enumerate(last_contact):
        if f1 < 0: continue
        for i2, f2 in enumerate(last_contact[i1 + 1:], i1 + 1):
            if f2 < 0: continue
            if abs(f1 - f2) <= thres:
                g[i1].append(i2)
    groups = []
    rem = list(range(len(last_contact)))
    first_idx = len(last_contact)
    while rem:
        i = rem.pop(0)
        if last_contact[i] < 0 or any(i in gr[0] for gr in groups):
            continue
        group = []
        group_f = -1
        rem_group = [i]
        while rem_group:
            i2 = rem_group.pop(0)
            first_idx = min(i2, first_idx)
            group.append(i2)
            group_f = max(last_contact[i2], group_f)
            rem_group.extend(i3 for i3 in g[i2] if i3 not in group and i3 not in rem_group)
        groups.append((group, group_f))
    groups = sorted(groups, key=lambda gr: gr[1])
    g0_idx = next(i for i, gr in enumerate(groups) if first_idx in gr[0])
    pattern = []
    new_contact = last_contact.copy()
    for i in range(len(groups)):
        i_group = (g0_idx + i) % len(groups)
        group, group_f = groups[i_group]
        for i in sorted(group):
            pattern.append(i)
            new_contact[i] = group_f
    pad = ' ' * (len(last_contact) - len(pattern))
    return ''.join(map(str, pattern)) + pad, new_contact

def find_longest_sequence(a):
    a = np.pad(np.asarray(a, dtype=bool), (1, 1), 'constant')
    idx = np.where(np.diff(a) != 0)[0].reshape(-1, 2)
    return tuple(idx[np.diff(idx, axis=1).argmax()]) if len(idx) else (0, 0)

def window_max(a, window):
    a = a.ravel()
    s, = a.strides
    assert window >= 1 and len(a) >= window
    wins = np.lib.stride_tricks.as_strided(a, [len(a) - window + 1, window], [s, s], writeable=False)
    pad_l, pad_r = window // 2, (window - 1) // 2
    m = wins.max(1)
    return np.r_[[m[0]] * pad_l, m, [m[-1]] * pad_r]

# parents = [JOINTS.index(JOINT_PARENTS.get(j, None)) if JOINT_PARENTS.get(j, None) in JOINTS else -1 for j in JOINTS]

def plot_skel(joint_pos, parents, ax=None, lines=None):
    if ax is None:
        plt.figure()
        ax = plt.subplot(111, projection='3d')
        mean = np.mean(joint_pos, axis=0)
        d = joint_pos - mean
        low, high = mean + d.min() * 1.1, mean + d.max() * 1.1
        ax.set_xlim(low[0], high[0])
        ax.set_ylim(low[1], high[1])
        ax.set_zlim(low[2], high[2])
        ax.set_aspect('equal')
    if lines is None:
        lines = [ax.plot3D([], [], [], c='k', lw=1)[0] for _ in parents]
    for i_joint, (i_plt, i_parent) in enumerate(zip(lines, parents)):
        if i_parent < 0: continue
        i_plt.set_data(joint_pos[[i_parent, i_joint], 0], joint_pos[[i_parent, i_joint], 1])
        i_plt.set_3d_properties(joint_pos[[i_parent, i_joint], 2])
    return ax, lines

def make_animation_skel(joint_pos, parents, fps=24, ax=None):
    ax, lines = plot_skel(joint_pos[0], parents, ax=ax)
    f = lambda i: plot_skel(joint_pos[i], parents, ax=ax, lines=lines)[1]
    return mplanim.FuncAnimation(ax.figure, f, range(len(joint_pos)), interval=1000./fps, repeat=False, blit=True)

# for i, k in enumerate(walk_kinds):
#     make_animation(walk_kinds_longest[i], parents, fps=24).save(f'{k}_example.mp4', writer=mplanim.writers['ffmpeg'](fps=24), dpi=100)
#     plt.close()


def plot_traj(traj_points, ax=None, lines=None):
    if ax is None:
        plt.figure()
        ax = plt.subplot(111)
        mean = np.mean(traj_points, axis=0)
        d = traj_points - mean
        low, high = mean + d.min() * 1.1, mean + d.max() * 1.1
        ax.set_xlim(low[0], high[0])
        ax.set_ylim(low[1], high[1])
        ax.set_aspect('equal')
    if lines is None:
        lines = [ax.plot([], [], [], c='k', lw=1)[0]]
    lines[0].set_data(traj_points[:, 0], traj_points[:, 1])
    return ax, lines

def make_animation_traj(traj_points, fps=24, ax=None):
    ax, lines = plot_traj(traj_points[0], ax=ax)
    mean = np.mean(traj_points, axis=(0, 1))
    d = traj_points - mean
    low, high = mean + d.min() * 1.1, mean + d.max() * 1.1
    ax.set_xlim(low[0], high[0])
    ax.set_ylim(low[1], high[1])
    ax.set_aspect('equal')
    f = lambda i: plot_traj(traj_points[i], ax=ax, lines=lines)[1]
    return mplanim.FuncAnimation(ax.figure, f, range(len(traj_points)), interval=1000./fps, repeat=False, blit=True)

def rotation_matrix_from_yaw(yaw):
    s = np.sin(yaw)
    c = np.cos(yaw)
    o = np.ones_like(yaw)
    z = np.zeros_like(yaw)
    return np.stack([np.stack([ c, -s,  z], axis=-1),
                     np.stack([ s,  c,  z], axis=-1),
                     np.stack([ z,  z,  o], axis=-1)], axis=-2)

def get_joint_speeds(joint_pos, root_vel, root_angvel):
    n = len(root_vel)
    joint_pos_prev = np.roll(joint_pos, 1, axis=0)
    joint_pos_prev[0] = joint_pos_prev[1]
    root_vel_prev = np.roll(root_vel, 1, axis=0)
    root_vel_prev[0] = 0
    root_angvel = np.roll(root_angvel, 1, axis=0)
    root_angvel[0] = 0
    rel_roots = np.zeros((n, 4, 4))
    rel_roots[:, -1, -1] = 1
    rel_roots[:, :2, 3] = root_vel_prev
    rel_roots[:, :3, :3] = rotation_matrix_from_yaw(root_angvel)
    joint_pos_h = np.concatenate([joint_pos, np.ones_like(joint_pos[..., :1])], axis=-1)
    joint_pos_rel = (rel_roots[:, np.newaxis] @ joint_pos_h[..., np.newaxis])[..., :-1, 0]
    return np.linalg.norm(joint_pos_rel - joint_pos_prev, axis=-1)

def peaks_cover(a):
    a = np.asarray(a)
    d = np.diff(a.ravel())
    p = np.where((d[:-1] * d[1:] < 0) & (d[:-1] > 0))[0] + 1
    v = a[p]
    return np.interp(np.arange(len(a)), p, v)


# Transforms are multiplied FROM THE LEFT: Transform * Object
# Rightmost transform is applied first: Transform3 * Transform2 * Transform1
def make_transform_relative(tx, root):
    tx_rel = np.linalg.inv(root) @ tx
    return normalize_transform(tx_rel)

def normalize_transform(tx):
    tx = np.asarray(tx)
    return tx / tx[..., 3:, 3:]

def rotate_vector(v, tx):
    tx = normalize_transform(tx)
    v = np.asarray(v)
    v_rot = tx[..., :3, :3] @ v[..., np.newaxis]
    return v_rot[..., 0]

def transform_vector(v, tx):
    v = np.asarray(v)
    v = np.concatenate([v, np.ones_like(v[..., :1])], axis=-1)
    v_rot = tx @ v[..., np.newaxis]
    return v_rot[..., :3, 0] / v_rot[..., 3:, 0]

def triangles_area(triangles):
    triangles = np.asarray(triangles)
    ab = triangles[..., 1, :] - triangles[..., 0, :]
    ac = triangles[..., 2, :] - triangles[..., 0, :]
    return np.linalg.norm(np.cross(ab, ac), axis=-1) / 2

def segments_error(segments1, segments2):
    segments1 = np.asarray(segments1)
    segments2 = np.asarray(segments2)
    triangles = np.empty(segments1.shape[:-2] + (4, 3, 3), dtype=segments1.dtype)
    mid_points = np.mean(np.concatenate([segments1, segments2], axis=-2), axis=-2)
    triangles[..., 0, 0, :] = segments1[..., 0, :]
    triangles[..., 0, 1, :] = segments1[..., 1, :]
    triangles[..., 0, 2, :] = mid_points
    triangles[..., 1, 0, :] = segments1[..., 0, :]
    triangles[..., 1, 1, :] = segments2[..., 0, :]
    triangles[..., 1, 2, :] = mid_points
    triangles[..., 2, 0, :] = segments2[..., 0, :]
    triangles[..., 2, 1, :] = segments2[..., 1, :]
    triangles[..., 2, 2, :] = mid_points
    triangles[..., 3, 0, :] = segments1[..., 1, :]
    triangles[..., 3, 1, :] = segments2[..., 1, :]
    triangles[..., 3, 2, :] = mid_points
    areas = triangles_area(triangles)
    return areas.sum(axis=-1)

def make_joint_transform(pos, x, y):
    pos = np.asarray(pos)
    s = np.shape(x)
    d = s[-1]
    xs = np.reshape(x, (-1, d))
    ys = np.reshape(y, (-1, d))
    xs /= np.maximum(np.linalg.norm(xs, axis=-1, keepdims=True), 1e-8)
    ys /= np.maximum(np.linalg.norm(ys, axis=-1, keepdims=True), 1e-8)
    zs = np.cross(xs, ys, axis=-1)
    zs /= np.maximum(np.linalg.norm(zs, axis=-1, keepdims=True), 1e-8)
    ys = np.cross(zs, xs, axis=-1)
    x = np.reshape(xs, s)
    y = np.reshape(ys, s)
    z = np.reshape(zs, s)
    tx_h = np.stack([x, y, z, pos], axis=-1)
    h = np.tile([[0, 0, 0, 1]], s[:-1] + (1, 1))
    tx = np.concatenate([tx_h, h], axis=-2)
    #assert np.allclose(np.linalg.det(tx[..., :3, :3]), 1), f'Invalid rotation matrix\n{tx[..., :3, :3]}'
    return tx

def get_transforms(df, joints, dtype=None):
    dtype = dtype or np.float32
    if not joints:
        joints = [col.split(':')[0] for col in df if col.endswith(':cs:position:x')]
    unpack = False
    if np.isscalar(joints):
        joints = [joints]
        unpack = True
    transforms = []
    for joint in joints:
        pos = df[[f'{joint}:cs:position:x',
                  f'{joint}:cs:position:y',
                  f'{joint}:cs:position:z']].values
        quat = df[[f'{joint}:cs:quaternion:w',
                   f'{joint}:cs:quaternion:x',
                   f'{joint}:cs:quaternion:y',
                   f'{joint}:cs:quaternion:z']].values
        quat = quaternion.as_quat_array(quat)
        rot = quaternion.as_rotation_matrix(quat)
        tx = np.concatenate([rot, pos[..., np.newaxis]], axis=2)
        tx = np.concatenate([tx, np.tile([0, 0, 0, 1], (len(tx), 1, 1))], axis=1)
        transforms.append(tx)
    return transforms[0] if unpack else transforms

def fix_maya_rotation(tx):
    tx = normalize_transform(tx)
    # Axes given in columns
    axes_ue = np.array([[ 1,  0,  0],
                        [ 0,  1,  0],
                        [ 0,  0,  1]], dtype=tx.dtype)
    axes_maya = np.array([[ 0,  0,  1],
                          [-1,  0,  0],
                          [ 0, -1,  0]], dtype=tx.dtype)
    tx[..., :3, :3] = tx[..., :3, :3] @ axes_ue @ np.linalg.inv(axes_maya)
    return tx

def get_root_transforms(df):
    return fix_maya_rotation(get_transforms(df, ROOT))

def get_npc_joint_transforms(df):
    npc_tx = get_transforms(df, JOINTS)
    if PLAYER_SWORD_TIP in JOINTS:
        idx = JOINTS.index(PLAYER_SWORD_TIP)
        npc_tx[idx][..., :3, 3] += rotate_vector(PLAYER_SWORD_OFFSET, npc_tx[idx])
    if NPC_SWORD_TIP in JOINTS:
        idx = JOINTS.index(NPC_SWORD_TIP)
        npc_tx[idx][..., :3, 3] += rotate_vector(NPC_SWORD_OFFSET, npc_tx[idx])
    return np.stack(npc_tx, axis=1)

def get_sword_transforms(df):
    player_tx = get_transforms(df, PLAYER_SWORD_TIP)
    npc_tx = get_transforms(df, NPC_SWORD_TIP)
    player_tx[..., :3, 3] += rotate_vector(PLAYER_SWORD_OFFSET, player_tx)
    npc_tx[..., :3, 3] += rotate_vector(NPC_SWORD_OFFSET, npc_tx)
    return player_tx, npc_tx

def get_styles(df, dtype=None):
    dtype = dtype or np.float32
    style_data = np.zeros((len(df), len(STYLES)), dtype)
    if 'stand' in STYLES and STANDING_COLUMN in df:
        style_data[:, STYLES.index('stand')] = df[STANDING_COLUMN].values
    if STYLE_COLUMN not in df:
        return style_data
    style_values = df[STYLE_COLUMN].fillna('').values
    for i, style in enumerate(style_values):
        if not style: continue
        style_tags = style.split(':')
        style_data[i, [STYLES.index(tag) for tag in style_tags]] = 1
    # Overlapping areas
    overlap = np.sum(style_data, axis=1)
    if np.all(overlap <= 1):
        # No overlapping
        return style_data
    if np.any(overlap > 2):
        raise ValueError('no more than two styles can overlap.')
    # Find overlap start and end points
    overlap, = np.nonzero(overlap > 1)
    overlap_border = np.diff(overlap) > 1
    overlap_begin = overlap[np.insert(overlap_border, 0, True)]
    overlap_end = overlap[np.append(overlap_border, True)] + 1
    assert len(overlap_begin) == len(overlap_end)
    # Linear blend in overlapping areas
    for begin, end in zip(overlap_begin, overlap_end):
        alpha = np.linspace(0, 1, end - begin)[:, np.newaxis]
        style_data[begin:end] = style_data[max(begin - 1, 0)] * (1 - alpha) + style_data[min(end, len(df) - 1)] * alpha
    return style_data

def joint_parent_index():
    joint_parents = []
    for joint in JOINTS:
        parent_name = JOINT_PARENTS.get(joint, None)
        if parent_name and parent_name in JOINTS:
            joint_parents.append(JOINTS.index(parent_name))
        else:
            joint_parents.append(-1)
    return joint_parents

def simulate(player_sword_tx, styles, npc_pose, root_tx, model_file, mode):
    with tf.Graph().as_default(), tf.Session() as sess:
        gd = tf.GraphDef()
        with Path(model_file).open('rb') as f:
            gd.MergeFromString(f.read())
        model_in, model_coords, model_out = tf.import_graph_def(
            gd, name='', return_elements=['Input:0', 'Coords:0', f'Output{mode.title()}:0'])
        del gd
        num_frames = len(player_sword_tx)
        coords_size = model_coords.shape[-1].value if model_coords.shape.ndims > 0 else None
        trajectory_size = TRAJECTORY_STEP_SIZE * (TRAJECTORY_WINDOW - 1) + 1
        player_sword_trajectory = np.tile(player_sword_tx[0], (trajectory_size, 1, 1))
        npc_sword_trajectory = np.tile(npc_pose[JOINTS.index(NPC_SWORD_TIP)], (trajectory_size, 1, 1))
        npc_simulation = np.tile(npc_pose, (num_frames, 1, 1, 1))
        root_simulation = np.tile(root_tx, (num_frames, 1, 1))
        coords_simulation = np.zeros((num_frames, coords_size) if coords_size is not None else num_frames,
                                     dtype=player_sword_tx.dtype)
        prev_root_tx = root_tx
        npc_joint_vel = np.zeros_like(npc_pose[:, :3, 3])
        for i, (current_player_sword_tx, current_style) in enumerate(zip(player_sword_tx[:-1], styles[:-1])):
            # Update trajectories
            player_sword_trajectory[:-1] = player_sword_trajectory[1:]
            player_sword_trajectory[-1] = current_player_sword_tx
            npc_sword_trajectory[:-1] = npc_sword_trajectory[1:]
            npc_sword_trajectory[-1] = npc_pose[JOINTS.index(NPC_SWORD_TIP)]
            # Make input
            in_vector, in_coords = make_input_data(player_sword_trajectory, npc_sword_trajectory,
                                                   current_style, npc_pose, npc_joint_vel, root_tx,
                                                   coords_size)
            coords_simulation[i] = in_coords
            # Run model
            out_vector = sess.run(model_out, feed_dict={model_in: in_vector, model_coords: in_coords})
            # TODO New root tx
            root_vel = out_vector[0:2]
            root_angvel = out_vector[2]
            root_tx, prev_root_tx =  root_tx, root_tx
            root_simulation[i + 1] = root_tx
            # New pose
            npc_pos_out = out_vector[3:3 + 3 * len(JOINTS)].reshape((-1, 3))
            npc_vel_out = out_vector[3 + 3 * len(JOINTS):3 + 6 * len(JOINTS)].reshape((-1, 3))
            npc_prev_pos_rel = make_transform_relative(npc_pose, prev_root_tx)[:, :3, 3]
            npc_pos_out = (1 - JOINT_SMOOTH) * (npc_prev_pos_rel + npc_vel_out) + JOINT_SMOOTH * npc_pos_out
            npc_xy_out = out_vector[3 + 6 * len(JOINTS):3 + 12 * len(JOINTS)].reshape((-1, 2, 3))
            npc_x_out = npc_xy_out[:, 0]
            npc_y_out = npc_xy_out[:, 1]
            npc_tx_out = make_joint_transform(npc_pos_out, npc_x_out, npc_y_out)
            npc_pose = root_tx @ npc_tx_out
            npc_joint_vel = rotate_vector(npc_vel_out, root_tx)
            npc_simulation[i + 1] = npc_pose
    if len(coords_simulation) > 1:
        coords_simulation[-1] = coords_simulation[-2]
    return npc_simulation, root_simulation, coords_simulation

def make_input_data(player_sword_trajectory, npc_sword_trajectory, style, npc_pose, npc_joint_vel, root_tx, coords_size):
    input_size = (4 * 3 * TRAJECTORY_WINDOW + 7 + 4 + len(STYLES) + 2 * 3 * len(JOINTS))
    # Sword relative to COG
    cog_tx = root_tx
    if COG:
        cog_tx = cog_tx.copy()
        cog_tx[2, 3] = npc_pose[JOINTS.index(COG)][2, 3]
    # assert np.allclose(np.linalg.det(cog_tx[:3, :3]), 1)
    player_sword_rel_tx = make_transform_relative(player_sword_trajectory[::TRAJECTORY_STEP_SIZE], cog_tx)
    npc_sword_rel_tx = make_transform_relative(npc_sword_trajectory[::TRAJECTORY_STEP_SIZE], cog_tx)
    # Sword position
    player_sword_pos = player_sword_rel_tx[:, :3, 3]
    player_hilt_pos = transform_vector(PLAYER_SWORD_VECTOR, player_sword_rel_tx)
    npc_sword_pos = npc_sword_rel_tx[:, :3, 3]
    npc_hilt_pos = transform_vector(NPC_SWORD_VECTOR, npc_sword_rel_tx)
    # Sword direction
    player_sword_dir = player_hilt_pos - player_sword_pos
    player_sword_dir /= np.linalg.norm(player_sword_dir, axis=-1, keepdims=True)
    npc_sword_dir = npc_hilt_pos - npc_sword_pos
    npc_sword_dir /= np.linalg.norm(npc_sword_dir, axis=-1, keepdims=True)
    # Sword close points
    player_sword_close, npc_sword_close = SegmentsClosestPoints(player_sword_pos[-1],
                                                                player_hilt_pos[-1],
                                                                npc_sword_pos[-1],
                                                                npc_hilt_pos[-1])
    sword_close_dist = np.linalg.norm(player_sword_close - npc_sword_close, axis=-1)
    # Sword angles
    player_sword_elevation = np.arctan2(player_sword_pos[..., -1, 2], player_sword_pos[..., -1, 0])
    player_sword_azimuth = np.arctan2(player_sword_pos[..., -1, 1], player_sword_pos[..., -1, 0])
    player_hilt_elevation = np.arctan2(player_hilt_pos[..., -1, 2], player_hilt_pos[..., -1, 0])
    player_hilt_azimuth = np.arctan2(player_hilt_pos[..., -1, 1], player_hilt_pos[..., -1, 0])
    # Joint position
    joint_pos = make_transform_relative(npc_pose, root_tx)[:, :3, 3]
    # Joint velocity
    joint_vel = rotate_vector(npc_joint_vel, np.linalg.inv(root_tx))
    # Make vector
    vector =  np.r_[
        npc_sword_pos.ravel(),
        npc_sword_dir.ravel(),
        player_sword_pos.ravel(),
        player_sword_dir.ravel(),
        npc_sword_close,
        player_sword_close,
        sword_close_dist,
        player_sword_elevation,
        player_sword_azimuth,
        player_hilt_elevation,
        player_hilt_azimuth,
        style,
        joint_pos.ravel(),
        joint_vel.ravel(),
    ]
    assert len(vector) == input_size, f'Expected len(vector) = {input_size}, was {len(vector)}.'
    # Grid coordinates
    if coords_size is None:
        coords = (np.arctan2(player_sword_pos[..., -1, 1], player_sword_pos[..., -1, 2]) + 2 * np.pi) % (2 * np.pi)
    if coords_size == 1:
        coords = np.r_[
            (np.arctan2(player_sword_pos[..., -1, 1], player_sword_pos[..., -1, 2]) + 2 * np.pi) % (2 * np.pi)
        ]
    elif coords_size == 2:
        coords = np.r_[
            player_sword_pos[..., -1, 1] / COORD_WIDTH + 0.5,
            player_sword_pos[..., -1, 2] / COORD_HEIGHT + 0.5,
        ]
    elif coords_size == 3:
        coords = np.r_[
            (player_sword_pos[..., -1, 0] - COORD_DEPTH_RANGE[0]) / (COORD_DEPTH_RANGE[1] - COORD_DEPTH_RANGE[0]),
            player_sword_pos[..., -1, 1] / COORD_WIDTH + 0.5,
            player_sword_pos[..., -1, 2] / COORD_HEIGHT + 0.5,
        ]
    else:
        raise ValueError('invalid coordinate size {}.'.format(coords_size))
    return vector, coords

def simulation_error(npc_joint_tx, root_tx, npc_simulation_tx, root_simulation_tx):
    npc_joint_rel = npc_joint_tx @ np.linalg.inv(root_tx[:, np.newaxis])
    npc_simulation_rel = npc_simulation_tx @ np.linalg.inv(root_simulation_tx[:, np.newaxis])
    joint_parents = joint_parent_index()
    segment_idx = [(i_joint, i_parent if i_parent >= 0 else i_joint) for i_joint, i_parent in enumerate(joint_parents)]
    segment_idx = np.asarray(segment_idx)
    segments1 = np.stack([npc_joint_rel[:, segment_idx[:, 0]][..., :3, 3],
                          npc_joint_rel[:, segment_idx[:, 1]][..., :3, 3]], axis=2)
    segments2 = np.stack([npc_simulation_rel[:, segment_idx[:, 0]][..., :3, 3],
                          npc_simulation_rel[:, segment_idx[:, 1]][..., :3, 3]], axis=2)
    return segments_error(segments1, segments2)

def make_anim_function(ax, player_sword_tx, npc_joint_tx, root_tx, temperature=None):
    num_frames = min(len(player_sword_tx), len(npc_joint_tx), len(root_tx))
    temperature = temperature if temperature is not None else np.zeros((num_frames, len(JOINTS)))
    color_good = mpl.colors.to_rgb('#42BF42')
    color_bad = mpl.colors.to_rgb('#BC4242')
    color_sword = mpl.colors.to_rgb('#4343BA')
    joint_plt = [ax.plot3D([], [], [], c=color_sword, lw=1)[0] for _ in JOINTS]
    player_sword_plt = ax.plot3D([], [], [], c=color_sword, lw=1)[0]
    player_sword_length = np.linalg.norm(PLAYER_SWORD_VECTOR)
    player_sword_rel_tx = make_transform_relative(player_sword_tx[:num_frames], root_tx[:num_frames])
    player_sword_tip = player_sword_rel_tx[:, :3, 3]
    player_sword_hilt = transform_vector(PLAYER_SWORD_VECTOR, player_sword_rel_tx)
    np_joint_rel_tx = make_transform_relative(npc_joint_tx[:num_frames], root_tx[:num_frames, np.newaxis])
    joint_pos = np_joint_rel_tx[:, :, :3, 3]
    joint_parents = joint_parent_index()
    temperature = temperature[:, :, np.newaxis]
    color = np.array(color_good) * (1 - temperature) + np.array(color_bad) * temperature

    def anim_frame(frame):
        nonlocal player_sword_rel_tx, np_joint_rel_tx, color
        nonlocal joint_plt, player_sword_plt
        nonlocal joint_parents
        # Joints
        for i_joint, (i_plt, i_parent) in enumerate(zip(joint_plt, joint_parents)):
            if i_parent < 0: continue
            i_plt.set_data(joint_pos[frame, [i_parent, i_joint], 0], joint_pos[frame, [i_parent, i_joint], 1])
            i_plt.set_3d_properties(joint_pos[frame, [i_parent, i_joint], 2])
            i_plt.set_color(color[frame, i_joint])
        # Sword
        player_sword_plt.set_data([player_sword_tip[frame, 0], player_sword_hilt[frame, 0]],
                                  [player_sword_tip[frame, 1], player_sword_hilt[frame, 1]])
        player_sword_plt.set_3d_properties([player_sword_tip[frame, 2], player_sword_hilt[frame, 2]])
        # Updated plots
        return joint_plt + [player_sword_plt]

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_xlim(-100, 100)
    ax.set_ylim(-100, 100)
    ax.set_zlim(0, 200)
    ax.invert_yaxis()
    ax.set_aspect('equal')
    ax.view_init(elev=25, azim=-40)

    return anim_frame

def merge_anim_functions(anim_funcs):
    return lambda frame: [art for anim_func in anim_funcs for art in anim_func(frame)]

def make_empty_dataframe(columns, dtypes, index=None):
    return pd.concat([pd.Series(name=col, dtype=dt, index=index) for col, dt in zip(columns, dtypes)], axis=1)

# Model analysis

def num_nn_params(input_size, output_size, hidden_layers, use_bias=True):
    nn_params = 0
    last = input_size
    for layer in [*hidden_layers, output_size]:
        nn_params += last * layer
        if use_bias:
            nn_params += layer
        last = layer
    return nn_params

def num_gfnn_params(spline_points, input_size, output_size, hidden_layers, use_bias=True):
    nn_params = num_nn_params(input_size, output_size, hidden_layers, use_bias)
    spline_size = 1
    for p in spline_points:
        spline_size *= p
    return nn_params * spline_size

def num_gfnn_encdec_params(spline_points, input_size, output_size,
                           encoder_layers, hidden_layers, decoder_layers,
                           use_bias=True, out_spline=True):
    encdec_params = 0
    gfnn_params = 0
    size = input_size
    if encoder_layers:
        encdec_params += num_nn_params(size, encoder_layers[-1], encoder_layers[:-1], use_bias)
        size = encoder_layers[-1]
    if hidden_layers:
        gfnn_params += num_nn_params(size, hidden_layers[-1], hidden_layers[:-1], use_bias)
        size = hidden_layers[-1]
    if decoder_layers:
        encdec_params += num_nn_params(size, decoder_layers[-1], decoder_layers[:-1], use_bias)
        size = decoder_layers[-1]
    out_params = num_nn_params(size, output_size, [], use_bias)
    if out_spline:
        gfnn_params += out_params
    else:
        encdec_params += out_params
    spline_size = 1
    for p in spline_points:
        spline_size *= p
    return encdec_params + gfnn_params * spline_size

def flops_spline(spline_points, nn_params):
    flops = 0
    # Bigger dimensions are reduced to 4, dimensions are sorted by size
    spline_points = sorted((min(p, 4) for p in spline_points), reverse=True)
    # Weight computation cost
    for p in spline_points:
        if p > 1:
            # Weight computation fixed cost
            flops += 17
            # Dimensions with size 2 or 3 aggregate weights
            if p < 4:
                flops += 4 * p + 3 * p
    # Interpolation cost
    flops_per_param = 0
    for i, p in enumerate(spline_points):
        if p > 1:
            # Weight multiplication and reduction
            a = np.prod(spline_points[i:])
            flops_per_param += a + (a // p) * (p - 1)
    flops += flops_per_param * nn_params
    return flops

def flops_linear(spline_points, nn_params):
    n_dims = sum(1 for d in spline_points if d > 1)
    return 3 * n_dims + (1 << n_dims) * nn_params

def flops_constant(spline_points):
    n_dims = sum(1 for d in spline_points if d > 1)
    return 1 * n_dims

def flops_nn(input_size, output_size, hidden_layers, use_bias=True):
    flops = 0
    last = input_size
    for layer in [*hidden_layers, output_size]:
        flops += 2 * (last * layer)
        if use_bias:
            flops += layer
        last = layer
    return flops

def test_run_time(model_file, input_size, spline_dim, mode):
    test_exe = Path(Path(__file__).parent, 'test_model', 'test_model.exe').resolve()
    model_path = Path(model_file).resolve()
    args = [str(test_exe), str(model_path), str(input_size), str(spline_dim), str(mode)]
    p = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
    if p.returncode != 0:
        return math.nan
    try:
        return float(p.stdout)
    except ValueError:
        return math.nan

def get_model_metadata(model_files):
    model_names = [model_name for model_name, _, _ in model_files]
    df = make_empty_dataframe(*zip(('input_size', int),
                                   ('output_size', int),
                                   ('hidden_layers', object),
                                   ('use_bias', bool),
                                   ('parameters_nn', int),
                                   ('spline_points', object),
                                   ('parameters_spline', int),
                                   ('mode', object),
                                   ('grid_evaluations', object),
                                   ('parameters_grid', object),
                                   ('flops_interp', int),
                                   ('flops_nn', int),
                                   ('flops_total', int),
                                   ('run_time', float)),
                              index=model_names)
    df.index.name = 'model'
    for i, (model_name, model_file, model_mode) in enumerate(model_files):
        model_mode = model_mode.strip().upper()
        with tf.Graph().as_default(), tf.Session() as sess:
            # Load graph
            gd = tf.GraphDef()
            with Path(model_file).open('rb') as f:
                gd.MergeFromString(f.read())
            inp, coords, out, *metadata = tf.import_graph_def(gd, name='', return_elements=[
                'Input:0', 'Coords:0', f'Output{model_mode.title()}:0',
                'Meta/NumParameters:0', 'Meta/SplinePoints:0', f'Meta/GridEvaluations{model_mode.title()}:0',
                'Meta/HiddenLayers:0', 'Meta/UseBias:0', 'Meta/InputSize:0', 'Meta/OutputSize:0'
            ])
            # Get metadata
            params_spline, spline_points, grid_evals, hidden_layers, use_bias, input_size, output_size = sess.run(metadata)
            params_nn = params_spline // np.prod(spline_points)
            params_grid = params_nn * np.prod(grid_evals)
            # Flops estimation
            if model_mode == 'CUBIC':
                flops_interp = flops_spline(spline_points, params_nn)
            elif model_mode == 'LINEAR':
                flops_interp = flops_linear(spline_points, params_nn)
            elif model_mode == 'CONSTANT':
                flops_interp = flops_constant(spline_points)
            else:
                raise ValueError(f'unknown model mode "{model_mode}"')
            flops_net = flops_nn(input_size, output_size, hidden_layers, use_bias)
            flops_total = flops_interp + flops_net
            run_time = test_run_time(model_file, input_size, len(spline_points), model_mode.title())
            df.loc[model_name, :] = [input_size, output_size, ' '.join(map(str, hidden_layers)), use_bias, params_nn,
                                     ' '.join(map(str, spline_points)), params_spline,
                                     model_mode, ' '.join(map(str, grid_evals)), params_grid,
                                     flops_interp, flops_net, flops_total, run_time]
    return df

def get_error_statistics(errors):
    stats_columns = ['mean', 'std', 'min', 'max', 'p25', 'p50', 'p75']
    #errors = np.asarray(errors)[1:]  # Skip first common frame
    return pd.DataFrame([[np.mean(errors, axis=0), np.std(errors, axis=0),
                          np.min(errors, axis=0), np.max(errors, axis=0),
                          np.percentile(errors, 25, axis=0),
                          np.percentile(errors, 50, axis=0),
                          np.percentile(errors, 75, axis=0)]],
                          columns=stats_columns)

def num_rows_cols(n):
    for a in range(int(math.sqrt(n)), 0, -1):
        if n % a == 0:
            break
    b = n // a
    return min(a, b), max(a, b)

def main(data_files, model_files, out_file, per_data, anim_dir, anim_layout, anim_layout_vert, anim_dpi, error_bone):
    data_names = [data_name for data_name, _ in data_files]
    model_names = [model_name for model_name, _, _ in model_files]
    data_df = make_empty_dataframe(['num_frames'], [int], index=data_names)
    data_df.index.name = 'data'
    model_df = get_model_metadata(model_files)
    data_fps = []
    data_frame_errors = []
    data_anim = []
    for data_name, data_file in data_files:
        print(f'Reading data {data_name}...')
        df = pd.read_csv(data_file, index_col=0)
        num_frames = len(df)
        data_df.loc[data_name, :] = num_frames
        npc_joint_tx = get_npc_joint_transforms(df)
        root_tx = get_root_transforms(df)
        player_sword_tx, npc_sword_tx = get_sword_transforms(df)
        styles = get_styles(df)
        anim_fps = 1. / df[TIME_COLUMN].diff().mean()
        data_fps.append(anim_fps)
        error_df = make_empty_dataframe(model_names, [float] * len(model_files), index=range(num_frames))
        npc_simulation_tx = []
        root_simulation_tx = []
        sim_error = []
        for model_name, model_file, model_mode in model_files:
            print(f'Simulating {model_name} on {data_name}...')
            npc_model_tx, root_model_tx, coords_model = simulate(
                player_sword_tx, styles, npc_joint_tx[0], root_tx[0], model_file, model_mode)
            model_error = simulation_error(npc_joint_tx, root_tx, npc_model_tx, root_model_tx)
            model_frame_error = model_error.sum(1)
            error_df[model_name] = model_frame_error
            npc_simulation_tx.append(npc_model_tx)
            root_simulation_tx.append(root_model_tx)
            sim_error.append(model_error)
        data_frame_errors.append(error_df.melt(var_name='model', value_name='error'))
        data_frame_errors[-1].insert(0, 'data', data_name)
        # Animation
        if anim_dir:
            animation_funcs = []
            if anim_layout:
                rows, cols = anim_layout
            else:
                rows, cols = num_rows_cols(len(model_files) + 1)
            fig = plt.figure(figsize=(5 * cols, 5 * rows), dpi=anim_dpi)
            ax = fig.add_subplot(rows, cols, 1, projection='3d')
            ax.set_title(data_name)
            animation_funcs.append(make_anim_function(ax, player_sword_tx, npc_joint_tx, root_tx))
            for i, (model_name, model_file, _) in enumerate(model_files):
                plot_idx = i + 2
                if anim_layout_vert:
                    plot_idx = (i % rows) * cols + (i // rows) + 1
                simulation_ax = fig.add_subplot(rows, cols, plot_idx, projection='3d')
                simulation_ax.set_title(model_name)
                if error_bone:
                    anim_simulation_temperature = np.clip(sim_error[i] / ERROR_BONE_SATURATION, 0, 1)
                else:
                    anim_simulation_temperature = sim_error[i].copy()
                    anim_simulation_temperature[:] = anim_simulation_temperature.sum(1, keepdims=True)
                    anim_simulation_temperature = np.clip(anim_simulation_temperature / ERROR_FRAME_SATURATION, 0, 1)
                animation_funcs.append(make_anim_function(
                    simulation_ax, player_sword_tx, npc_simulation_tx[i], root_simulation_tx[i], anim_simulation_temperature))
            anim_interval = 1000 / anim_fps
            anim_all_func = merge_anim_functions(animation_funcs)
            fig.tight_layout(pad=1.5)
            data_anim.append(mplanim.FuncAnimation(fig, anim_all_func, range(num_frames), interval=anim_interval, repeat=False, blit=True))
    # Save results
    if out_file:
        print('Saving results...')
        with pd.ExcelWriter(str(out_file)) as writer:
            if model_files:
                model_df.to_excel(writer, sheet_name='Models')
            if data_files:
                data_df.to_excel(writer, sheet_name='Data')
            if data_files and model_files:
                error_df = pd.concat(data_frame_errors, ignore_index=True)
                overall_df = (error_df.groupby(['model'])['error']
                              .apply(get_error_statistics)
                              .reset_index(level=-1, drop=True)
                              .loc[model_names, :])
                overall_df.to_excel(writer, sheet_name='Overall')
                if per_data:
                    data_stats = (error_df.groupby(['data', 'model'])['error']
                                  .apply(get_error_statistics)
                                  .reset_index(level=-1, drop=True))
                    for data_name, _ in data_files:
                        data_df = data_stats.loc[data_name].loc[model_names]
                        data_df.to_excel(writer, sheet_name=data_name)
                error_df.to_excel(writer, index=False, sheet_name='Error')
    # Write animations
    if anim_dir:
        print('Animating...')
        with ThreadPool() as pool:
            pool.starmap(lambda data, fps, anim: anim.save(str(Path(anim_dir, f'{data[0]}.mp4')),
                                                      writer=mplanim.writers['ffmpeg'](fps=fps),
                                                      dpi=anim_dpi),
                         zip(data_files, data_fps, data_anim))
    print('Done.')


# if __name__ == '__main__':
#     main(*ARGS)
