# imports
import sys
from os.path import isfile, isdir, join, realpath, dirname

from matplotlib import pyplot as plt

script_root = dirname(realpath(__file__))
sys.path.insert(1, script_root)

## Functional
from argparse import ArgumentParser
from pathlib import Path
from os import listdir
import re
import copy
import numpy as np
import pandas as pd
import json

## data
from dependencies.trialPlotter import TrialVisualiser
from dependencies.Targets import Target

## Usability
from rich.table import Table
from rich.console import Console
from rich.progress import Progress

# Static
## Regex
participant_regex = re.compile(r'^.*[A-Za-z0-9]+[-][A-Za-z0-9]+[-][A-Za-z0-9]+[-][A-Za-z0-9]+[-][A-Za-z0-9]+.*')
int_regex = re.compile(r"^[0-9]+$")
range_regex = re.compile(r"^[0-9]+[-][0-9]+$")
not_regex = re.compile(r"^[!].+$")
all_regex = re.compile(r"^[aA]([lL][lL])?$")
confirm_regex = re.compile(r"^[yY]([eE][sS])?$")
reject_regex = re.compile(r"^[nN][oO]?$")

console = Console()

HVP_KEY = "Hand Velocity Peaks"

# Helpers
def generate_participant_table(participants):
    participants_table = Table()
    participants_table.add_column("#")
    participants_table.add_column("Selected")
    participants_table.add_column("Participant ID")

    for idx, row in enumerate(participants):
        if type(row) is tuple:
            participant = row[0]
            selected = row[1]
        else:
            participant = row
            selected = False
        
        selected_colour = "green" if selected else "red"
        participants_table.add_row(
            str(idx),
            f"[{selected_colour}]{selected}[/{selected_colour}]",
            participant
        )
    return participants_table

# load trial data and participant metadata
def load_participant_data(participant, input_dir):
    metadata = None
    metadata_path = join(dirname(input_dir), "GestureAnnotations.json")
    if isfile(metadata_path):
        try:
            with open(metadata_path, 'r') as metadata_file:
                metadata = json.load(metadata_file)
                metadata = [md for md in metadata if md["ID"] == participant]
                if len(metadata) == 1:
                    metadata = metadata[0]
                else:
                    print(f"Unable to find participant ({participant}) in metadata {len(metadata)}")
                    metadata = None
        except Exception:
            print(f"Unable to load metadata file ({metadata_path})")
    else:
        print(f"Unable to find metadata file ({metadata_path}) for participant: {participant}")
    
    if metadata is None:
        return None, None
    
    participant_metadata = {k: v for k, v in metadata.items() if k != "Conditions"}
    participant_metadata["Conditions"] = copy.deepcopy(metadata["Conditions"])
    # participant_conditions = copy.deepcopy(metadata["Conditions"])

    pointing_trials = {}
    # read trials in conditions, load the dataframes
    for condition_data in metadata["Conditions"]:    
        condition = condition_data["Condition"]
        print(f"Retrieving data for {participant} - {condition}")

        trials = condition_data["Trials"]
        
        for trial in trials:
            trial_path = join(input_dir, trial["Data"])
                
            if isfile(trial_path):
                trial["FULL"] = pd.read_csv(trial_path)
            else:
                print(f"Unable to find file at: {trial_path}")
                trial["FULL"] = None
            
            trial["TARGET"] = Target(
                trial["TargetBox"]["Row"], trial["TargetBox"]["Column"], trial["TargetLED"]["ID"], 
                np.array([trial["TargetBox"].get("X", np.nan), trial["TargetBox"].get("Y", np.nan), trial["TargetBox"].get("Z", np.nan)]),
                np.array([trial["TargetBox"].get("YAW", np.nan), trial["TargetBox"].get("PITCH", np.nan), trial["TargetBox"].get("ROLL", np.nan)]),
                np.array([trial["TargetLED"].get("X", np.nan), trial["TargetLED"].get("Y", np.nan), trial["TargetLED"].get("Z", np.nan)])
            )

            trial[HVP_KEY] = trial[HVP_KEY]

        condition = tuple(condition.split('-'))

        pointing_trials[condition] = trials.copy()
    
    return pointing_trials, participant_metadata

# Save Annotation Changes
def save_override(participant_id, condition, trial_idx, overrides, path, participant_metadata):
    if len([True for x in overrides.values() if len(set(x.keys()) - set(["START", "END"])) > 0]) < 1:
        print(f"No Overrides, skipping. {participant_id}, {condition}, {trial_idx}, {overrides}")
        return
    manifest_path = join(path, f"{participant_id}Annotations.json")
    manifest = None
    
    if isfile(manifest_path):
        with open(manifest_path, 'r') as manifest_file:
            manifest = json.load(manifest_file)
    else:
        manifest = participant_metadata

    if manifest is not None:
        con_idx = next(i for i, v in enumerate(manifest["Conditions"]) if v["Condition"] == f"{condition[0]}-{condition[1]}")
        manifest["Conditions"][con_idx]["Trials"][trial_idx][HVP_KEY] = overrides

        with open(manifest_path, 'w') as manifest_file:
            manifest_str = json.dumps(manifest)
            manifest_file.writelines(manifest_str)
    
# Program Start
parser = ArgumentParser(description="Visualise and Annotate Trials")
parser.add_argument("-d", "--dir", type=Path, nargs="?", default=None)

args = parser.parse_args()

if args.dir is not None:
    input_dir = args.dir
else:
    parser.print_usage()
    parser.exit(1, "No directory was provided")
    
participants = [d for d in listdir(input_dir) if isdir(join(input_dir, d)) and participant_regex.search(d)]
filtered_participants = [p for p in participants if participant_regex.search(p)]

if len(filtered_participants) < 1:
    parser.exit(1, f"Unable to locate participants in the provided directory: {input_dir}")

selecting_participants = True
chosen_participants = set()
while selecting_participants:
    participants_table = generate_participant_table([(p, p in chosen_participants) for p in filtered_participants])
    console.print(participants_table)
    participant_input = input("Please select the participants that you wish to process, then confirm your selection (y/yes).\n > ")
    participant_input = participant_input.split()
    if len(participant_input) < 1:
        print("No new participants were provided.")
    else:
        unfound = []
        for participant in participant_input:
            if confirm_regex.search(participant) is not None:
                selecting_participants = False
            elif all_regex.search(participant) is not None:
                chosen_participants = { p for p in filtered_participants }
            else:
                if range_regex.search(participant):
                    p_range = [int(s) for s in participant.split("-")]
                    if (0 <= p_range[0] < len(filtered_participants) and 0 <= p_range[1] < len(filtered_participants) and p_range[0] < p_range[1]):
                        for p_idx in range(p_range[0], p_range[1]+1):
                            if filtered_participants[p_idx] in chosen_participants:
                                chosen_participants.remove(filtered_participants[p_idx])
                            else:
                                chosen_participants.add(filtered_participants[p_idx])
                    else:
                        print(f"Unable to process input '{participant}', as range invalid.")
                else:
                    found_idx = False
                    if int_regex.search(participant) is not None:
                        p_idx = int(participant)
                        if p_idx < len(filtered_participants):
                            if filtered_participants[p_idx] in chosen_participants:
                                chosen_participants.remove(filtered_participants[p_idx])
                            else:
                                chosen_participants.add(filtered_participants[p_idx])
                            found_idx = True
                    if not found_idx:
                        pid_regex = re.compile(participant)
                        matched = [p for p in filtered_participants if pid_regex.search(p) is not None]
                        if len(matched) < 1:
                            unfound.append(participant)
                        if len(matched) > 1:
                            print(f"Multiple participants matched {matched} for given string: '{participant}'")
                        for p in matched:
                            if p in chosen_participants:
                                chosen_participants.remove(p)
                            else:
                                chosen_participants.add(p)
        if len(unfound) > 0:
            print(f"Unable to match participants from given inputs: {unfound}")
            if not selecting_participants:
                print("Please reconfirm if you're happy to use the selected participants.")
print(f"Using the provided participants: {chosen_participants}")
        
# process participants
participants = [(p, join(input_dir, p)) for p in chosen_participants]

for participant, participant_path in participants:
    print(f"Loading trial data for participant: {participant}, {participant_path}")
    participant_data, participant_metadata = load_participant_data(participant, input_dir)

    plotter = TrialVisualiser(fig_size=(16,9), match_participant_perspective=False)
    data_to_visualise = {
        condition: [
            trial for trial in data if trial["FULL"] is not None]
            for condition, data in participant_data.items()
        }
    plotter_controls = plotter.show(data_to_visualise, participant, list(data_to_visualise.keys())[0], 0, 0, hand_dominance=participant_metadata["Dominant Hand"], override_callback=lambda participant_id, condition, trial_idx, overrides, path=participant_path: save_override(participant_id, condition, trial_idx, overrides, path, participant_metadata))
    plt.show(block = True)