import os
import torch
import h5py
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset

class TabularDataStressBuckling(Dataset):
    """
    This class can be used to load tabular buckling and stress data.
    Therefore, it can be used for the following datasets:
    - PerfectShell_LinearFEA
    - ImperfectShell_LinearFEA
    """
    def __init__(self, preproc_input_path, raw_input_paths, buckling_path,
        preproc_stress_path, raw_stress_paths, data_format):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.raw_input_paths = raw_input_paths
        self.buckling_path = buckling_path
        self.preproc_stress_path = preproc_stress_path
        self.raw_stress_paths = raw_stress_paths
        self.data_format = data_format

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()

        # Load buckling data
        self.buckling_data = self.load_buckling_data()
        # Load preprocessed data
        if self.data_format == 'preprocessed':
            self.preproc_inputs, self.preproc_stresses = self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):

        shell_specs = self.shell_characteristics[idx]
        buckling_factor = self.buckling_data[idx]

        if self.data_format == 'preprocessed':
            pts_data = self.preproc_inputs[idx]
            stress_data = self.preproc_stresses[idx]
        elif self.data_format == 'raw':
            file_ind = idx // 1000
            row_ind = idx % 1000
            with h5py.File(self.raw_input_paths[file_ind], 'r') as hf:
                pts_data = torch.from_numpy(hf['data'][row_ind, 4:]).to(torch.float32)

            with h5py.File(self.raw_stress_paths[file_ind], 'r') as hf:
                stress_data = torch.from_numpy(hf['data'][row_ind, 4:-1]).to(torch.float32)

        return shell_specs, pts_data, buckling_factor, stress_data
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        
        return shell_characteristics

    
    def load_buckling_data(self):

        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)

        return buckling_data
        

    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)
        
        # Load stresses
        with h5py.File(self.preproc_stress_path, 'r') as hf:
            preproc_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)

        return preproc_inputs, preproc_stresses

class TabularDataBucklingOnly(Dataset):
    """
    This class can be used to load buckling-only data.
    Therefore, it can be used for the following datasets:
    - PerfectShell_NonlinearFEA/linear
    - PerfectShell_NonlinearFEA/nonlinear
    """
    def __init__(self, preproc_input_path, raw_input_paths, buckling_path,
                 data_format, mixed_fidelity=False):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_input_path = preproc_input_path
        self.raw_input_paths = raw_input_paths
        self.buckling_path = buckling_path
        self.data_format = data_format
        self.mixed_fidelity = mixed_fidelity

        # Load shell characteristics (span, height, thickness, Young's modulus)
        self.shell_characteristics = self.load_shell_characteristics()
        # Load buckling data
        if mixed_fidelity:
            self.linear_buckling_data, self.nonlinear_buckling_data = self.load_mixed_buckling_data()
        else:
            self.buckling_data = self.load_buckling_data()
        # Load preprocessed data
        if self.data_format == 'preprocessed':
            self.preproc_inputs = self.load_preproc_data()

    def __len__(self):
        return len(self.shell_characteristics)

    def __getitem__(self, idx):

        shell_specs = self.shell_characteristics[idx]

        if self.data_format == 'preprocessed':
            pts_data = self.preproc_inputs[idx]
        elif self.data_format == 'raw':
            file_ind = idx // 1000
            row_ind = idx % 1000
            with h5py.File(self.raw_input_paths[file_ind], 'r') as hf:
                pts_data = torch.from_numpy(hf['data'][row_ind, 4:]).to(torch.float32)
        
        if self.mixed_fidelity:
            linear_buckling_factor = self.linear_buckling_data[idx]
            nonlinear_buckling_factor = self.nonlinear_buckling_data[idx]
            return shell_specs, pts_data, linear_buckling_factor, nonlinear_buckling_factor
        else:
            buckling_factor = self.buckling_data[idx]
            return shell_specs, pts_data, buckling_factor
    
    def load_shell_characteristics(self):
        with h5py.File(self.preproc_input_path, 'r') as hf:
            shell_characteristics = torch.from_numpy(
                hf['transformed_data'][:, :4]
            ).to(torch.float32).to(self.device)
        
        return shell_characteristics

    def load_buckling_data(self):

        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)

        return buckling_data
    
    def load_mixed_buckling_data(self):

        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device)

        linear_buckling_data = buckling_data[:, -3].to(torch.float).to(self.device).unsqueeze(1)
        nonlinear_buckling_data = buckling_data[:, -2].to(torch.float).to(self.device).unsqueeze(1)

        return linear_buckling_data, nonlinear_buckling_data
        
    def load_preproc_data(self):
        # Load inputs
        with h5py.File(self.preproc_input_path, 'r') as hf:
            preproc_inputs = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)

        return preproc_inputs
    
if __name__ == '__main__':

    # Load PerfectShell_LinearFEA dataset
    dataset_split = 'training' # change to 'validation' or 'testing' to load the other splits
    data_format = 'preprocessed' # change to 'raw' to use the raw data
    ROOT_PATH = Path.cwd().joinpath(
        './datasets/PerfectShell_LinearFEA/tabular'
    )

    # inputs
    preproc_input_path = ROOT_PATH.joinpath(
        './input/{:s}/pca_input/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_input_paths = [
        ROOT_PATH.joinpath(
            './input/{:s}/raw_input/pts_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                ROOT_PATH.joinpath('./input/{:s}/raw_input'.format(dataset_split))
            ))
        )
    ]

    # outputs
    buckling_path = ROOT_PATH.joinpath(
        './output/{:s}/buckling_output/buckling_{:s}.csv'.format(dataset_split, dataset_split)
    )
    preproc_stress_path = ROOT_PATH.joinpath(
        './output/{:s}/pca_output/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_stress_paths = [
        ROOT_PATH.joinpath(
            './output/{:s}/raw_output/s_GQ_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                ROOT_PATH.joinpath('./output/{:s}/raw_output'.format(dataset_split))
            ))
        )
    ]

    dataset = TabularDataStressBuckling(
        preproc_input_path, raw_input_paths, buckling_path,
        preproc_stress_path, raw_stress_paths, data_format
    )

    print('*** PerfectShell_LinearFEA dataset ***')
    print('Amount of samples in the {:s} dataset = {:d}'.format(dataset_split, len(dataset)))
    print('Example of a single sample: {}'.format(dataset[0]))

    # Load ImperfectShell_LinearFEA dataset
    dataset_split = 'training' # change to 'validation' or 'testing' to load the other splits
    data_format = 'preprocessed' # change to 'raw' to use the raw data
    ROOT_PATH = Path.cwd().joinpath(
        './datasets/ImperfectShell_LinearFEA/tabular'
    )

    # inputs
    preproc_input_path = ROOT_PATH.joinpath(
        './input/{:s}/pca_input/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_input_paths = [
        ROOT_PATH.joinpath(
            './input/{:s}/raw_input/pts_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                ROOT_PATH.joinpath('./input/{:s}/raw_input'.format(dataset_split))
            ))
        )
    ]

    # outputs
    buckling_path = ROOT_PATH.joinpath(
        './output/{:s}/buckling_output/buckling_{:s}.csv'.format(dataset_split, dataset_split)
    )
    preproc_stress_path = ROOT_PATH.joinpath(
        './output/{:s}/pca_output/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_stress_paths = [
        ROOT_PATH.joinpath(
            './output/{:s}/raw_output/s_GQ_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                ROOT_PATH.joinpath('./output/{:s}/raw_output'.format(dataset_split))
            ))
        )
    ]

    dataset = TabularDataStressBuckling(
        preproc_input_path, raw_input_paths, buckling_path,
        preproc_stress_path, raw_stress_paths, data_format
    )

    print('*** ImperfectShell_LinearFEA dataset ***')
    print('Amount of samples in the {:s} dataset = {:d}'.format(dataset_split, len(dataset)))
    print('Example of a single sample: {}'.format(dataset[0]))

    # Load PerfectShell_NonlinearFEA/linear dataset
    dataset_split = 'training' # change to 'validation' or 'testing' to load the other splits
    data_format = 'preprocessed' # change to 'raw' to use the raw data
    ROOT_PATH = Path.cwd().joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/linear'
    )

    # inputs
    preproc_input_path = ROOT_PATH.joinpath(
        './input/{:s}/pca_input/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_input_paths = [
        ROOT_PATH.joinpath(
            './input/{:s}/raw_input/pts_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                ROOT_PATH.joinpath('./input/{:s}/raw_input'.format(dataset_split))
            ))
        )
    ]

    # outputs
    buckling_path = ROOT_PATH.joinpath(
        './output/{:s}/buckling_output/buckling_{:s}.csv'.format(dataset_split, dataset_split)
    )

    dataset = TabularDataBucklingOnly(
        preproc_input_path, raw_input_paths, buckling_path, data_format, mixed_fidelity=False
    )

    print('*** PerfectShell_NonlinearFEA/linear dataset ***')
    print('Amount of samples in the {:s} dataset = {:d}'.format(dataset_split, len(dataset)))
    print('Example of a single sample: {}'.format(dataset[0]))

    # Load PerfectShell_NonlinearFEA/nonlinear dataset
    dataset_split = 'training' # change to 'validation' or 'testing' to load the other splits
    data_format = 'preprocessed' # change to 'raw' to use the raw data
    ROOT_PATH = Path.cwd().joinpath(
        './datasets/PerfectShell_NonlinearFEA/tabular/nonlinear'
    )

    # inputs
    preproc_input_path = ROOT_PATH.joinpath(
        './input/{:s}/pca_input/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_input_paths = [
        ROOT_PATH.joinpath(
            './input/{:s}/raw_input/pts_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                ROOT_PATH.joinpath('./input/{:s}/raw_input'.format(dataset_split))
            ))
        )
    ]

    # outputs
    buckling_path = ROOT_PATH.joinpath(
        './output/{:s}/buckling_output/buckling_{:s}.csv'.format(dataset_split, dataset_split)
    )

    dataset = TabularDataBucklingOnly(
        preproc_input_path, raw_input_paths, buckling_path, data_format, mixed_fidelity=True
    )

    print('*** PerfectShell_NonlinearFEA/nonlinear dataset ***')
    print('Amount of samples in the {:s} dataset = {:d}'.format(dataset_split, len(dataset)))
    print('Example of a single sample: {}'.format(dataset[0]))