import os
import joblib
import h5py
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics.regression import MeanAbsolutePercentageError, MeanSquaredError, MeanAbsoluteError
from torchinfo import summary
from torch.nn.utils import clip_grad_norm_

# Custom imports
from tools import save_checkpoint, EarlyStopping, IncrementalPCAonGPU
from metrics import RMSELoss, RMSEMetric, MAPEPeakMetric, APEPeakSampleMetric
from datasets import ImageDataStressPCA
from models import CNN

def train_model(train_dataset, val_dataset, cnn_config, num_epochs):
    """ A function that trains a CNN and saves it """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create data loaders
    train_dataloader = DataLoader(train_dataset, batch_size=cnn_config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=cnn_config['batch_size'], shuffle=False)

    print('Number of training set samples = {:d}'.format(len(train_dataset)))
    print('Number of validation set samples = {:d}'.format(len(val_dataset)))
    print('Number of testing set samples = {:d}'.format(len(test_dataset)))

    # Instantiate the model
    model = CNN(cnn_config, 99, train_dataset[0][1].shape[0], train_dataset[0][2].shape[0])
    model.to(device)

    summary(
        model,
        input_size=[
            (cnn_config['batch_size'], 3, 99, 99),
            (cnn_config['batch_size'], train_dataset[0][1].shape[0])
        ]
    )

    # Define loss function and optimizer
    optimizer = getattr(
        optim, cnn_config['optimizer']
    )(model.parameters(), lr=cnn_config['learning_rate'])

    # Define loss
    loss_dict = {
        'MSE': nn.MSELoss(),
        'RMSE': RMSELoss(),
        'MAE': nn.L1Loss()
    }
    criterion = loss_dict['MSE']

    # Metrics to compute
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }

    # Instantiate the EarlyStopping class
    early_stopping = EarlyStopping(patience=100, delta=0.01)

    # Instantiate dict for storing losses
    losses_history = {}
    for key in metrics.keys():
        losses_history[key] = {
            'train': [],
            'val': []
        }

    # Training loop
    epochs_completed = 0
    for epoch in tqdm.tqdm(range(num_epochs), total=num_epochs):
        epochs_completed += 1

        train_metrics = train_func(model, optimizer, criterion, metrics, train_dataloader)
        val_metrics = eval_func(model, metrics, val_dataset, val_dataloader)

        # Store metrics
        for m in metrics.keys():
            losses_history[m]['train'].append(train_metrics[m])
            losses_history[m]['val'].append(val_metrics[m])
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        for key in train_metrics.keys():
            print(f'    {key}: Train = {train_metrics[key]}, Val = {val_metrics[key]}')
        
        # Check if early stopping criteria are met
        if early_stopping(val_metrics['MSE']):
            print("Early stopping!")
            break

    
    model_file = Path.cwd().joinpath('./models/CNN_stress_model.pth')
    save_checkpoint(model, optimizer, epoch, model_file)

    # Plot the learning curves
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['train'], label='Training Loss')
    plt.plot(range(1, epochs_completed+1), losses_history['RMSE']['val'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Curves')
    plt.legend()
    plt.show()

def eval_model(test_dataset, model_file, preds_file, raw_stress_paths, pca_path, scaler_path, cnn_config):
    """ A function that loads a CNN, use it to make predictions on the
     testing set, save them, and assess the perfromance of the model """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Define test dataloader
    test_dataloader = DataLoader(
        test_dataset, batch_size=cnn_config['batch_size'], shuffle=False
    )

    # Load model
    model_state = torch.load(
        Path.cwd().joinpath(
            './models/{:s}'.format(model_file)
        )
    )
    model = CNN(cnn_config, 99, test_dataset[0][1].shape[0], test_dataset[0][2].shape[0])
    model.load_state_dict(model_state['model_state_dict'])
    model.to(device)

    # Predict test set and save predictions
    metrics = {
        'MSE': MeanSquaredError().to(device),
        'MAPE': MeanAbsolutePercentageError().to(device),
        'RMSE': RMSEMetric().to(device),
        'MAE': MeanAbsoluteError().to(device)
    }
    test_metrics = eval_func(model, metrics, test_dataset, test_dataloader, save_preds=True)

    # Load test predictions
    print('preds')
    pred_data = torch.load(Path.cwd().joinpath('./results/raw_predictions/{:s}'.format(preds_file)))
    features = pred_data[:, :4]
    pca_y = pred_data[:, 4:1418]
    pca_preds = pred_data[:, 1418:]

    # Load raw predictions
    raw_stress_tensors = []
    for path in raw_stress_paths:
        with h5py.File(path, 'r') as hf:
            raw_stress_data = torch.from_numpy(hf['df']['block0_values'][:]).to(torch.float32)
            raw_stress_tensors.append(raw_stress_data[:, 4:-1])
    y = torch.concatenate(raw_stress_tensors)
    y *= 0.001

    # Back-transform PCA preds
    pca_model = joblib.load(pca_path)
    scaler = torch.load(scaler_path)
    pca_preds_scaled = pca_model.inverse_transform(pca_preds).cpu()
    backtransformed_preds = (pca_preds_scaled * scaler['std'].cpu()) + scaler['mean'].cpu()
    backtransformed_preds *= 0.001

    # Separate data into principal stresses
    s1_preds = backtransformed_preds[:, :134724]
    s2_preds = backtransformed_preds[:, 134724:269448]
    s3_preds = backtransformed_preds[:, 269448:]
    s1_y = y[:, :134724]
    s2_y = y[:, 134724:269448]
    s3_y = y[:, 269448:]
    sorted_preds = [s1_preds, s2_preds, s3_preds]
    sorted_y = [s1_y, s2_y, s3_y]

    # Compute final MAE, RMSE, MAPEpeak
    mae_func = MeanAbsoluteError()
    rmse_func = RMSEMetric()
    mape_peak_func = MAPEPeakMetric()
    print('Final test accuracy:')
    for i in range(3):
        print('    - S{:d}'.format(i+1))
        # Compute MAE, RMSE, MAPE and MAPE_5%
        mae = mae_func(sorted_preds[i], sorted_y[i]).item()
        rmse = rmse_func(sorted_preds[i], sorted_y[i]).item()
        mape_peak = mape_peak_func(sorted_preds[i], sorted_y[i]).item()
        print('        MAE = {:f}'.format(mae))
        print('        RMSE = {:f}'.format(rmse))
        print('        MAPEpeak = {:f}'.format(mape_peak))

    # Plot predicted peak stress vs FE peak stress
    print('Predicted vs FE peak stress')
    for i in range(3):
        print('    - S{:d}'.format(i+1))
        peak_pred, _ = torch.max(torch.absolute(sorted_preds[i]), dim=1)
        peak_target, _ = torch.max(torch.absolute(sorted_y[i]), dim=1)
        max_val = max(peak_pred.max().item(), peak_target.max().item())
        plt.plot([0.,max_val], [0.,max_val], color='black', linestyle='-', linewidth=1)
        plt.scatter(peak_pred, peak_target, s=10, marker='x', color='red', alpha=0.7)
        plt.axis("equal")
        plt.xlim(0, max_val)
        plt.ylim(0, max_val)
        plt.gca().set_aspect('equal', adjustable='box')
        plt.show()

    # Determine Peason correlations with design space dimensions
    ape_peak_sample_func = APEPeakSampleMetric()
    print('Pearson correlations between design space dimensions and APEpeak')
    for i in range(3):
        print('    - S{:d}'.format(i+1))
        ape_peak = ape_peak_sample_func(sorted_preds[i], sorted_y[i]).numpy()
        dims = {
            'span': features[:, 0].numpy(),
            'height/span': features[:, 1].numpy() / features[:, 0].numpy(),
            'thickness/span': features[:, 2].numpy() / features[:, 0].numpy(),
            'material': features[:, 3].numpy()
        }
        for dim, vals in dims.items():
            correlation = np.corrcoef(vals, ape_peak)[0,1]
            print('        {:s} - APE correlation = {:f}'.format(dim, correlation))

def train_func(model, optimizer, criterion, metrics, train_dataloader):
    """ A function that iterates through the dataset (to complete a single epoch),
    compute the loss and update the model's weights and biases """
    model.train()  # Set the model to training mode
    running_train_metrics = {}
    for key in metrics.keys():
        running_train_metrics[key] = 0.

    for images, additional_inputs, targets in train_dataloader:
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, additional_inputs)

        # Compute the loss
        loss = criterion(outputs, targets)

        # Backpropagate gradients
        loss.backward()

        # Clip gradients to prevent exploding gradients
        clip_grad_norm_(model.parameters(), 1.0)

        # Update weights
        optimizer.step()

        # Calculate metrics
        for key, metric_func in metrics.items():
            running_train_metrics[key] += metric_func(outputs, targets).item()

    # Final calc
    for key in metrics.keys():
        running_train_metrics[key] /= len(train_dataloader)
    
    return running_train_metrics

def eval_func(model, metrics, dataset, eval_dataloader, save_preds=False):
    """ A function that iterates through the test dataset, makes predictions
     and potentially save them """
    model.eval()  # Set the model to evaluation mode
    running_eval_metrics = {}
    for key in metrics.keys():
        running_eval_metrics[key] = 0.

    full_features = []
    full_targets = []
    full_preds = []
    with torch.no_grad():
        for images, additional_inputs, targets in eval_dataloader:
            outputs = model(images, additional_inputs)
            for key, metric_func in metrics.items():
                running_eval_metrics[key] += metric_func(outputs, targets).item()

            if save_preds:
                if dataset.normalize_scalars:
                    additional_inputs = additional_inputs * dataset.input_std + dataset.input_mean
                full_features.append(additional_inputs[:, :4].cpu())
                full_targets.append(targets.cpu())
                full_preds.append(outputs.cpu())
    
    if save_preds:
        full_features = torch.cat(full_features, dim=0)
        full_targets = torch.cat(full_targets, dim=0)
        full_preds = torch.cat(full_preds, dim=0)
        preds_array = torch.cat((full_features, full_targets, full_preds), dim=1)
        preds_file = Path.cwd().joinpath('./results/raw_predictions/CNN_stress_preds.pth')
        torch.save(preds_array, preds_file)
    
    # Final calc
    for key in metrics.keys():
        running_eval_metrics[key] /= len(eval_dataloader)

    return running_eval_metrics

if __name__ == '__main__':

    # Input vals
    DATASET_ROOT = Path('F:/ConcreteShellFEA') # change to local path to ConcreteShellFEA
    num_epochs = 1000
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prep datasets
    # Training
    train_image_folder_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/images/input/training/imperfection_map_99x99'
    )
    train_perfect_scalars_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/images/input/training/training_scalars.csv'
    )
    train_imperfect_scalars_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/images/input/training/training_scalars.csv'
    )
    train_perfect_stress_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/tabular/output/training/pca_output/pca_training.h5'
    )
    train_imperfect_stress_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/tabular/output/training/pca_output/pca_training.h5'
    )
    train_perfect_raw_stress_paths = [
        DATASET_ROOT.joinpath(
            './datasets/PerfectShell_LinearFEA/tabular/output/training/raw_output/s_GQ_{:d}.h5'.format(i)
        ) for i in range(
            len(os.listdir(
                DATASET_ROOT.joinpath('./datasets/PerfectShell_LinearFEA/tabular/output/training/raw_output')
            ))
        )
    ]
    train_imperfect_raw_stress_paths = [
        DATASET_ROOT.joinpath(
            './datasets/ImperfectShell_LinearFEA/tabular/output/training/raw_output/s_GQ_{:d}.h5'.format(i)
        ) for i in range(
            len(os.listdir(
                DATASET_ROOT.joinpath('./datasets/ImperfectShell_LinearFEA/tabular/output/training/raw_output')
            ))
        )
    ]
    train_dataset = ImageDataStressPCA(
        train_image_folder_path, train_perfect_scalars_path, train_imperfect_scalars_path,
        train_perfect_stress_path, train_imperfect_stress_path
    )
    # Validation
    val_image_folder_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/images/input/validation/imperfection_map_99x99'
    )
    val_perfect_scalars_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/images/input/validation/validation_scalars.csv'
    )
    val_imperfect_scalars_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/images/input/validation/validation_scalars.csv'
    )
    val_perfect_stress_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/tabular/output/validation/pca_output/pca_validation.h5'
    )
    val_imperfect_stress_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/tabular/output/validation/pca_output/pca_validation.h5'
    )
    val_perfect_raw_stress_paths = [
        DATASET_ROOT.joinpath(
            './datasets/PerfectShell_LinearFEA/tabular/output/validation/raw_output/s_GQ_{:d}.h5'.format(i)
        ) for i in range(
            len(os.listdir(
                DATASET_ROOT.joinpath('./datasets/PerfectShell_LinearFEA/tabular/output/validation/raw_output')
            ))
        )
    ]
    val_imperfect_raw_stress_paths = [
        DATASET_ROOT.joinpath(
            './datasets/ImperfectShell_LinearFEA/tabular/output/validation/raw_output/s_GQ_{:d}.h5'.format(i)
        ) for i in range(
            len(os.listdir(
                DATASET_ROOT.joinpath('./datasets/ImperfectShell_LinearFEA/tabular/output/validation/raw_output')
            ))
        )
    ]
    val_dataset = ImageDataStressPCA(
        val_image_folder_path, val_perfect_scalars_path, val_imperfect_scalars_path,
        val_perfect_stress_path, val_imperfect_stress_path
    )
    # Testing
    test_image_folder_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/images/input/testing/imperfection_map_99x99'
    )
    test_perfect_scalars_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/images/input/testing/testing_scalars.csv'
    )
    test_imperfect_scalars_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/images/input/testing/testing_scalars.csv'
    )
    test_perfect_stress_path = DATASET_ROOT.joinpath(
        './datasets/PerfectShell_LinearFEA/tabular/output/testing/pca_output/pca_testing.h5'
    )
    test_imperfect_stress_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/tabular/output/testing/pca_output/pca_testing.h5'
    )
    test_perfect_raw_stress_paths = [
        DATASET_ROOT.joinpath(
            './datasets/PerfectShell_LinearFEA/tabular/output/testing/raw_output/s_GQ_{:d}.h5'.format(i)
        ) for i in range(
            len(os.listdir(
                DATASET_ROOT.joinpath('./datasets/PerfectShell_LinearFEA/tabular/output/testing/raw_output')
            ))
        )
    ]
    test_imperfect_raw_stress_paths = [
        DATASET_ROOT.joinpath(
            './datasets/ImperfectShell_LinearFEA/tabular/output/testing/raw_output/s_GQ_{:d}.h5'.format(i)
        ) for i in range(
            len(os.listdir(
                DATASET_ROOT.joinpath('./datasets/ImperfectShell_LinearFEA/tabular/output/testing/raw_output')
            ))
        )
    ]
    test_dataset = ImageDataStressPCA(
        test_image_folder_path, test_perfect_scalars_path, test_imperfect_scalars_path,
        test_perfect_stress_path, test_imperfect_stress_path
    )

    # Calculate scalar input mean and std for normalisation based on training and validation datasets
    train_val_inputs = torch.cat(
        (
            torch.cat(
                (
                    train_dataset.perfect_scalars,
                    train_dataset.perfect_pca_stresses
                ),
                dim=1
            ),
            torch.cat(
                (
                    val_dataset.perfect_scalars,
                    val_dataset.perfect_pca_stresses
                ),
                dim=1
            )
        ),
        dim=0
    )
    # Re-order samples as they were in the paper
    ordering_file = Path.cwd().joinpath('ordering_indices.csv')
    ordering_inds = np.loadtxt(ordering_file).astype(np.int64).tolist()
    train_val_inputs = train_val_inputs[ordering_inds]
    input_mean = torch.mean(train_val_inputs, dim=0)
    input_std = torch.std(train_val_inputs, dim=0)

    # Update datasets with calculated mean and std
    train_dataset.input_mean = input_mean
    train_dataset.input_std = input_std
    train_dataset.normalize_scalars = True
    val_dataset.input_mean = input_mean
    val_dataset.input_std = input_std
    val_dataset.normalize_scalars = True
    test_dataset.input_mean = input_mean
    test_dataset.input_std = input_std
    test_dataset.normalize_scalars = True

    # Model configuration
    cnn_config = {
        'num_conv_layers': 5,
        'kernels_per_layer': 16,
        'kernel_size': 9,
        'pooling_size': 7,
        'activation': 'ReLU',
        'pooling': 'AvgPool2d',
        'conv_features': 99,
        'num_dense_layers': 4,
        'neurons_per_layer': 1024,
        'dropout': 0.09,
        'batch_norm': True,
        'initializer': 'normal_',
        'optimizer': 'Adamax',
        'learning_rate': 0.001,
        'batch_size': 256
    }

    # Train model and save test predictions
    print('*** TRAINING MODEL ***')
    train_model(train_dataset, val_dataset, cnn_config, num_epochs)

    # Evaluate accuracy
    pca_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/tabular/output/pca_stress/pca_model_GQ.joblib'
    )
    scaler_path = DATASET_ROOT.joinpath(
        './datasets/ImperfectShell_LinearFEA/tabular/output/pca_stress/scaler_params_GQ.pth'
    )
    model_file = 'CNN_stress_model.pth' #  change to 'CNN_stress_best_model.pth' to see best model
    preds_file = 'CNN_stress_preds.pth' #  change to 'CNN_stress_best_preds.pth' to see best preds
    print('*** EVALUATE MODEL ***')
    eval_model(
        test_dataset, model_file,
        preds_file, test_imperfect_raw_stress_paths,
        pca_path, scaler_path, cnn_config
    )