import torch
import h5py
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

class TabularDataBuckling(Dataset):
    """ A class for handling tabular buckling data from the ConcreteShellFEA dataset """
    def __init__(self, perfect_input_path, imperfect_input_path, perfect_buckling_path, imperfect_buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.perfect_input_path = perfect_input_path
        self.imperfect_input_path = imperfect_input_path
        self.perfect_buckling_path = perfect_buckling_path
        self.imperfect_buckling_path = imperfect_buckling_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load buckling data
        self.perfect_buckling_data, self.imperfect_buckling_data = self.load_buckling_data()
        # Load PCA-transformed inputs
        self.perfect_inputs, self.imperfect_inputs = self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (
                self.shell_characteristics, self.perfect_inputs,
                self.imperfect_inputs, self.perfect_buckling_data
            ),
            dim=1
        )[idx]
        # outputs
        targets = self.imperfect_buckling_data[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.perfect_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics
    
    def load_buckling_data(self):
        perfect_buckling_data = torch.from_numpy(
            np.loadtxt(
                self.perfect_buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)
        imperfect_buckling_data = torch.from_numpy(
            np.loadtxt(
                self.imperfect_buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)
        return perfect_buckling_data, imperfect_buckling_data

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.perfect_input_path, 'r') as hf:
            perfect_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        with h5py.File(self.imperfect_input_path, 'r') as hf:
            imperfect_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return perfect_inputs, imperfect_inputs

class TabularDataStressPCA(Dataset):
    """ A class for handling tabular stress data from the ConcreteShellFEA dataset """
    def __init__(self, perfect_input_path, imperfect_input_path,
            perfect_pca_stress_path, imperfect_pca_stress_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.perfect_input_path = perfect_input_path
        self.imperfect_input_path = imperfect_input_path
        self.perfect_pca_stress_path = perfect_pca_stress_path
        self.imperfect_pca_stress_path = imperfect_pca_stress_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load PCA-transformed inputs
        self.perfect_inputs, self.imperfect_inputs = self.load_preproc_data()
        # Load PCA-transformed stresses
        self.perfect_pca_stresses, self.imperfect_pca_stresses = self.load_stresses()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (
                self.shell_characteristics, self.perfect_inputs,
                self.imperfect_inputs, self.perfect_pca_stresses
            ),
            dim=1
        )[idx]
        # outputs
        targets = self.imperfect_pca_stresses[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.perfect_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.perfect_input_path, 'r') as hf:
            perfect_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        with h5py.File(self.imperfect_input_path, 'r') as hf:
            imperfect_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return perfect_inputs, imperfect_inputs
    
    def load_stresses(self):
        # Load stresses
        with h5py.File(self.perfect_pca_stress_path, 'r') as hf:
            perfect_pca_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        with h5py.File(self.imperfect_pca_stress_path, 'r') as hf:
            imperfect_pca_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return perfect_pca_stresses, imperfect_pca_stresses

class ImageDataBuckling(Dataset):
    """ A class for handling image buckling data from the ConcreteShellFEA dataset """
    def __init__(self, image_folder_path, perfect_scalars_path, imperfect_scalars_path,
                 perfect_buckling_path, imperfect_buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_folder_path = image_folder_path
        self.perfect_scalars_path = perfect_scalars_path
        self.perfect_buckling_path = perfect_buckling_path
        self.imperfect_buckling_path = imperfect_buckling_path
        self.samples = pd.read_csv(imperfect_scalars_path, header=None, usecols=[0]).iloc[:, 0].astype(str).tolist()
        self.normalize_scalars = False

        # Load scalar inputs
        self.perfect_scalars = self.load_scalars()
        # Load buckling data
        self.perfect_buckling_data, self.imperfect_buckling_data = self.load_buckling_data()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_file = self.image_folder_path.joinpath(f'{self.samples[idx]}.h5')

        # Load image from HDF5
        with h5py.File(image_file, 'r') as hf:
            image = hf['image_data'][()]
        image = torch.from_numpy(image).float().to(self.device)

        # Load additional scalar values
        scalars = torch.cat(
            (
                self.perfect_scalars, self.perfect_buckling_data
            ),
            dim=1
        )[idx]
        imperfect_buckling_factor = self.imperfect_buckling_data[idx]

        # Normalize scalar input if needed
        if self.normalize_scalars:
            scalars = (scalars - self.input_mean) / self.input_std

        return image, scalars, imperfect_buckling_factor
    
    def load_scalars(self):
        perfect_scalars = torch.from_numpy(
            pd.read_csv(self.perfect_scalars_path, header=None)[[1,2,3,4]].to_numpy()
        ).to(torch.float32).to(self.device)
        return perfect_scalars
    
    def load_buckling_data(self):
        perfect_buckling_data = torch.from_numpy(
            np.loadtxt(
                self.perfect_buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)
        imperfect_buckling_data = torch.from_numpy(
            np.loadtxt(
                self.imperfect_buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)
        return perfect_buckling_data, imperfect_buckling_data

class ImageDataStressPCA(Dataset):
    """ A class for handling image stress data from the ConcreteShellFEA dataset """
    def __init__(self, image_folder_path, perfect_scalars_path, imperfect_scalars_path,
                 perfect_pca_stress_path, imperfect_pca_stress_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_folder_path = image_folder_path
        self.perfect_scalars_path = perfect_scalars_path
        self.perfect_pca_stress_path = perfect_pca_stress_path
        self.imperfect_pca_stress_path = imperfect_pca_stress_path
        self.samples = pd.read_csv(imperfect_scalars_path, header=None, usecols=[0]).iloc[:, 0].astype(str).tolist()
        self.normalize_scalars = False

        # Load scalar inputs
        self.perfect_scalars = self.load_scalars()
        # Load PCA-transformed stresses
        self.perfect_pca_stresses, self.imperfect_pca_stresses = self.load_stresses()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_file = self.image_folder_path.joinpath(f'{self.samples[idx]}.h5')

        # Load image from HDF5
        with h5py.File(image_file, 'r') as hf:
            image = hf['image_data'][()]
        image = torch.from_numpy(image).float().to(self.device)

        # Load additional scalar values
        scalars = torch.cat(
            (
                self.perfect_scalars, self.perfect_pca_stresses
            ),
            dim=1
        )[idx]
        imperfect_pca_stresses = self.imperfect_pca_stresses[idx]

        # Normalize scalar input if needed
        if self.normalize_scalars:
            scalars = (scalars - self.input_mean) / self.input_std

        return image, scalars, imperfect_pca_stresses
    
    def load_scalars(self):
        perfect_scalars = torch.from_numpy(
            pd.read_csv(self.perfect_scalars_path, header=None)[[1,2,3,4]].to_numpy()
        ).to(torch.float32).to(self.device)
        return perfect_scalars
    
    def load_stresses(self):
        # Load stresses
        with h5py.File(self.perfect_pca_stress_path, 'r') as hf:
            perfect_pca_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)

        with h5py.File(self.imperfect_pca_stress_path, 'r') as hf:
            imperfect_pca_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return perfect_pca_stresses, imperfect_pca_stresses