import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import summary
from torch_geometric.loader import DataLoader
from torchmetrics.regression import MeanAbsolutePercentageError, MeanSquaredError, MeanAbsoluteError
from torch.nn.utils import clip_grad_norm_
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# Custom imports
from tools import save_checkpoint, EarlyStopping
from metrics import RMSELoss, MAPEPercent, RMSEMetric, AbsolutePercentageError
from datasets import GraphDataBuckling
from models import GNN

def train_model(train_dataset, val_dataset, gnn_config, num_epochs):
    """ A function that trains a GNN and saves it """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create DataLoaders for training and validation sets
    train_dataloader = DataLoader(
        train_dataset, batch_size=gnn_config['batch_size'], shuffle=True
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=gnn_config['batch_size'], shuffle=False
    )


    print('Number of training set samples = {:d}'.format(len(train_dataset)))
    print('Number of validation set samples = {:d}'.format(len(val_dataset)))

    # Instantiate the model
    model = GNN(num_features=train_dataset[0].num_features, output_size=1, config=gnn_config)
    model.to(device)
    first_batch = train_dataset[0]
    first_batch.x = first_batch.x.to(device)
    first_batch.edge_index = first_batch.edge_index.to(torch.int64).to(device)
    first_batch.buckling = first_batch.buckling.to(device)
    print(summary(model, first_batch))
    first_batch = None

    # Define loss function and optimizer
    optimizer = getattr(
        optim, gnn_config['optimizer']
    )(model.parameters(), lr=gnn_config['learning_rate'])
    
    # Define loss
    loss_dict = {
        'MSE': nn.MSELoss(),
        'RMSE': RMSELoss(),
        'MAE': nn.L1Loss()
    }
    criterion = loss_dict['MSE']

    # Metrics to compute
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }

    # Instantiate earlystopping
    early_stopping = EarlyStopping(patience=20, delta=0.1)

    # Instantiate dict for storing losses
    losses_history = {}
    for key in metrics.keys():
        losses_history[key] = {
            'train': [],
            'val': []
        }
    
    epochs_completed = 0
    for epoch in tqdm.tqdm(range(num_epochs), total=num_epochs):
        epochs_completed += 1

        train_metrics = train_func(model, optimizer, criterion, metrics, train_dataloader, device)
        val_metrics = eval_func(model, metrics, val_dataset, val_dataloader, device)

        # Store losses
        for m in metrics.keys():
            losses_history[m]['train'].append(train_metrics[m])
            losses_history[m]['val'].append(val_metrics[m])
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        for key in train_metrics.keys():
            print(f'    {key}: Train = {train_metrics[key]}, Val = {val_metrics[key]}')
        
        # Check if early stopping criteria are met
        if early_stopping(val_metrics['MSE']):
            print("Early stopping!")
            break
    
    # save final model
    model_file = Path.cwd().joinpath('./models/GNN_buckling_model.pth')
    save_checkpoint(model, optimizer, epoch, model_file)
    
    # Plot the learning curves
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['train'], label='Training Loss')
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['val'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Curves')
    plt.legend()
    plt.show()

def eval_model(test_dataset, model_file, preds_file, gnn_config):
    """ A function that loads a GNN, use it to make predictions on the
     testing set, save them, and assess the perfromance of the model """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Define test dataloader
    test_dataloader = DataLoader(
        test_dataset, batch_size=gnn_config['batch_size'], shuffle=False
    )

    # Load model
    model_state = torch.load(
        Path.cwd().joinpath(
            './models/{:s}'.format(model_file)
        )
    )
    model = GNN(num_features=test_dataset[0].num_features, output_size=1, config=gnn_config)
    model.load_state_dict(model_state['model_state_dict'])
    model.to(device)

    # Predict test set and save predictions
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }
    test_metrics = eval_func(model, metrics, test_dataset, test_dataloader, device, save_preds=True)

    # Load test predictions
    pred_data = torch.load(Path.cwd().joinpath('./results/raw_predictions/{:s}'.format(preds_file)))
    features = pred_data[:, :4]
    y = pred_data[:, 4]
    preds = pred_data[:, 5]

    # Compute final MAE, RMSE, MAPE, MAPE 5%
    mae_func = MeanAbsoluteError()
    rmse_func = RMSEMetric()
    mape_func = MeanAbsolutePercentageError()
    mape5_func = MAPEPercent(0.05)
    mae = mae_func(preds, y).item()
    rmse = rmse_func(preds, y).item()
    mape = mape_func(preds, y).item()
    mape5 = mape5_func(preds, y).item()
    print('Final test accuracy:')
    print('    MAE = {:f}'.format(mae))
    print('    RMSE = {:f}'.format(rmse))
    print('    MAPE = {:f}'.format(mape * 100))
    print('    MAPE 5% = {:f}'.format(mape5))

    # Plot predicted buckling vs FE buckling
    max_val = max(preds.max().item(), y.max().item())
    plt.plot([0.,max_val], [0.,max_val], color='black', linestyle='-', linewidth=1)
    plt.scatter(preds, y, s=10, marker='x', color='red', alpha=0.7)
    plt.axis("equal")
    plt.xlim(0, max_val)
    plt.ylim(0, max_val)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.show()

    # Determine Peason correlations with design space dimensions
    ape_func = AbsolutePercentageError()
    ape = ape_func(preds, y).numpy()
    dims = {
        'span': features[:, 0].numpy(),
        'height/span': features[:, 1].numpy() / features[:, 0].numpy(),
        'thickness/span': features[:, 2].numpy() / features[:, 0].numpy(),
        'material': features[:, 3].numpy()
    }
    print('Pearson correlations between design space dimensions and APEs')
    for dim, vals in dims.items():
        correlation = np.corrcoef(vals, ape)[0,1]
        print('    {:s} - APE correlation = {:f}'.format(dim, correlation))

def train_func(model, optimizer, criterion, metrics, train_dataloader, device):
    """ A function that iterates through the dataset (to complete a single epoch),
    compute the loss and update the model's weights and biases """
    model.train()  # Set the model to training mode
    running_train_metrics = {}
    for key in metrics.keys():
        running_train_metrics[key] = 0.
    
    for batch in train_dataloader:

        # send data to device
        batch.x = batch.x.to(device)
        batch.edge_index = batch.edge_index.to(torch.int64).to(device)
        batch.buckling = batch.buckling.to(device)
        batch.batch = batch.batch.to(device)

        optimizer.zero_grad()
        output = model(batch)
        # Compute loss
        loss = criterion(output, batch.buckling)
        # Backpropagate gradients
        loss.backward()
        # Clip gradients to prevent exploding gradients
        clip_grad_norm_(model.parameters(), 1.0)
        # Update weights
        optimizer.step()
        # Calculate metrics
        for key, metric_func in metrics.items():
            running_train_metrics[key] += metric_func(output, batch.buckling).item()

    # Final calc
    for key in metrics.keys():
        running_train_metrics[key] /= len(train_dataloader)   
    
    return running_train_metrics

def eval_func(model, metrics, dataset, eval_dataloader, device, save_preds=False):
    """ A function that iterates through the test dataset, makes predictions
     and potentially save them """
    model.eval()  # Set the model to evaluation mode
    running_eval_metrics = {}
    for key in metrics.keys():
        running_eval_metrics[key] = 0.

    full_features = []
    full_targets = []
    full_preds = []
    with torch.no_grad():
        for batch in eval_dataloader:

            # send data to device
            batch.x = batch.x.to(device)
            batch.edge_index = batch.edge_index.to(torch.int64).to(device)
            batch.buckling = batch.buckling.to(device)
            batch.batch = batch.batch.to(device)

            output = model(batch)

            for key, metric_func in metrics.items():
                running_eval_metrics[key] += metric_func(output, batch.buckling).item()

            if save_preds:
                features = batch.x * dataset.input_std.to(device) + dataset.input_mean.to(device)
                first_indices = torch.arange(0, features[:, 3:7].size(0), step=12881)
                features = features[:, 3:7][first_indices]
                full_features.append(features.cpu())
                full_targets.append(batch.buckling.cpu())
                full_preds.append(output.cpu())

    if save_preds:
        full_features = torch.cat(full_features, dim=0)
        full_targets = torch.cat(full_targets, dim=0)
        full_preds = torch.cat(full_preds, dim=0)
        preds_array = torch.cat((full_features, full_targets, full_preds), dim=1)
        preds_file = Path.cwd().joinpath('./results/raw_predictions/GNN_buckling_preds.pth')
        torch.save(preds_array, preds_file)

    # Final calc
    for key in metrics.keys():
        running_eval_metrics[key] /= len(eval_dataloader)

    return running_eval_metrics

if __name__ == '__main__':
    
    # Input vals
    DATASET_ROOT = Path('F:/ConcreteShellFEA') # change to local path to ConcreteShellFEA
    num_epochs = 1000
    batch_size = 128

    # Prep datasets
    # Training
    train_graph_folder_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/graphs/input_and_output/training/{:d}'.format(batch_size)
    )
    train_dataset = GraphDataBuckling(
        train_graph_folder_path, 12800, batch_size
    )
    # Validation
    val_graph_folder_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/graphs/input_and_output/validation/{:d}'.format(batch_size)
    )
    val_dataset = GraphDataBuckling(
        val_graph_folder_path, 3200, batch_size
    )
    # Testing
    test_graph_folder_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/graphs/input_and_output/testing/{:d}'.format(batch_size)
    )
    test_dataset = GraphDataBuckling(
        test_graph_folder_path, 4000, batch_size
    )

    # Calculate input mean and std for normalisation based on training and validation datasets
    mean = torch.zeros(train_dataset.num_node_x)
    count = 0
    for file_path in train_dataset.processed_file_names:
        with open(file_path, 'rb') as f:
            data = torch.load(f)
            mean += data.x.sum(dim=0)
            count += data.x.size(0)
    mean /= count
    std = torch.zeros(train_dataset.num_node_x)
    for file_path in train_dataset.processed_file_names:
        with open(file_path, 'rb') as f:
            data = torch.load(f)
            std += ((data.x - mean) ** 2).sum(dim=0)
    std = torch.sqrt(std / count)
    # Uncomment to use mean and std calculated based on another training/validation split, as in paper
    # mean = torch.tensor(
    #     [-1.3438e-10, -3.7690e-11,  3.2949e-01,  6.0032e+00, 
    #      6.2542e-01, 7.9875e-02,  3.5410e+01,  7.7237e-05,  7.7634e-03]
    # )
    # std = torch.tensor(
    #     [2.3117e+00, 2.3117e+00, 2.5303e-01, 2.3063e+00,
    #      2.5276e-01, 3.9403e-02, 4.9165e+00, 6.7407e-05, 8.7767e-02]
    # )

    # Update datasets with calculated mean and std
    train_dataset.input_mean = mean
    train_dataset.input_std = std
    train_dataset.normalize_input = True
    val_dataset.input_mean = mean
    val_dataset.input_std = std
    val_dataset.normalize_input = True
    test_dataset.input_mean = mean
    test_dataset.input_std = std
    test_dataset.normalize_input = True

    # Model configuration
    gnn_config = {
        'num_preproc_layers': 1,
        'num_postproc_layers': 1,
        'mlp_layers': 2,
        'mlp_neurons': 16,
        'num_layers': 2,
        'layer_connectivity': 'cat',
        'hidden_size': 5,
        'aggr': 'add',
        'activation': 'Tanh',
        'attention': 'none',
        'dropout': 0.,
        'batch_norm': False,
        'pooling_method': 'global_mean_pool',
        'optimizer': 'Adamax',
        'learning_rate': 0.01,
        'batch_size': batch_size
    }

    # Train model and save test predictions
    print('*** TRAINING MODEL ***')
    train_model(train_dataset, val_dataset, gnn_config, num_epochs)

    # Evaluate accuracy
    model_file = 'GNN_buckling_model.pth' #  change to 'GNN_buckling_best_model.pth' to see best model
    preds_file = 'GNN_buckling_preds.pth' #  change to 'GNN_buckling_best_preds.pth' to see best preds
    print('*** EVALUATE MODEL ***')
    eval_model(
        test_dataset, model_file, preds_file, gnn_config
    )