import torch
import math
from torch_geometric.data import Dataset
from pathlib import Path

class GraphDataStressBuckling(Dataset):
    """
    This class can be used to load graph data.
    Therefore, it can be used for the following dataset:
    - PerfectShell_LinearFEA
    """
    def __init__(self, graph_folder_path, num_samples, batch_size, transform=None, pre_transform=None, pre_filter=None):
        self.num_samples = num_samples
        self.batch_size = batch_size

        super().__init__(graph_folder_path, transform, pre_transform, pre_filter)

        self.root = graph_folder_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @property
    def processed_file_names(self):
        num_batches = math.ceil(self.num_samples / self.batch_size)
        return [
            self.root.joinpath(
                'batch_{:d}.pt'.format(b_num)
            ) for b_num in range(num_batches)
        ]

    def len(self):
        return self.num_samples

    def get(self, idx):

        file_ind = idx // self.batch_size
        graph_ind = idx % self.batch_size
        file_path = self.processed_file_names[file_ind]
        with open(file_path, 'rb') as f:
            graph_batch = torch.load(f, weights_only=False)
        graph = graph_batch.to_data_list()[graph_ind]

        # Convert stress from Pa to kPa
        graph.stress = graph.stress * 0.001

        return graph

if __name__ == '__main__':

    # Load PerfectShell_LinearFEA dataset
    dataset_split = 'training' # change to 'validation' or 'testing' to load the other splits
    batch_size = 128 # needs to be specified as the data is pre-batched for storage and loading efficiency
    ROOT_PATH = Path.cwd().joinpath(
        './datasets/PerfectShell_LinearFEA/graphs'
    )
    samples_count = {
        'training': 12800,
        'validation': 3200,
        'testing': 4000
    }

    # inputs
    graph_folder_path = ROOT_PATH.joinpath(
        './input_and_output/{:s}/{:d}'.format(dataset_split, batch_size)
    )

    dataset = GraphDataStressBuckling(
        graph_folder_path=graph_folder_path, batch_size=batch_size,
        num_samples=samples_count[dataset_split]
    )

    print('*** PerfectShell_LinearFEA dataset ***')
    print('Amount of samples in the {:s} dataset = {:d}'.format(dataset_split, len(dataset)))
    print('Example of a single sample: {}'.format(dataset[0]))