import os
import sys
import math
import argparse
from pathlib import Path
from multiprocessing.pool import ThreadPool
import subprocess
import pickle

# Parse arguments before importing libraries
def parse_args(args=None):
    parser = argparse.ArgumentParser(argument_default=None)
    parser.add_argument('--data', help='record data files', required=True, nargs='+',
                        metavar='[NAME:]PATH', type=named_path)
    parser.add_argument('--models', help='model files', required=True, nargs='+',
                        metavar='[NAME:]PATH[?MODE]', type=named_path_mode('CUBIC'))
    parser.add_argument('--per-data', help='write results for each data file', type=str2bool, default='true')
    parser.add_argument('--write-error', help='write detailed error data', type=str2bool, default='true')
    parser.add_argument('--out', help='results directory', required=True, type=Path)
    parsed_args = parser.parse_args(args=args)
    data = parsed_args.data or []
    if len(set(name for name, _ in data)) != len(data):
        raise ValueError('data file names must be unique.')
    models = parsed_args.models or []
    if len(set(name for name, _, _ in models)) != len(models):
        raise ValueError('model names must be unique.')
    per_data = parsed_args.per_data
    write_error = parsed_args.write_error
    out = parsed_args.out
    return data, models, per_data, write_error, out

def named_path(v):
    try:
        name = ''
        path = v
        if ':' in v: name, path = v.split(':', 1)
        if not path: raise ValueError
        path = Path(path)
        if not name: name = path.stem
        if not name: raise ValueError
        return name, Path(path)
    except BaseException as e:
        raise argparse.ArgumentTypeError('value must have the format [NAME:]PATH.')

def named_path_mode(default_mode):
    def f(v):
        try:
            mode = default_mode
            name_path = v
            if '?' in v: name_path, mode = v.split('?', 1)
            if not name_path: raise ValueError
            name, path = named_path(name_path)
            return name, path, mode.upper()
        except BaseException as e:
            raise argparse.ArgumentTypeError('value must have the format [NAME:]PATH[?MODE].')
    return f

def layout_shape(v):
    try:
        rows, cols = map(int, v.split(','))
        if rows <= 0 or cols <= 0: raise ValueError
        return rows, cols
    except BaseException as e:
        raise argparse.ArgumentTypeError('value must have the format ROWS,COLS.')

def str2bool(v):
    v = str(v).strip().lower()
    if v in {'yes', 'true', 't', 'y', '1'}:
        return True
    elif v in {'no', 'false', 'f', 'n', '0'}:
        return False
    else:
        raise argparse.ArgumentTypeError('boolean value expected.')

if __name__ == '__main__':
    ARGS = parse_args()

import numpy as np
import pandas as pd
import quaternion
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as mplanim
import seaborn as sns


ROOT = 'root'
HEAD = 'Head'
JOINTS = [
    'Hips',
    'Spine',
    'Spine1',
    'Neck',
    'Head',
    'Head_end',
    'LeftShoulder',
    'LeftArm',
    'LeftForeArm',
    'LeftHand',
    'LeftHand_end',
    'RightShoulder',
    'RightArm',
    'RightForeArm',
    'RightHand',
    'RightHand_end',
    'LeftUpLeg',
    'LeftLeg',
    'LeftFoot',
    'LeftFoot_end',
    'RightUpLeg',
    'RightLeg',
    'RightFoot',
    'RightFoot_end',
    'Tail',
    'Tail1',
    'Tail1_end',
]
JOINT_PARENTS = {
    'root': None,
    'Hips': 'root',
    'Spine': 'Hips',
    'Spine1': 'Spine',
    'Neck': 'Spine1',
    'Head': 'Neck',
    'Head_end': 'Head',
    'LeftShoulder': 'Spine1',
    'LeftArm': 'LeftShoulder',
    'LeftForeArm': 'LeftArm',
    'LeftHand': 'LeftForeArm',
    'LeftHand_end': 'LeftHand',
    'RightShoulder': 'Spine1',
    'RightArm': 'RightShoulder',
    'RightForeArm': 'RightArm',
    'RightForeArm': 'RightArm',
    'RightHand': 'RightForeArm',
    'RightHand_end': 'RightHand',
    'LeftUpLeg': 'Hips',
    'LeftLeg': 'LeftUpLeg',
    'LeftFoot': 'LeftLeg',
    'LeftFoot_end': 'LeftFoot',
    'RightUpLeg': 'Hips',
    'RightLeg': 'RightUpLeg',
    'RightFoot': 'RightLeg',
    'RightFoot_end': 'RightFoot',
    'Tail': 'Hips',
    'Tail1': 'Tail',
    'Tail1_end': 'Tail1',
}
FOOT_IDX = [10, 15, 19, 23]
COG = 'Hips'
# STYLES = ['Stand', 'Walk', 'Sit', 'Down', 'Up', 'Jump']
# STYLES = ['Stand', 'Walk', 'Gallop']
STYLES = ['Stand', 'Walk', 'Trot', 'Gallop']
STANDING_PHASE_MIN_CYCLE = 6  # 1/4 s @ 24 fps
PAST_TRAJECTORY_WINDOW = 6
FUTURE_TRAJECTORY_WINDOW = 6
TRAJECTORY_STEP_SIZE = 4  # 1/6 s @ 24 fps
WITH_TERRAIN = False
WITH_CONTACT = False
STANDING_SPEED_THRESHOLD = 0.1
JOINT_SMOOTH = 0.5
ERROR_BONE_SATURATION = 1_000
ERROR_FRAME_SATURATION = 10_000

TIME_COLUMN = ':time'
STANDING_COLUMN = ':notify:PfnnStanding'
STYLE_COLUMN = ':notify:PfnnStyle:Name'
SWORD_BLOCK = ':notify:SwordBlock'

STEP_NOTIFY = 'DogStep'
NUM_PHASES = 4
STANDING = 'PfnnStanding'
STYLE_NOTIFY = 'DogStyle:Name'

x_trajectory_start = 0
x_past_pos_start = x_trajectory_start
x_past_pos_end = x_past_pos_start + 2 * PAST_TRAJECTORY_WINDOW
x_future_pos_start = x_past_pos_end
x_future_pos_end = x_future_pos_start + 2 * FUTURE_TRAJECTORY_WINDOW
x_past_dir_start = x_future_pos_end
x_past_dir_end = x_past_dir_start + 2 * PAST_TRAJECTORY_WINDOW
x_future_dir_start = x_past_dir_end
x_future_dir_end = x_future_dir_start + 2 * FUTURE_TRAJECTORY_WINDOW
x_style_start = x_future_dir_end
x_style_end = x_style_start + len(STYLES) * (PAST_TRAJECTORY_WINDOW + FUTURE_TRAJECTORY_WINDOW)
x_trajectory_end = x_style_end
x_joint_start = x_trajectory_end
x_joint_pos_start = x_joint_start
x_joint_pos_end = x_joint_pos_start + 3 * len(JOINTS)
x_joint_vel_start = x_joint_pos_end
x_joint_vel_end = x_joint_vel_start + 3 * len(JOINTS)
x_joint_end = x_joint_vel_end
x_terr_start = x_joint_end
x_past_terr_start = x_terr_start
x_past_terr_end = x_past_terr_start + (3 * PAST_TRAJECTORY_WINDOW if WITH_TERRAIN else 0)
x_future_terr_start = x_past_terr_end
x_future_terr_end = x_future_terr_start + (3 * FUTURE_TRAJECTORY_WINDOW if WITH_TERRAIN else 0)
x_terr_end = x_future_terr_end
x_end = x_terr_end
input_size = x_end

phase_size = NUM_PHASES

y_root_vel_start = 0
y_root_vel_end = y_root_vel_start + 2
y_root_angvel_start = y_root_vel_end
y_root_angvel_end = y_root_angvel_start + 1
y_phase_start = y_root_angvel_end
y_phase_end = y_phase_start + NUM_PHASES
y_contact_start = y_phase_end
y_contact_end = y_contact_start + (4 if WITH_CONTACT else 0)
y_trajectory_start = y_contact_end
y_future_pos_start = y_trajectory_start
y_future_pos_end = y_future_pos_start + 2 * FUTURE_TRAJECTORY_WINDOW
y_future_dir_start = y_future_pos_end
y_future_dir_end = y_future_dir_start + 2 * FUTURE_TRAJECTORY_WINDOW
y_trajectory_end = y_future_dir_end
y_joint_start = y_trajectory_end
y_joint_pos_start = y_joint_start
y_joint_pos_end = y_joint_pos_start + 3 * len(JOINTS)
y_joint_vel_start = y_joint_pos_end
y_joint_vel_end = y_joint_vel_start + 3 * len(JOINTS)
y_joint_xy_start = y_joint_vel_end
y_joint_xy_end = y_joint_xy_start + 6 * len(JOINTS)
y_joint_end = y_joint_xy_end
y_end = y_joint_end
output_size = y_end


# m = (np.core.defchararray.find(walk_pat, ' ') == -1)

WALK_PATTERN_TRESHOLD = 2

def walk_pattern(phases, thres=0):
    n_foot = phases.shape[1]
    contacts = np.diff(phases, axis=0) < 0
    n_foot = phases.shape[1]
    walk_pat = np.zeros(len(phases), dtype='U{}'.format(n_foot))
    current = ' ' * n_foot
    last_contact = [-1] * n_foot
    for i in range(len(contacts)):
        c = contacts[i]
        if c.any():
            for i_foot in range(n_foot):
                if c[i_foot]:
                    last_contact[i_foot] = i
            current, last_contact = make_walk_pattern(last_contact, thres)
        walk_pat[i] = current
    walk_pat[-1] = current
    return walk_pat

def make_walk_pattern(last_contact, thres=0):
    if not any(f >= 0 for _, f in enumerate(last_contact)):
        return ' ' * last_contact
    g = [[] for _ in last_contact]
    for i1, f1 in enumerate(last_contact):
        if f1 < 0: continue
        for i2, f2 in enumerate(last_contact[i1 + 1:], i1 + 1):
            if f2 < 0: continue
            if abs(f1 - f2) <= thres:
                g[i1].append(i2)
    groups = []
    rem = list(range(len(last_contact)))
    first_idx = len(last_contact)
    while rem:
        i = rem.pop(0)
        if last_contact[i] < 0 or any(i in gr[0] for gr in groups):
            continue
        group = []
        group_f = -1
        rem_group = [i]
        while rem_group:
            i2 = rem_group.pop(0)
            first_idx = min(i2, first_idx)
            group.append(i2)
            group_f = max(last_contact[i2], group_f)
            rem_group.extend(i3 for i3 in g[i2] if i3 not in group and i3 not in rem_group)
        groups.append((group, group_f))
    groups = sorted(groups, key=lambda gr: gr[1])
    g0_idx = next(i for i, gr in enumerate(groups) if first_idx in gr[0])
    pattern = []
    new_contact = last_contact.copy()
    for i in range(len(groups)):
        i_group = (g0_idx + i) % len(groups)
        group, group_f = groups[i_group]
        for i in sorted(group):
            pattern.append(i)
            new_contact[i] = group_f
    pad = ' ' * (len(last_contact) - len(pattern))
    return ''.join(map(str, pattern)) + pad, new_contact

def find_longest_sequence(a):
    a = np.pad(np.asarray(a, dtype=bool), (1, 1), 'constant')
    idx = np.where(np.diff(a) != 0)[0].reshape(-1, 2)
    return tuple(idx[np.diff(idx, axis=1).argmax()]) if len(idx) else (0, 0)

def window_max(a, window):
    a = a.ravel()
    s, = a.strides
    assert window >= 1 and len(a) >= window
    wins = np.lib.stride_tricks.as_strided(a, [len(a) - window + 1, window], [s, s], writeable=False)
    pad_l, pad_r = window // 2, (window - 1) // 2
    m = wins.max(1)
    return np.r_[[m[0]] * pad_l, m, [m[-1]] * pad_r]

# parents = [JOINTS.index(JOINT_PARENTS.get(j, None)) if JOINT_PARENTS.get(j, None) in JOINTS else -1 for j in JOINTS]

def plot_skel(joint_pos, parents, ax=None, lines=None):
    if ax is None:
        plt.figure()
        ax = plt.subplot(111, projection='3d')
        mean = np.mean(joint_pos, axis=0)
        d = joint_pos - mean
        low, high = mean + d.min() * 1.1, mean + d.max() * 1.1
        ax.set_xlim(low[0], high[0])
        ax.set_ylim(low[1], high[1])
        ax.set_zlim(low[2], high[2])
        ax.set_aspect('equal')
    if lines is None:
        lines = [ax.plot3D([], [], [], c='k', lw=1)[0] for _ in parents]
    for i_joint, (i_plt, i_parent) in enumerate(zip(lines, parents)):
        if i_parent < 0: continue
        i_plt.set_data(joint_pos[[i_parent, i_joint], 0], joint_pos[[i_parent, i_joint], 1])
        i_plt.set_3d_properties(joint_pos[[i_parent, i_joint], 2])
    return ax, lines

def make_animation_skel(joint_pos, parents, fps=24, ax=None):
    ax, lines = plot_skel(joint_pos[0], parents, ax=ax)
    f = lambda i: plot_skel(joint_pos[i], parents, ax=ax, lines=lines)[1]
    return mplanim.FuncAnimation(ax.figure, f, range(len(joint_pos)), interval=1000./fps, repeat=False, blit=True)

# for i, k in enumerate(walk_kinds):
#     make_animation(walk_kinds_longest[i], parents, fps=24).save(f'{k}_example.mp4', writer=mplanim.writers['ffmpeg'](fps=24), dpi=100)
#     plt.close()


def plot_traj(traj_points, ax=None, lines=None):
    if ax is None:
        plt.figure()
        ax = plt.subplot(111)
        mean = np.mean(traj_points, axis=0)
        d = traj_points - mean
        low, high = mean + d.min() * 1.1, mean + d.max() * 1.1
        ax.set_xlim(low[0], high[0])
        ax.set_ylim(low[1], high[1])
        ax.set_aspect('equal')
    if lines is None:
        lines = [ax.plot([], [], [], c='k', lw=1)[0]]
    lines[0].set_data(traj_points[:, 0], traj_points[:, 1])
    return ax, lines

def make_animation_traj(traj_points, fps=24, ax=None):
    ax, lines = plot_traj(traj_points[0], ax=ax)
    mean = np.mean(traj_points, axis=(0, 1))
    d = traj_points - mean
    low, high = mean + d.min() * 1.1, mean + d.max() * 1.1
    ax.set_xlim(low[0], high[0])
    ax.set_ylim(low[1], high[1])
    ax.set_aspect('equal')
    f = lambda i: plot_traj(traj_points[i], ax=ax, lines=lines)[1]
    return mplanim.FuncAnimation(ax.figure, f, range(len(traj_points)), interval=1000./fps, repeat=False, blit=True)

def rotation_matrix_from_yaw(yaw):
    s = np.sin(yaw)
    c = np.cos(yaw)
    o = np.ones_like(yaw)
    z = np.zeros_like(yaw)
    return np.stack([np.stack([ c, -s,  z], axis=-1),
                     np.stack([ s,  c,  z], axis=-1),
                     np.stack([ z,  z,  o], axis=-1)], axis=-2)

def get_joint_speeds(joint_pos, root_vel, root_angvel):
    n = len(root_vel)
    joint_pos_prev = np.roll(joint_pos, 1, axis=0)
    joint_pos_prev[0] = joint_pos_prev[1]
    root_vel_prev = np.roll(root_vel, 1, axis=0)
    root_vel_prev[0] = 0
    root_angvel = np.roll(root_angvel, 1, axis=0)
    root_angvel[0] = 0
    rel_roots = np.zeros((n, 4, 4))
    rel_roots[:, -1, -1] = 1
    rel_roots[:, :2, 3] = root_vel_prev
    rel_roots[:, :3, :3] = rotation_matrix_from_yaw(root_angvel)
    joint_pos_h = np.concatenate([joint_pos, np.ones_like(joint_pos[..., :1])], axis=-1)
    joint_pos_rel = (rel_roots[:, np.newaxis] @ joint_pos_h[..., np.newaxis])[..., :-1, 0]
    return np.linalg.norm(joint_pos_rel - joint_pos_prev, axis=-1)

def peaks_cover(a):
    a = np.asarray(a)
    d = np.diff(a.ravel())
    p = np.where((d[:-1] * d[1:] < 0) & (d[:-1] > 0))[0] + 1
    v = a[p]
    return np.interp(np.arange(len(a)), p, v)


# Transforms are multiplied FROM THE LEFT: Transform * Object
# Rightmost transform is applied first: Transform3 * Transform2 * Transform1
def make_transform_relative(tx, root):
    tx_rel = np.linalg.inv(root) @ tx
    return normalize_transform(tx_rel)

def normalize_transform(tx):
    tx = np.asarray(tx)
    return tx / tx[..., 3:, 3:]

def rotate_vector(v, tx):
    tx = normalize_transform(tx)
    v = np.asarray(v)
    v_rot = tx[..., :3, :3] @ v[..., np.newaxis]
    return v_rot[..., 0]

def transform_vector(v, tx):
    v = np.asarray(v)
    v = np.concatenate([v, np.ones_like(v[..., :1])], axis=-1)
    v_rot = tx @ v[..., np.newaxis]
    return v_rot[..., :3, 0] / v_rot[..., 3:, 0]

def triangles_area(triangles):
    triangles = np.asarray(triangles)
    ab = triangles[..., 1, :] - triangles[..., 0, :]
    ac = triangles[..., 2, :] - triangles[..., 0, :]
    return np.linalg.norm(np.cross(ab, ac), axis=-1) / 2

def segments_error(segments1, segments2):
    segments1 = np.asarray(segments1)
    segments2 = np.asarray(segments2)
    triangles = np.empty(segments1.shape[:-2] + (4, 3, 3), dtype=segments1.dtype)
    mid_points = np.mean(np.concatenate([segments1, segments2], axis=-2), axis=-2)
    triangles[..., 0, 0, :] = segments1[..., 0, :]
    triangles[..., 0, 1, :] = segments1[..., 1, :]
    triangles[..., 0, 2, :] = mid_points
    triangles[..., 1, 0, :] = segments1[..., 0, :]
    triangles[..., 1, 1, :] = segments2[..., 0, :]
    triangles[..., 1, 2, :] = mid_points
    triangles[..., 2, 0, :] = segments2[..., 0, :]
    triangles[..., 2, 1, :] = segments2[..., 1, :]
    triangles[..., 2, 2, :] = mid_points
    triangles[..., 3, 0, :] = segments1[..., 1, :]
    triangles[..., 3, 1, :] = segments2[..., 1, :]
    triangles[..., 3, 2, :] = mid_points
    areas = triangles_area(triangles)
    return areas.sum(axis=-1)

def make_segments(pos):
    seg_idx = [(i, JOINTS.index(JOINT_PARENTS[j])) for i, j in enumerate(JOINTS) if JOINT_PARENTS.get(j, -1) in JOINTS]
    seg = np.empty((len(pos), len(seg_idx), 2, 3), pos.dtype)
    seg[:, :, 0] = pos[:, [i for i, _ in seg_idx]]
    seg[:, :, 1] = pos[:, [i for _, i in seg_idx]]
    return seg

def model_score(model, mode, data, sync=1):
    x, p, y = data
    y_segs = make_segments(y[:, y_joint_pos_start:y_joint_pos_end].reshape(len(y), -1, 3))
    y2 = run_model(model, mode, x, p, sync)
    y2_segs = make_segments(y2[:, y_joint_pos_start:y_joint_pos_end].reshape(len(y2), -1, 3))
    return segments_error(y_segs, y2_segs)

def compare_models(datas, models, sync=1):
    js = [j for j in JOINTS if JOINT_PARENTS.get(j, -1) in JOINTS]
    dfs = []
    data_dfs = []
    for data_name, data_file in datas:
        data = get_data(data_file)
        num_frames = len(data[0]) if data else 0
        data_df = make_empty_dataframe(['Angle', 'Velocity'], [float, float], index=range(num_frames))
        data_dfs.append(data_df)
        if not data:
            continue
        ang, vel = get_angvel(data[0])
        data_df['Angle'] = ang
        data_df['Velocity'] = vel
        scores = {}
        for model_name, model_file, model_mode in models:
            print(f'Evaluating model {model_name} on {data_name}...')
            s = model_score(model_file, model_mode, data, sync)
            for i, j in enumerate(js):
                scores[(model_name, j)] = s[:, i]
        df = pd.DataFrame(scores)
        dfs.append(df)
    data_names = [data_name for data_name, _ in datas]
    all_df =  pd.concat(dfs, keys=data_names, names=['Data'])
    all_df.index.set_names(['Data', 'Frame'], inplace=True)
    all_df.columns.set_names(['Model', 'Bone'], inplace=True)
    all_data_df = pd.concat(data_dfs, keys=data_names, names=['Data'])
    all_data_df.index.set_names(['Data', 'Frame'], inplace=True)
    return all_df, all_data_df

def plot_scores(scores, ax=None):
    plt.figure()
    df = scores.unstack().reset_index(0)
    df.columns = ['model', 'value']
    ax = sns.boxenplot(x='model', y='value', data=df, ax=ax)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    return ax

def get_data(record):
    with tf.Graph().as_default(), tf.Session() as sess:
        ds = (tf.data.Dataset.list_files(str(record))
              .flat_map(tf.data.TFRecordDataset)
              .map(parse_example, num_parallel_calls=os.cpu_count())
              .batch(10000)
              .prefetch(1))
        it = ds.make_one_shot_iterator()
        it_next = it.get_next()
        data = []
        while True:
            try:
                data.append(sess.run(it_next))
            except tf.errors.OutOfRangeError: break
        return [np.concatenate(x) for x in zip(*data)]

def get_angvel(x):
    traj_size = (x_future_pos_end - x_future_pos_start) // 2
    fut_pos = np.reshape(x[:, x_future_pos_start:x_future_pos_end], (-1, traj_size, 2))
    ang = np.rad2deg(np.arctan2(fut_pos[:, -1, 1], fut_pos[:, -1, 0]))
    vel = np.sum(np.linalg.norm(fut_pos[:, 1:] - fut_pos[:, :-1], axis=-1), axis=-1)
    return ang, vel

def parse_example(example_proto):
    features = {'x': tf.FixedLenFeature((input_size,), tf.float32),
                'phase': tf.FixedLenFeature((phase_size), tf.float32),
                'y': tf.FixedLenFeature((output_size,), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['x'], parsed_features['phase'], parsed_features['y']

def run_model(model, mode, x, phase, sync=1):
    mode = mode.strip().upper()
    is_mann = mode == 'MANN'
    with tf.Graph().as_default(), tf.Session() as sess:
        gd = tf.GraphDef()
        with open(model, 'rb') as f:
            gd.MergeFromString(f.read())
        if not is_mann:
            inp, coords, out = tf.import_graph_def(
                gd, name='', return_elements=['Input:0', 'Coords:0', f'Output{mode.title()}:0'])
        else:
            inp, out = tf.import_graph_def(
                gd, name='', return_elements=['Input:0', 'Output:0'])
            coords = None
        it = iter(range(len(x)))
        res = []
        xi = np.zeros_like(x[0])
        pi = np.zeros_like(phase[0])
        try:
            while True:
                i = next(it)
                pi[:] = phase[i]
                if not is_mann:
                    res.append(sess.run(out, feed_dict={inp: x[i], coords: pi}))
                else:
                    res.append(sess.run(out, feed_dict={inp: x[i]}))
                for i in range(sync - 1):
                    i = next(it)
                    xi[:] = x[i]
                    xi[x_joint_pos_start:x_joint_pos_end] = res[-1][x_joint_pos_start:x_joint_pos_end]
                    xi[x_joint_vel_start:x_joint_vel_end] = res[-1][x_joint_vel_start:x_joint_vel_end]
                    pi += res[-1][y_phase_start:y_phase_end]
                    if not is_mann:
                        res.append(sess.run(out, feed_dict={inp: xi, coords: pi}))
                    else:
                        res.append(sess.run(out, feed_dict={inp: xi}))
        except StopIteration: pass
        return np.stack(res)

def make_joint_transform(pos, x, y):
    pos = np.asarray(pos)
    s = np.shape(x)
    d = s[-1]
    xs = np.reshape(x, (-1, d))
    ys = np.reshape(y, (-1, d))
    xs /= np.maximum(np.linalg.norm(xs, axis=-1, keepdims=True), 1e-8)
    ys /= np.maximum(np.linalg.norm(ys, axis=-1, keepdims=True), 1e-8)
    zs = np.cross(xs, ys, axis=-1)
    zs /= np.maximum(np.linalg.norm(zs, axis=-1, keepdims=True), 1e-8)
    ys = np.cross(zs, xs, axis=-1)
    x = np.reshape(xs, s)
    y = np.reshape(ys, s)
    z = np.reshape(zs, s)
    tx_h = np.stack([x, y, z, pos], axis=-1)
    h = np.tile([[0, 0, 0, 1]], s[:-1] + (1, 1))
    tx = np.concatenate([tx_h, h], axis=-2)
    #assert np.allclose(np.linalg.det(tx[..., :3, :3]), 1), f'Invalid rotation matrix\n{tx[..., :3, :3]}'
    return tx

def get_transforms(df, joints, dtype=None):
    dtype = dtype or np.float32
    if not joints:
        joints = [col.split(':')[0] for col in df if col.endswith(':cs:position:x')]
    unpack = False
    if np.isscalar(joints):
        joints = [joints]
        unpack = True
    transforms = []
    for joint in joints:
        pos = df[[f'{joint}:cs:position:x',
                  f'{joint}:cs:position:y',
                  f'{joint}:cs:position:z']].values
        quat = df[[f'{joint}:cs:quaternion:w',
                   f'{joint}:cs:quaternion:x',
                   f'{joint}:cs:quaternion:y',
                   f'{joint}:cs:quaternion:z']].values
        quat = quaternion.as_quat_array(quat)
        rot = quaternion.as_rotation_matrix(quat)
        tx = np.concatenate([rot, pos[..., np.newaxis]], axis=2)
        tx = np.concatenate([tx, np.tile([0, 0, 0, 1], (len(tx), 1, 1))], axis=1)
        transforms.append(tx)
    return transforms[0] if unpack else transforms

def fix_maya_rotation(tx):
    tx = normalize_transform(tx)
    # Axes given in columns
    axes_ue = np.array([[ 1,  0,  0],
                        [ 0,  1,  0],
                        [ 0,  0,  1]], dtype=tx.dtype)
    axes_maya = np.array([[ 0,  0,  1],
                          [-1,  0,  0],
                          [ 0, -1,  0]], dtype=tx.dtype)
    tx[..., :3, :3] = tx[..., :3, :3] @ axes_ue @ np.linalg.inv(axes_maya)
    return tx

def get_root_transforms(df):
    return fix_maya_rotation(get_transforms(df, ROOT))

def joint_parent_index():
    joint_parents = []
    for joint in JOINTS:
        parent_name = JOINT_PARENTS.get(joint, None)
        if parent_name and parent_name in JOINTS:
            joint_parents.append(JOINTS.index(parent_name))
        else:
            joint_parents.append(-1)
    return joint_parents

def merge_anim_functions(anim_funcs):
    return lambda frame: [art for anim_func in anim_funcs for art in anim_func(frame)]

def make_empty_dataframe(columns, dtypes, index=None):
    return pd.concat([pd.Series(name=col, dtype=dt, index=index) for col, dt in zip(columns, dtypes)], axis=1)

# Model analysis

def num_nn_params(input_size, output_size, hidden_layers, use_bias=True):
    nn_params = 0
    last = input_size
    for layer in [*hidden_layers, output_size]:
        nn_params += last * layer
        if use_bias:
            nn_params += layer
        last = layer
    return nn_params

def num_gfnn_params(spline_points, input_size, output_size, hidden_layers, use_bias=True):
    nn_params = num_nn_params(input_size, output_size, hidden_layers, use_bias)
    spline_size = 1
    for p in spline_points:
        spline_size *= p
    return nn_params * spline_size

def flops_spline(spline_points, nn_params):
    flops = 0
    # Bigger dimensions are reduced to 4, dimensions are sorted by size
    spline_points = sorted((min(p, 4) for p in spline_points), reverse=True)
    # Weight computation cost
    for p in spline_points:
        if p > 1:
            # Weight computation fixed cost
            flops += 17
            # Dimensions with size 2 or 3 aggregate weights
            if p < 4:
                flops += 4 * p + 3 * p
    # Interpolation cost
    flops_per_param = 0
    for i, p in enumerate(spline_points):
        if p > 1:
            # Weight multiplication and reduction
            a = np.prod(spline_points[i:])
            flops_per_param += a + (a // p) * (p - 1)
    flops += flops_per_param * nn_params
    return flops

def flops_linear(spline_points, nn_params):
    n_dims = sum(1 for d in spline_points if d > 1)
    return 3 * n_dims + (1 << n_dims) * nn_params

def flops_constant(spline_points):
    n_dims = sum(1 for d in spline_points if d > 1)
    return 1 * n_dims

def flops_mann(num_gated, nn_params):
    return (num_gated + num_gated - 1) * nn_params

def flops_nn(input_size, output_size, hidden_layers, use_bias=True):
    flops = 0
    last = input_size
    for layer in [*hidden_layers, output_size]:
        flops += 2 * (last * layer)
        if use_bias:
            flops += layer
        last = layer
    return flops

def test_run_time(model_file, input_size, spline_dim, mode):
    test_exe = Path(Path(__file__).parent, 'test_model', 'test_model_gen.exe').resolve()
    model_path = Path(model_file).resolve()
    args = [str(test_exe), str(model_path), 'Input', str(input_size)]
    if spline_dim > 0:
        args += ['Coords', str(spline_dim)]
    args.append(f'Output{mode.strip().title()}')
    p = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
    if p.returncode != 0:
        return math.nan
    try:
        return float(p.stdout)
    except ValueError:
        return math.nan

def get_model_metadata(model_files):
    model_names = [model_name for model_name, _, _ in model_files]
    df = make_empty_dataframe(*zip(('input_size', int),
                                   ('output_size', int),
                                   ('actual_input_size', int),
                                   ('actual_input_gate_size', int),
                                   ('actual_output_size', int),
                                   ('encoder_layers', object),
                                   ('hidden_layers', object),
                                   ('decoder_layers', object),
                                   ('gate_hidden_layers', object),
                                   ('functioned_output', bool),
                                   ('use_bias', bool),
                                   ('parameters_trainable', int),
                                   ('parameters_actual', int),
                                   ('parameters_interp', int),
                                   ('num_experts', object),
                                   ('parameters_experts', int),
                                   ('mode', object),
                                   ('actual_num_experts', object),
                                   ('parameters_actual_experts', object),
                                   ('parameters_gate', object),
                                   ('flops_gate', int),
                                   ('flops_interp', int),
                                   ('flops_nn', int),
                                   ('flops_total', int),
                                   ('run_time', float)),
                              index=model_names)
    df.index.name = 'model'
    for i, (model_name, model_file, model_mode) in enumerate(model_files):
        model_mode = model_mode.strip().upper()
        with tf.Graph().as_default(), tf.Session() as sess:
            # Load graph
            gd = tf.GraphDef()
            with Path(model_file).open('rb') as f:
                gd.MergeFromString(f.read())
            if model_mode != 'MANN':
                metadata = tf.import_graph_def(gd, name='', return_elements=[
                    'Meta/NumParameters:0', 'Meta/SplinePoints:0', f'Meta/GridEvaluations{model_mode.title()}:0',
                    'Meta/EncoderLayers:0', 'Meta/DecoderLayers:0',
                    'Meta/HiddenLayers:0',
                    'Meta/FunctionedOutput:0', 'Meta/UseBias:0',
                    'Meta/InputSize:0', 'Meta/OutputSize:0',
                    'Meta/ActualInputSize:0', 'Meta/ActualOutputSize:0',
                ])
                (
                    params_train, num_experts, grid_evals,
                    encoder_layers, decoder_layers, hidden_layers,
                    functioned_output, use_bias,
                    input_size, output_size,
                    actual_input_size, actual_output_size
                ) = sess.run(metadata)
                actual_input_gate_size = 0
                gate_hidden_layers = []
            else:
                metadata = tf.import_graph_def(gd, name='', return_elements=[
                    'Meta/NumParameters:0', 'Meta/NumGated:0',
                    'Meta/EncoderLayers:0', 'Meta/DecoderLayers:0',
                    'Meta/HiddenLayers:0', 'Meta/GatingHiddenLayers:0',
                    'Meta/FunctionedOutput:0', 'Meta/UseBias:0',
                    'Meta/InputSize:0', 'Meta/OutputSize:0',
                    'Meta/ActualInputSize:0', 'Meta/ActualInputGateSize:0', 'Meta/ActualOutputSize:0',
                ])
                (
                    params_train, num_experts,
                    encoder_layers, decoder_layers,
                    hidden_layers, gate_hidden_layers,
                    functioned_output, use_bias,
                    input_size, output_size,
                    actual_input_size, actual_input_gate_size, actual_output_size
                ) = sess.run(metadata)
                num_experts = [num_experts]
                grid_evals = num_experts
            # Find numbers of parameters
            params_experts = params_train
            if len(encoder_layers) > 0:
                params_experts -= num_nn_params(actual_input_size, encoder_layers[-1], encoder_layers[:-1], use_bias)
            if len(decoder_layers) > 0:
                decoder_in = (hidden_layers[-1] if len(hidden_layers) > 0 else
                              (encoder_layers[-1] if len(encoder_layers) > 0 else actual_input_size))
                if functioned_output:
                    params_experts -= num_nn_params(decoder_in, decoder_layers[-1], decoder_layers[:-1], use_bias)
                else:
                    params_experts -= num_nn_params(decoder_in, actual_output_size, decoder_layers, use_bias)
            if actual_input_gate_size:
                params_gate = num_nn_params(actual_input_gate_size, num_experts[0], gate_hidden_layers, use_bias)
            else:
                params_gate = 0
            params_experts -= params_gate
            params_interp = params_experts // np.prod(num_experts)
            params_grid = params_interp * np.prod(grid_evals)
            params_actual = params_train - params_experts + params_grid
            # Flops estimation
            flops_gate = 0
            if model_mode == 'CUBIC':
                flops_interp = flops_spline(num_experts, params_interp)
            elif model_mode == 'LINEAR':
                flops_interp = flops_linear(num_experts, params_interp)
            elif model_mode == 'CONSTANT':
                flops_interp = flops_constant(num_experts)
            elif model_mode == 'MANN':
                # TODO
                flops_interp = flops_mann(num_experts[0], params_interp)
                flops_gate = flops_nn(actual_input_gate_size, num_experts[0], gate_hidden_layers, use_bias)
            else:
                raise ValueError(f'unknown model mode "{model_mode}"')
            all_layers = [*encoder_layers, *hidden_layers, *decoder_layers]
            flops_net = flops_nn(actual_input_size, actual_output_size, all_layers, use_bias)
            flops_total = flops_interp + flops_net
            if model_mode != 'MANN':
                run_time = test_run_time(model_file, input_size, len(num_experts), model_mode.title())
            else:
                run_time = test_run_time(model_file, input_size, 0, '')
            df.loc[model_name, :] = [input_size, output_size,
                                     actual_input_size, actual_input_gate_size, actual_output_size,
                                     ' '.join(map(str, encoder_layers)),
                                     ' '.join(map(str, hidden_layers)),
                                     ' '.join(map(str, decoder_layers)),
                                     ' '.join(map(str, gate_hidden_layers)),
                                     functioned_output, use_bias,
                                     params_train, params_actual, params_interp,
                                     ' '.join(map(str, num_experts)), params_experts,
                                     model_mode, ' '.join(map(str, grid_evals)), params_grid, params_gate,
                                     flops_gate, flops_interp, flops_net, flops_total, run_time]
    return df

def get_error_statistics(errors):
    stats_columns = ['mean', 'std', 'min', 'max', 'p25', 'p50', 'p75']
    #errors = np.asarray(errors)[1:]  # Skip first common frame
    return pd.DataFrame([[np.mean(errors, axis=0), np.std(errors, axis=0),
                          np.min(errors, axis=0), np.max(errors, axis=0),
                          np.percentile(errors, 25, axis=0),
                          np.percentile(errors, 50, axis=0),
                          np.percentile(errors, 75, axis=0)]],
                          columns=stats_columns)

def num_rows_cols(n, perfect=False):
    s = int(math.sqrt(n))
    if perfect:
        for a in range(s, 0, -1):
            if n % a == 0:
                break
        b = n // a
    else:
        a = s
        b = (n + s - 1) // s
    return min(a, b), max(a, b)

def plot_errors(error_df, save_path=None):
    error_long = error_df.groupby(level=0, axis=1).sum().stack().rename('Error')
    error_long = error_long.reset_index(2)
    datas = error_long.index.levels[0]
    nrows, ncols = num_rows_cols(len(datas))
    fig = plt.figure(figsize=(5 * ncols, 5 * nrows))
    ax = None
    for i, data in enumerate(datas, 1):
        ax = fig.add_subplot(nrows, ncols, i, sharex=ax, sharey=ax)
        sns.boxenplot(x='Model', y='Error', data=error_long.loc[data], ax=ax)
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
        ax.set_title(data)
    fig.tight_layout()
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(str(save_path))
    return fig

def plot_bone_errors(error_df, save_dir=None, ext='png'):
    if save_dir:
        Path(save_dir).mkdir(parents=True, exist_ok=True)
    m = error_df.values.max()
    y_min, y_max = -m * .05, m * 1.05
    figs = []
    for model in error_df.columns.levels[0]:
        model_error = error_df[model].stack().rename('Error').reset_index('Bone')
        datas = model_error.index.levels[0]
        nrows, ncols = num_rows_cols(len(datas))
        fig = plt.figure(figsize=(7 * ncols, 5 * nrows))
        ax = None
        for i, data in enumerate(datas, 1):
            ax = fig.add_subplot(nrows, ncols, i, sharex=ax, sharey=ax)
            sns.boxplot(x='Bone', y='Error', data=model_error.loc[data], ax=ax)
            ax.set_ylim(y_min, y_max)
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
            ax.set_title(data)
        fig.suptitle(model, y=.99)
        fig.tight_layout(rect=(0, 0, 1, .99))
        if save_dir:
            fig.savefig(str(Path(save_dir, f'{model}.{ext}')))
        figs.append(fig)
    return figs

def main(data_files, model_files, per_data, write_error, out_dir):
    out_dir.mkdir(parents=True, exist_ok=True)
    models_df = get_model_metadata(model_files)
    with Path(out_dir, 'models.df').open('wb') as f:
        pickle.dump(models_df, f)
    error_df, data_df = compare_models(data_files, model_files)
    # Save results
    print('Saving results...')
    with Path(out_dir, 'data.df').open('wb') as f:
        pickle.dump(data_df, f)
    error_long = (error_df
        .stack([0, 1])
        .rename('Error')
        .reorder_levels([0, 2, 1, 3])
        .sort_index(level=[0, 1, 2], kind='heapsort'))
    error_long.index.set_names(['Data', 'Model', 'Frame', 'Bone'], inplace=True)
    with Path(out_dir, 'error.df').open('wb') as f:
        pickle.dump(error_long, f)
    with pd.ExcelWriter(Path(out_dir, 'results.xlsx')) as writer:
        if model_files:
            models_df.to_excel(writer, sheet_name='Models')
        if data_files:
            data_info_df = pd.DataFrame({'NumFrames': 1}, index=data_df.index).groupby('Data').sum()
            data_info_df.to_excel(writer, sheet_name='DataInfo')
            data_df.reset_index().to_excel(writer, sheet_name='DataFeatures', index=False)
        if data_files and model_files:
            model_names = [n for n, _, _ in model_files]
            error_frame_df = error_long.groupby(['Data', 'Model', 'Frame']).sum()
            overall_df = (error_frame_df.groupby(['Model'])
                          .apply(get_error_statistics)
                          .reset_index(level=-1, drop=True)
                          .loc[model_names, :])
            overall_df.to_excel(writer, sheet_name='Overall')
            if per_data:
                data_stats = (error_frame_df.groupby(['Data', 'Model'])
                              .apply(get_error_statistics)
                              .reset_index(level=-1, drop=True))
                for data_name, _ in data_files:
                    data_df = data_stats.loc[data_name].loc[model_names]
                    data_df.to_excel(writer, sheet_name=data_name)
        if write_error:
            error_long.reset_index().to_excel(writer, 'Error', index=False)
    plot_errors(error_df, Path(out_dir, 'error.png'))
    plot_bone_errors(error_df, Path(out_dir, 'bone_error'), ext='png')
    plt.close('all')
    print('Done.')

if __name__ == '__main__':
    main(*ARGS)
