import torch
import h5py
import numpy as np
from torch.utils.data import Dataset

class TabularLFBuckling(Dataset):
    """ A class for handling low-fidelity tabular buckling data from the ConcreteShellFEA dataset """
    def __init__(self, preproc_input_path, buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.buckling_path = buckling_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load buckling data
        self.buckling_data = self.load_buckling_data()
        # Load PCA-transformed inputs
        self.preproc_inputs= self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (self.shell_characteristics, self.preproc_inputs),
            dim=1
        )[idx]
        # outputs
        targets = self.buckling_data[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics
    
    def load_buckling_data(self):
        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)
        return buckling_data

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return preproc_inputs

class TabularHFBuckling(Dataset):
    """ A class for handling high-fidelity tabular buckling data from the ConcreteShellFEA dataset """
    def __init__(self, preproc_input_path, buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.buckling_path = buckling_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load buckling data
        self.buckling_data = self.load_buckling_data()
        # Load PCA-transformed inputs
        self.preproc_inputs= self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (self.shell_characteristics, self.preproc_inputs),
            dim=1
        )[idx]
        # outputs
        targets = self.buckling_data[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics
    
    def load_buckling_data(self):
        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        )[:, -2].to(torch.float32).to(self.device).unsqueeze(1)
        return buckling_data

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return preproc_inputs

class TabularMFSequentialBuckling(Dataset):
    """ A class for handling mixed-fidelity tabular buckling data from the
    ConcreteShellFEA dataset for sequential MLPs applications"""
    def __init__(self, preproc_input_path, buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.buckling_path = buckling_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load buckling data
        self.linear_buckling_data, self.nonlinear_buckling_data = self.load_buckling_data()
        # Load PCA-transformed inputs
        self.preproc_inputs= self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (self.shell_characteristics, self.preproc_inputs, self.linear_buckling_data),
            dim=1
        )[idx]
        # outputs
        targets = self.nonlinear_buckling_data[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics
    
    def load_buckling_data(self):
        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device)
        linear_buckling_data = buckling_data[:, -3].to(torch.float).to(self.device).unsqueeze(1)
        nonlinear_buckling_data = buckling_data[:, -2].to(torch.float).to(self.device).unsqueeze(1)

        return linear_buckling_data, nonlinear_buckling_data

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return preproc_inputs

class TabularMFTransferBuckling(Dataset):
    """ A class for handling mixed-fidelity tabular buckling data from the
    ConcreteShellFEA dataset for transfer learning applications"""
    def __init__(self, preproc_input_path, buckling_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.buckling_path = buckling_path
        self.normalize_input = False

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load buckling data
        self.buckling_data = self.load_buckling_data()
        # Load PCA-transformed inputs
        self.preproc_inputs= self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):
        # inputs
        features = torch.cat(
            (self.shell_characteristics, self.preproc_inputs),
            dim=1
        )[idx]
        # outputs
        targets = self.buckling_data[idx]

        # Normalize if needed
        if self.normalize_input:
            features = (features - self.input_mean) / self.input_std

        return features, targets
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        return shell_characteristics
    
    def load_buckling_data(self):
        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        )[:, -2].to(torch.float32).to(self.device).unsqueeze(1)
        return buckling_data

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        return preproc_inputs