import torch
import h5py
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch_geometric.data import Dataset as GraphDataset

class TabularDataBuckling(Dataset):
    """ A class for handling tabular buckling data from the ConcreteShellFEA dataset """
    def __init__(self, preproc_input_path, buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.buckling_path = buckling_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load buckling data
        self.buckling_data = self.load_buckling_data()
        # Load PCA-transformed inputs
        self.preproc_inputs= self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (self.shell_characteristics, self.preproc_inputs),
            dim=1
        )[idx]
        # outputs
        targets = self.buckling_data[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics
    
    def load_buckling_data(self):
        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)
        return buckling_data

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return preproc_inputs

class TabularDataStressPCA(Dataset):
    """ A class for handling tabular stress data from the ConcreteShellFEA dataset """
    def __init__(self, preproc_input_path, pca_stress_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.pca_stress_path = pca_stress_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load PCA-transformed inputs and outputs
        self.preproc_inputs, self.pca_stresses = self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (self.shell_characteristics, self.preproc_inputs),
            dim=1
        )[idx]
        # outputs
        targets = self.pca_stresses[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        # Load stresses
        with h5py.File(self.pca_stress_path, 'r') as hf:
            pca_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return preproc_inputs, pca_stresses

class ImageDataBuckling(Dataset):
    """ A class for handling image buckling data from the ConcreteShellFEA dataset """
    def __init__(self, image_folder_path, scalars_path, buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_folder_path = image_folder_path
        self.scalars_path = scalars_path
        self.buckling_path = buckling_path
        self.samples = pd.read_csv(scalars_path, header=None, usecols=[0]).iloc[:, 0].astype(str).tolist()
        self.normalize_scalars = False

        # Load scalar inputs
        self.scalars = self.load_scalars()
        # Load buckling data
        self.buckling_data = self.load_buckling_data()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_file = self.image_folder_path.joinpath(f'{self.samples[idx]}.h5')

        # Load image from HDF5
        with h5py.File(image_file, 'r') as hf:
            image = hf['image_data'][()]
        # Add a channel dimension to the image
        image = image[None, :, :]
        image = torch.from_numpy(image).float().to(self.device)

        # Load additional scalar values
        scalars = self.scalars[idx]
        buckling_factor = self.buckling_data[idx]

        # Normalize scalar input if needed
        if self.normalize_scalars:
            scalars = (scalars - self.input_mean) / self.input_std

        return image, scalars, buckling_factor
    
    def load_scalars(self):
        scalars = torch.from_numpy(
            pd.read_csv(self.scalars_path, header=None)[[1,2,3,4]].to_numpy()
        ).to(torch.float32).to(self.device)
        return scalars
    
    def load_buckling_data(self):
        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)
        return buckling_data

class ImageDataStress(Dataset):
    """ A class for handling image stress data from the ConcreteShellFEA dataset """
    def __init__(self, image_folder_path, scalars_path, preproc_stress_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_folder_path = image_folder_path
        self.scalars_path = scalars_path
        self.preproc_stress_path = preproc_stress_path
        self.samples = pd.read_csv(scalars_path, header=None, usecols=[0]).iloc[:, 0].astype(str).tolist()
        self.normalize_scalars = False

        # Load scalar inputs
        self.scalars = self.load_scalars()
        # Load preprocessed data
        self.pca_stresses = self.load_preproc_data()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_file = self.image_folder_path.joinpath(f'{self.samples[idx]}.h5')

        # Load image from HDF5
        with h5py.File(image_file, 'r') as hf:
            image = hf['image_data'][()]
        # Add a channel dimension to the image
        image = image[None, :, :]
        image = torch.from_numpy(image).float().to(self.device)

        # Load additional scalar values
        scalars = self.scalars[idx]
        stress_data = self.pca_stresses[idx]

        # Normalize scalar input if needed
        if self.normalize_scalars:
            scalars = (scalars - self.input_mean) / self.input_std

        return image, scalars, stress_data
    
    def load_scalars(self):
        scalars = torch.from_numpy(
            pd.read_csv(self.scalars_path, header=None)[[1,2,3,4]].to_numpy()
        ).to(torch.float32).to(self.device)
        return scalars
    
    def load_preproc_data(self):
        # Load stresses
        with h5py.File(self.preproc_stress_path, 'r') as hf:
            preproc_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return preproc_stresses

class GraphDataBuckling(GraphDataset):
    """ A class for handling graph buckling data from the ConcreteShellFEA dataset """
    def __init__(self, graph_folder_path, num_samples, batch_size, transform=None, pre_transform=None, pre_filter=None):
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.normalize_input = False
        self.num_node_x = 9

        super().__init__(graph_folder_path, transform, pre_transform, pre_filter)

        self.root = graph_folder_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @property
    def processed_file_names(self):
        num_batches = math.ceil(self.num_samples / self.batch_size)
        return [
            self.root.joinpath(
                'batch_{:d}.pt'.format(b_num)
            ) for b_num in range(num_batches)
        ]

    def len(self):
        return self.num_samples

    def get(self, idx):

        file_ind = idx // self.batch_size
        graph_ind = idx % self.batch_size
        file_path = self.processed_file_names[file_ind]
        with open(file_path, 'rb') as f:
            graph_batch = torch.load(f, weights_only=False)
        graph = graph_batch.to_data_list()[graph_ind]

        # Normalize node embeddings using the precomputed mean and std
        if self.normalize_input:
            graph.x = (graph.x - self.input_mean) / self.input_std

        return graph

class GraphDataStress(GraphDataset):
    """ A class for handling graph stress data from the ConcreteShellFEA dataset """
    def __init__(self, graph_folder_path, num_samples, batch_size, transform=None, pre_transform=None, pre_filter=None):
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.normalize_input = False
        self.num_node_x = 9

        super().__init__(graph_folder_path, transform, pre_transform, pre_filter)

        self.root = graph_folder_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @property
    def processed_file_names(self):
        num_batches = math.ceil(self.num_samples / self.batch_size)
        return [
            self.root.joinpath(
                'batch_{:d}.pt'.format(b_num)
            ) for b_num in range(num_batches)
        ]

    def len(self):
        return self.num_samples

    def get(self, idx):

        file_ind = idx // self.batch_size
        graph_ind = idx % self.batch_size
        file_path = self.processed_file_names[file_ind]
        with open(file_path, 'rb') as f:
            graph_batch = torch.load(f, weights_only=False)
        graph = graph_batch.to_data_list()[graph_ind]

        # Normalize node embeddings using the precomputed mean and std
        if self.normalize_input:
            graph.x = (graph.x - self.input_mean) / self.input_std
        
        # Convert stress from Pa to kPa
        graph.stress = graph.stress * 0.001

        return graph