import tqdm
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchinfo import summary
from torch.nn.utils import clip_grad_norm_
from torchmetrics.regression import MeanAbsolutePercentageError, MeanSquaredError, MeanAbsoluteError

# Custom imports
from tools import save_checkpoint, EarlyStopping
from metrics import RMSELoss, MAPEPercent, MAEMetric, RMSEMetric, AbsolutePercentageError
from datasets import TabularMFSequentialBuckling
from models import NeuralNetwork

def train_model(train_dataset, val_dataset, mlp_config, num_epochs):
    """ A function that trains a MLP and saves it """
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create DataLoaders
    train_dataloader = DataLoader(train_dataset, batch_size=mlp_config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=mlp_config['batch_size'], shuffle=False)

    print('Number of training set samples = {:d}'.format(len(train_dataset)))
    print('Number of validation set samples = {:d}'.format(len(val_dataset)))
    
    # Define model
    input_size = train_dataset[0][0].shape[0]
    output_size = train_dataset[0][1].shape[0]
    model = NeuralNetwork(input_size, output_size, mlp_config)
    model.to(device)
    summary(model)

    # Define optimizer
    optimizer = getattr(
        optim, mlp_config['optimizer']
    )(model.parameters(), lr=mlp_config['learning_rate'])

    # Define loss
    loss_dict = {
        'MSE': nn.MSELoss(),
        'RMSE': RMSELoss(),
        'MAE': nn.L1Loss()
    }
    criterion = loss_dict['MSE']

    # Metrics to compute
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }

    # Instantiate the EarlyStopping class
    early_stopping = EarlyStopping(patience=100, delta=0.01)

    # Instantiate dict for storing losses
    losses_history = {}
    for key in metrics.keys():
        losses_history[key] = {
            'train': [],
            'val': []
        }

    # Training loop
    epochs_completed = 0
    for epoch in tqdm.tqdm(range(num_epochs), total=num_epochs):
        epochs_completed += 1

        train_metrics = train_func(model, optimizer, criterion, metrics, train_dataloader)
        val_metrics = eval_func(model, metrics, val_dataset, val_dataloader)

        # Store losses
        for m in metrics.keys():
            losses_history[m]['train'].append(train_metrics[m])
            losses_history[m]['val'].append(val_metrics[m])

        print(f'Epoch {epoch+1}/{num_epochs}')
        for key in train_metrics.keys():
            print(f'    {key}: Train = {train_metrics[key]}, Val = {val_metrics[key]}')
        
        # Check if early stopping criteria are met
        if early_stopping(val_metrics['MSE']):
            print("Early stopping!")
            break
    
    # save final model
    model_file = Path.cwd().joinpath('./models/full_dataset/mf_sequential_MLP_model.pth')
    save_checkpoint(model, optimizer, epoch, model_file)
    
    # Plot the learning curves
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['train'], label='Training Loss')
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['val'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Curves')
    plt.legend()
    plt.show()

def eval_model(test_dataset, model_file, preds_file, mlp_config):
    """ A function that loads a MLP, use it to make predictions on the
     testing set, save them, and assess the perfromance of the model """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Define test dataloader
    test_dataloader = DataLoader(
        test_dataset, batch_size=mlp_config['batch_size'], shuffle=False
    )

    # Load model
    model_state = torch.load(
        Path.cwd().joinpath(
            './models/full_dataset/{:s}'.format(model_file)
        )
    )
    input_size = train_dataset[0][0].shape[0]
    output_size = train_dataset[0][1].shape[0]
    model = NeuralNetwork(input_size, output_size, mlp_config)
    model.load_state_dict(model_state['model_state_dict'])
    model.to(device)

    # Predict test set and save predictions
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }
    test_metrics = eval_func(model, metrics, test_dataset, test_dataloader, save_preds=True)

    # Load test predictions
    pred_data = torch.load(Path.cwd().joinpath('./results/raw_predictions/{:s}'.format(preds_file)))
    features = pred_data[:, :4]
    y = pred_data[:, 4]
    preds = pred_data[:, 5]

    # Compute final MAE, RMSE, MAPE, MAPE 5%
    mae_func = MeanAbsoluteError()
    rmse_func = RMSEMetric()
    mape_func = MeanAbsolutePercentageError()
    mae5_func = MAEMetric(0.05)
    mape5_func = MAPEPercent(0.05)
    mae = mae_func(preds, y).item()
    rmse = rmse_func(preds, y).item()
    mape = mape_func(preds, y).item()
    mae5 = mae5_func(preds, y).item()
    mape5 = mape5_func(preds, y).item()
    print('Final test accuracy:')
    print('    MAE = {:f}'.format(mae))
    print('    RMSE = {:f}'.format(rmse))
    print('    MAPE = {:f}'.format(mape * 100))
    print('    MAE 5% = {:f}'.format(mae5))
    print('    MAPE 5% = {:f}'.format(mape5))

    # Plot predicted buckling vs FE buckling
    max_val = max(preds.max().item(), y.max().item())
    plt.plot([0.,max_val], [0.,max_val], color='black', linestyle='-', linewidth=1)
    plt.scatter(preds, y, s=10, marker='x', color='red', alpha=0.7)
    plt.axis("equal")
    plt.xlim(0, max_val)
    plt.ylim(0, max_val)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.show()

    # Determine Peason correlations with design space dimensions
    ape_func = AbsolutePercentageError()
    ape = ape_func(preds, y).numpy()
    dims = {
        'span': features[:, 0].numpy(),
        'height/span': features[:, 1].numpy() / features[:, 0].numpy(),
        'thickness/span': features[:, 2].numpy() / features[:, 0].numpy(),
        'material': features[:, 3].numpy()
    }
    print('Pearson correlations between design space dimensions and APEs')
    for dim, vals in dims.items():
        correlation = np.corrcoef(vals, ape)[0,1]
        print('    {:s} - APE correlation = {:f}'.format(dim, correlation))

def train_func(model, optimizer, criterion, metrics, train_dataloader):
    """ A function that iterates through the dataset (to complete a single epoch),
    compute the loss and update the model's weights and biases """
    model.train()
    running_train_metrics = {}
    for key in metrics.keys():
        running_train_metrics[key] = 0.

    for features, target in train_dataloader:
        optimizer.zero_grad()
        outputs = model(features)

        # Calculate the loss
        loss = criterion(outputs, target)

        loss.backward()

        # Clip gradients to prevent exploding gradients
        clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        for key, metric_func in metrics.items():
            running_train_metrics[key] += metric_func(outputs, target).item()
        
    # Final calc
    for key in metrics.keys():
        running_train_metrics[key] /= len(train_dataloader)

    return running_train_metrics

def eval_func(model, metrics, dataset, eval_dataloader, save_preds=False):
    """ A function that iterates through the test dataset, makes predictions
     and potentially save them """
    model.eval()
    running_eval_metrics = {}
    for key in metrics.keys():
        running_eval_metrics[key] = 0.

    full_features = []
    full_targets = []
    full_preds = []
    with torch.no_grad():
        for features, target in eval_dataloader:
            outputs = model(features)

            for key, metric_func in metrics.items():
                running_eval_metrics[key] += metric_func(outputs, target).item()

            if save_preds:
                if dataset.normalize_input:
                    features = features * dataset.input_std + dataset.input_mean
                full_features.append(features[:, :4].cpu())
                full_targets.append(target.cpu())
                full_preds.append(outputs.cpu())
    
    if save_preds:
        full_features = torch.cat(full_features, dim=0)
        full_targets = torch.cat(full_targets, dim=0)
        full_preds = torch.cat(full_preds, dim=0)
        preds_array = torch.cat((full_features, full_targets, full_preds), dim=1)
        preds_file = Path.cwd().joinpath('./results/raw_predictions/mf_sequential_MLP_preds.pth')
        torch.save(preds_array, preds_file)

    # Final calc
    for key in metrics.keys():
        running_eval_metrics[key] /= len(eval_dataloader)

    return running_eval_metrics

if __name__ == '__main__':

    # Input vals
    DATASET_ROOT = Path('F:/ConcreteShellFEA') # Change to path to ConcreteShellFEA
    num_epochs = 1000
    
    # Prep datasets
    # Training
    train_input_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/nonlinear/input/training/pca_input/pca_training.h5'
    )
    train_buckling_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/nonlinear/output/training/buckling_output/buckling_training.csv'
    )
    train_dataset = TabularMFSequentialBuckling(train_input_path, train_buckling_path)
    # Validation
    val_input_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/nonlinear/input/validation/pca_input/pca_validation.h5'
    )
    val_buckling_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/nonlinear/output/validation/buckling_output/buckling_validation.csv'
    )
    val_dataset = TabularMFSequentialBuckling(val_input_path, val_buckling_path)
    # Testing
    test_input_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/nonlinear/input/testing/pca_input/pca_testing.h5'
    )
    test_buckling_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/nonlinear/output/testing/buckling_output/buckling_testing.csv'
    )
    test_dataset = TabularMFSequentialBuckling(test_input_path, test_buckling_path)

    # Calculate input mean and std for normalisation based on training dataset
    input_mean = torch.mean(
        torch.cat(
            (
                train_dataset.shell_characteristics,
                train_dataset.preproc_inputs,
                train_dataset.linear_buckling_data
            ),
            dim=1
        ),
        dim=0
    )
    input_std = torch.std(
        torch.cat(
            (
                train_dataset.shell_characteristics,
                train_dataset.preproc_inputs,
                train_dataset.linear_buckling_data
            ),
            dim=1
        ),
        dim=0
    )

    # Update datasets with calculated mean and std
    train_dataset.input_mean = input_mean
    train_dataset.input_std = input_std
    train_dataset.normalize_input = True
    val_dataset.input_mean = input_mean
    val_dataset.input_std = input_std
    val_dataset.normalize_input = True
    test_dataset.input_mean = input_mean
    test_dataset.input_std = input_std
    test_dataset.normalize_input = True

    # Model configuration
    mlp_config = {
        'num_layers': 5,
        'neurons_per_layer': 32,
        'activation': 'Softplus',
        'dropout': 0.,
        'batch_norm': False,
        'initializer': 'xavier_uniform_',
        'optimizer': 'Adam',
        'learning_rate': 0.001,
        'batch_size': 8,
    }

    # Train model and save test predictions
    print('*** TRAINING MODEL ***')
    train_model(train_dataset, val_dataset, mlp_config, num_epochs)

    # Evaluate accuracy
    model_file = 'mf_sequential_MLP_model.pth' #  change to 'mf_sequential_MLP_best_model.pth' to see best model
    preds_file = 'mf_sequential_MLP_preds.pth' #  change to 'mf_sequential_MLP_best_preds.pth' to see best preds
    # Predictions when lf preds are used for eigenvalue buckling
    # preds_file = 'mf_sequential_MLP_best_preds_2.pth'
    print('*** EVALUATE MODEL ***')
    eval_model(
        test_dataset, model_file, preds_file, mlp_config
    )