import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import summary
from torch_geometric.loader import DataLoader
from torchmetrics.regression import MeanAbsolutePercentageError, MeanSquaredError, MeanAbsoluteError
from torch.nn.utils import clip_grad_norm_
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# Custom imports
from tools import save_checkpoint, EarlyStopping
from metrics import RMSELoss, MAPEPeakMetric, RMSEMetric, APEPeakSampleMetric
from datasets import GraphDataStress
from models import GNN

def train_model(train_dataset, val_dataset, gnn_config, num_epochs):
    """ A function that trains a GNN and saves it """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create DataLoaders for training and validation sets
    train_dataloader = DataLoader(
        train_dataset, batch_size=gnn_config['batch_size'], shuffle=True
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=gnn_config['batch_size'], shuffle=False
    )


    print('Number of training set samples = {:d}'.format(len(train_dataset)))
    print('Number of validation set samples = {:d}'.format(len(val_dataset)))

    # Instantiate the model
    model = GNN(num_features=train_dataset[0].num_features, output_size=12, config=gnn_config)
    model.to(device)
    first_batch = train_dataset[0]
    first_batch.x = first_batch.x.to(device)
    first_batch.edge_index = first_batch.edge_index.to(torch.int64).to(device)
    first_batch.stress = first_batch.stress.to(device)
    print(summary(model, first_batch))
    first_batch = None

    # Define loss function and optimizer
    optimizer = getattr(
        optim, gnn_config['optimizer']
    )(model.parameters(), lr=gnn_config['learning_rate'])
    
    # Define loss
    loss_dict = {
        'MSE': nn.MSELoss(),
        'RMSE': RMSELoss(),
        'MAE': nn.L1Loss()
    }
    criterion = loss_dict['MSE']

    # Metrics to compute
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }

    # Instantiate earlystopping
    early_stopping = EarlyStopping(patience=60, delta=0.01)

    # Instantiate dict for storing losses
    losses_history = {}
    for key in metrics.keys():
        losses_history[key] = {
            'train': [],
            'val': []
        }
    
    epochs_completed = 0
    for epoch in tqdm.tqdm(range(num_epochs), total=num_epochs):
        epochs_completed += 1

        train_metrics = train_func(model, optimizer, criterion, metrics, train_dataloader, device)
        val_metrics = eval_func(model, metrics, val_dataset, val_dataloader, device)

        # Store losses
        for m in metrics.keys():
            losses_history[m]['train'].append(train_metrics[m])
            losses_history[m]['val'].append(val_metrics[m])
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        for key in train_metrics.keys():
            print(f'    {key}: Train = {train_metrics[key]}, Val = {val_metrics[key]}')
        
        # Check if early stopping criteria are met
        if early_stopping(val_metrics['MSE']):
            print("Early stopping!")
            break
    
    # save final model
    model_file = Path.cwd().joinpath('./models/GNN_stress_model.pth')
    save_checkpoint(model, optimizer, epoch, model_file)
    
    # Plot the learning curves
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['train'], label='Training Loss')
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['val'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Curves')
    plt.legend()
    plt.show()

def eval_model(test_dataset, model_file, preds_file, gnn_config):
    """ A function that loads a GNN, use it to make predictions on the
     testing set, save them, and assess the perfromance of the model """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Define test dataloader
    test_dataloader = DataLoader(
        test_dataset, batch_size=gnn_config['batch_size'], shuffle=False
    )

    # Load model
    model_state = torch.load(
        Path.cwd().joinpath(
            './models/{:s}'.format(model_file)
        )
    )
    model = GNN(num_features=test_dataset[0].num_features, output_size=12, config=gnn_config)
    model.load_state_dict(model_state['model_state_dict'])
    model.to(device)

    # Predict test set and save predictions
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }
    test_metrics = eval_func(
        model, metrics, test_dataset, test_dataloader, device, save_preds=True,
        gnn_config=gnn_config
    )

    # Load test predictions
    pred_data = torch.load(Path.cwd().joinpath('./results/raw_predictions/{:s}'.format(preds_file)))
    features = pred_data[:, :4]
    y = pred_data[:, 4:154576]
    preds = pred_data[:, 154576:]

    # Separate data into principal stresses
    s1_preds = preds[:, :51524]
    s2_preds = preds[:, 51524:103048]
    s3_preds = preds[:, 103048:]
    s1_y = y[:, :51524]
    s2_y = y[:, 51524:103048]
    s3_y = y[:, 103048:]
    sorted_preds = [s1_preds, s2_preds, s3_preds]
    sorted_y = [s1_y, s2_y, s3_y]

    # Compute final MAE, RMSE, MAPEpeak
    mae_func = MeanAbsoluteError()
    rmse_func = RMSEMetric()
    mape_peak_func = MAPEPeakMetric()
    print('Final test accuracy:')
    for i in range(3):
        print('    - S{:d}'.format(i+1))
        # Compute MAE, RMSE, MAPE and MAPE_5%
        mae = mae_func(sorted_preds[i], sorted_y[i]).item()
        rmse = rmse_func(sorted_preds[i], sorted_y[i]).item()
        mape_peak = mape_peak_func(sorted_preds[i], sorted_y[i]).item()
        print('        MAE = {:f}'.format(mae))
        print('        RMSE = {:f}'.format(rmse))
        print('        MAPEpeak = {:f}'.format(mape_peak))

    # Plot predicted peak stress vs FE peak stress
    print('Predicted vs FE peak stress')
    for i in range(3):
        print('    - S{:d}'.format(i+1))
        peak_pred, _ = torch.max(torch.absolute(sorted_preds[i]), dim=1)
        peak_target, _ = torch.max(torch.absolute(sorted_y[i]), dim=1)
        max_val = max(peak_pred.max().item(), peak_target.max().item())
        plt.plot([0.,max_val], [0.,max_val], color='black', linestyle='-', linewidth=1)
        plt.scatter(peak_pred, peak_target, s=10, marker='x', color='red', alpha=0.7)
        plt.axis("equal")
        plt.xlim(0, max_val)
        plt.ylim(0, max_val)
        plt.gca().set_aspect('equal', adjustable='box')
        plt.show()
    
    # Determine Peason correlations with design space dimensions
    ape_peak_sample_func = APEPeakSampleMetric()
    print('Pearson correlations between design space dimensions and APEpeak')
    for i in range(3):
        print('    - S{:d}'.format(i+1))
        ape_peak = ape_peak_sample_func(sorted_preds[i], sorted_y[i]).numpy()
        dims = {
            'span': features[:, 0].numpy(),
            'height/span': features[:, 1].numpy() / features[:, 0].numpy(),
            'thickness/span': features[:, 2].numpy() / features[:, 0].numpy(),
            'material': features[:, 3].numpy()
        }
        for dim, vals in dims.items():
            correlation = np.corrcoef(vals, ape_peak)[0,1]
            print('        {:s} - APE correlation = {:f}'.format(dim, correlation))

def train_func(model, optimizer, criterion, metrics, train_dataloader, device):
    """ A function that iterates through the dataset (to complete a single epoch),
    compute the loss and update the model's weights and biases """
    model.train()  # Set the model to training mode
    running_train_metrics = {}
    for key in metrics.keys():
        running_train_metrics[key] = 0.
    
    for batch in train_dataloader:

        # send data to device
        batch.x = batch.x.to(device)
        batch.edge_index = batch.edge_index.to(torch.int64).to(device)
        batch.stress = batch.stress.to(device)
        batch.batch = batch.batch.to(device)

        optimizer.zero_grad()
        output = model(batch)
        # Compute loss
        loss = criterion(output, batch.stress)
        # Backpropagate gradients
        loss.backward()
        # Clip gradients to prevent exploding gradients
        clip_grad_norm_(model.parameters(), 1.0)
        # Update weights
        optimizer.step()
        # Calculate metrics
        for key, metric_func in metrics.items():
            running_train_metrics[key] += metric_func(output, batch.stress).item()

    # Final calc
    for key in metrics.keys():
        running_train_metrics[key] /= len(train_dataloader)   
    
    return running_train_metrics

def eval_func(model, metrics, dataset, eval_dataloader, device, save_preds=False, gnn_config=None):
    """ A function that iterates through the test dataset, makes predictions
     and potentially save them """
    model.eval()  # Set the model to evaluation mode
    running_eval_metrics = {}
    for key in metrics.keys():
        running_eval_metrics[key] = 0.

    full_features = []
    full_targets = []
    full_preds = []
    with torch.no_grad():
        for batch in eval_dataloader:

            # send data to device
            batch.x = batch.x.to(device)
            batch.edge_index = batch.edge_index.to(torch.int64).to(device)
            batch.stress = batch.stress.to(device)
            batch.batch = batch.batch.to(device)

            output = model(batch)

            for key, metric_func in metrics.items():
                running_eval_metrics[key] += metric_func(output, batch.stress).item()

            if save_preds:
                features = batch.x * dataset.input_std.to(device) + dataset.input_mean.to(device)
                first_indices = torch.arange(0, features[:, 3:7].size(0), step=12881)
                features = features[:, 3:7][first_indices]

                for i in range(16):
                    s1_ind = [0,3,6,9]
                    s2_ind = [1,4,7,10]
                    s3_ind = [2,5,8,11]
                    s_inds = s1_ind + s2_ind + s3_ind
                    sample_targets = batch.stress[i*12881:(i+1)*12881, s_inds]
                    full_targets.append(sample_targets.T.reshape(-1).cpu().unsqueeze(0))
                    sample_outs = output[i*12881:(i+1)*12881, s_inds]
                    full_preds.append(sample_outs.T.reshape(-1).cpu().unsqueeze(0))


                full_features.append(features.cpu())

    if save_preds:
        full_features = torch.cat(full_features, dim=0)
        full_targets = torch.cat(full_targets, dim=0)
        full_preds = torch.cat(full_preds, dim=0)
        preds_array = torch.cat((full_features, full_targets, full_preds), dim=1)
        preds_file = Path.cwd().joinpath('./results/raw_predictions/GNN_stress_preds.pth')
        torch.save(preds_array, preds_file)

    # Final calc
    for key in metrics.keys():
        running_eval_metrics[key] /= len(eval_dataloader)

    return running_eval_metrics

if __name__ == '__main__':
    
    # Input vals
    DATASET_ROOT = Path('F:/ConcreteShellFEA') # change to local path to ConcreteShellFEA
    num_epochs = 1000
    batch_size = 16

    # Prep datasets
    # Training
    train_graph_folder_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/graphs/input_and_output/training/{:d}'.format(batch_size)
    )
    train_dataset = GraphDataStress(
        train_graph_folder_path, 12800, batch_size
    )
    # Validation
    val_graph_folder_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/graphs/input_and_output/validation/{:d}'.format(batch_size)
    )
    val_dataset = GraphDataStress(
        val_graph_folder_path, 3200, batch_size
    )
    # Testing
    test_graph_folder_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/graphs/input_and_output/testing/{:d}'.format(batch_size)
    )
    test_dataset = GraphDataStress(
        test_graph_folder_path, 4000, batch_size
    )

    # Calculate input mean and std for normalisation based on training and validation datasets
    mean = torch.zeros(train_dataset.num_node_x)
    count = 0
    for file_path in train_dataset.processed_file_names:
        with open(file_path, 'rb') as f:
            data = torch.load(f)
            mean += data.x.sum(dim=0)
            count += data.x.size(0)
    mean /= count
    std = torch.zeros(train_dataset.num_node_x)
    for file_path in train_dataset.processed_file_names:
        with open(file_path, 'rb') as f:
            data = torch.load(f)
            std += ((data.x - mean) ** 2).sum(dim=0)
    std = torch.sqrt(std / count)

    # Update datasets with calculated mean and std
    train_dataset.input_mean = mean
    train_dataset.input_std = std
    train_dataset.normalize_input = True
    val_dataset.input_mean = mean
    val_dataset.input_std = std
    val_dataset.normalize_input = True
    test_dataset.input_mean = mean
    test_dataset.input_std = std
    test_dataset.normalize_input = True

    # Model configuration
    gnn_config = {
        'num_preproc_layers': 1,
        'num_postproc_layers': 0,
        'mlp_layers': 6,
        'mlp_neurons': 32,
        'num_layers': 3,
        'layer_connectivity': 'cat',
        'hidden_size': 9,
        'aggr': 'add',
        'activation': 'ReLU',
        'attention': 'none',
        'dropout': 0.,
        'batch_norm': False,
        'optimizer': 'Adam',
        'learning_rate': 0.01,
        'batch_size': batch_size
    }

    # Train model and save test predictions
    print('*** TRAINING MODEL ***')
    train_model(train_dataset, val_dataset, gnn_config, num_epochs)

    # Evaluate accuracy
    model_file = 'GNN_stress_model.pth' #  change to 'GNN_stress_best_model.pth' to see best model
    preds_file = 'GNN_stress_preds.pth' #  change to 'GNN_stress_best_preds.pth' to see best preds
    print('*** EVALUATE MODEL ***')
    eval_model(
        test_dataset, model_file, preds_file, gnn_config
    )