import os
import torch
import h5py
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset

class ImageDataStressBuckling(Dataset):
    """
    This class can be used to load image and scalar input data,
    as well as tabular buckling and stress data.
    Therefore, it can be used for the following datasets:
    - PerfectShell_LinearFEA
    - ImperfectShell_LinearFEA
    """
    def __init__(self, image_folder_path, scalars_path, buckling_path,
        preproc_stress_path, raw_stress_paths, data_format):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_folder_path = image_folder_path
        self.scalars_path = scalars_path
        self.buckling_path = buckling_path
        self.preproc_stress_path = preproc_stress_path
        self.raw_stress_paths = raw_stress_paths
        self.data_format = data_format
        self.samples = pd.read_csv(scalars_path, header=None, usecols=[0]).iloc[:, 0].astype(str).tolist()

        # Load scalar inputs
        self.scalars = self.load_scalars()
        # Load buckling data
        self.buckling_data = self.load_buckling_data()
        # Load preprocessed data
        if self.data_format == 'preprocessed':
            self.preproc_stresses = self.load_preproc_data()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_file = self.image_folder_path.joinpath(f'{self.samples[idx]}.h5')

        # Load image from HDF5
        with h5py.File(image_file, 'r') as hf:
            image = hf['image_data'][()]
        # Add a channel dimension to the image
        image = image[None, :, :]
        image = torch.from_numpy(image).float().to(self.device)

        # Load additional scalar values
        scalars = self.scalars[idx]
        buckling_factor = self.buckling_data[idx]

        if self.data_format == 'preprocessed':
            stress_data = self.preproc_stresses[idx]
        elif self.data_format == 'raw':
            file_ind = idx // 1000
            row_ind = idx % 1000

            with h5py.File(self.raw_stress_paths[file_ind], 'r') as hf:
                stress_data = torch.from_numpy(hf['data'][row_ind, 4:-1]).to(torch.float32)

        return image, scalars, buckling_factor, stress_data
    
    def load_scalars(self):

        scalars = torch.from_numpy(
            pd.read_csv(self.scalars_path, header=None)[[1,2,3,4]].to_numpy()
        ).to(torch.float32).to(self.device)

        return scalars
    
    def load_buckling_data(self):

        buckling_data = torch.from_numpy(
            np.loadtxt(
                self.buckling_path, delimiter=','
            )
        ).to(torch.float32).to(self.device).unsqueeze(1)

        return buckling_data
    
    def load_preproc_data(self):
        
        # Load stresses
        with h5py.File(self.preproc_stress_path, 'r') as hf:
            preproc_stresses = torch.from_numpy(
                hf['transformed_data'][:, 4:]
            ).to(torch.float32).to(self.device)

        return preproc_stresses
    
if __name__ == '__main__':

    # Load PerfectShell_LinearFEA dataset
    dataset_split = 'training' # change to 'validation' or 'testing' to load the other splits
    data_format = 'preprocessed' # change to 'raw' to use the raw data
    TABULAR_OUTPUT_ROOT_PATH = Path.cwd().joinpath(
        './datasets/PerfectShell_LinearFEA/tabular'
    )
    ROOT_PATH = Path.cwd().joinpath(
        './datasets/PerfectShell_LinearFEA/images'
    )

    # inputs
    image_folder_path = ROOT_PATH.joinpath(
        './input/{:s}/height_map_99x99'.format(dataset_split)
    )
    scalars_path = ROOT_PATH.joinpath(
        './input/{:s}/{:s}_scalars.csv'.format(dataset_split, dataset_split)
    )

    # outputs
    buckling_path = TABULAR_OUTPUT_ROOT_PATH.joinpath(
        './output/{:s}/buckling_output/buckling_{:s}.csv'.format(
            dataset_split, dataset_split
        )
    )
    preproc_stress_path = TABULAR_OUTPUT_ROOT_PATH.joinpath(
        './output/{:s}/pca_output/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_stress_paths = [
        TABULAR_OUTPUT_ROOT_PATH.joinpath(
            './output/{:s}/raw_output/s_GQ_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                TABULAR_OUTPUT_ROOT_PATH.joinpath(
                    './output/{:s}/raw_output'.format(dataset_split)
                )
            ))
        )
    ]

    dataset = ImageDataStressBuckling(
        image_folder_path, scalars_path, buckling_path,
        preproc_stress_path, raw_stress_paths, data_format
    )

    print('*** PerfectShell_LinearFEA dataset ***')
    print('Amount of samples in the {:s} dataset = {:d}'.format(dataset_split, len(dataset)))
    print('Example of a single sample: {}'.format(dataset[0]))

    # Load ImperfectShell_LinearFEA dataset
    dataset_split = 'training' # change to 'validation' or 'testing' to load the other splits
    data_format = 'preprocessed' # change to 'raw' to use the raw data
    TABULAR_OUTPUT_ROOT_PATH = Path.cwd().joinpath(
        './datasets/ImperfectShell_LinearFEA/tabular'
    )
    ROOT_PATH = Path.cwd().joinpath(
        './datasets/ImperfectShell_LinearFEA/images'
    )

    # inputs
    image_folder_path = ROOT_PATH.joinpath(
        './input/{:s}/imperfection_map_99x99'.format(dataset_split)
    )
    scalars_path = ROOT_PATH.joinpath(
        './input/{:s}/{:s}_scalars.csv'.format(dataset_split, dataset_split)
    )

    # outputs
    buckling_path = TABULAR_OUTPUT_ROOT_PATH.joinpath(
        './output/{:s}/buckling_output/buckling_{:s}.csv'.format(
            dataset_split, dataset_split
        )
    )
    preproc_stress_path = TABULAR_OUTPUT_ROOT_PATH.joinpath(
        './output/{:s}/pca_output/pca_{:s}.h5'.format(dataset_split, dataset_split)
    )
    raw_stress_paths = [
        TABULAR_OUTPUT_ROOT_PATH.joinpath(
            './output/{:s}/raw_output/s_GQ_{:d}.h5'.format(dataset_split, i)
        ) for i in range(
            len(os.listdir(
                TABULAR_OUTPUT_ROOT_PATH.joinpath(
                    './output/{:s}/raw_output'.format(dataset_split)
                )
            ))
        )
    ]

    dataset = ImageDataStressBuckling(
        image_folder_path, scalars_path, buckling_path,
        preproc_stress_path, raw_stress_paths, data_format
    )

    print('*** ImperfectShell_LinearFEA dataset ***')
    print('Amount of samples in the {:s} dataset = {:d}'.format(dataset_split, len(dataset)))
    print('Example of a single sample: {}'.format(dataset[0]))
    