import torch
import torch.nn as nn
from torch.nn import Dropout
import torch.nn.init as init
from torch_geometric.nn import BatchNorm, MLP, GeneralConv, Linear
import torch_geometric.nn as nn_geo

class NeuralNetwork(nn.Module):
    """ A Neural Network class that builds a Multilayer Perceptron based on
    a config dictionary """
    def __init__(self, input_size, output_size, config):
        super(NeuralNetwork, self).__init__()

        self.layers = nn.ModuleList()
        for i in range(config['num_layers']):
            # For first layer use input size
            if i == 0:
                self.layers.append(
                    nn.Linear(input_size, config['neurons_per_layer'])
                )
                self.layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.layers.append(nn.Dropout(config['dropout']))
            # For last layer use output size
            elif i == config['num_layers']-1:
                self.layers.append(
                    nn.Linear(config['neurons_per_layer'], output_size)
                )
            # All other layers
            else:
                self.layers.append(
                    nn.Linear(config['neurons_per_layer'], config['neurons_per_layer'])
                )
                self.layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.layers.append(nn.Dropout(config['dropout']))
        
        for layer in self.layers:
            layer_type = str(type(layer)).split("'")[1].split('.')[3]
            if layer_type == 'linear':
                getattr(init, config['initializer'])(layer.weight)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class CNN(nn.Module):
    """ A Convolutional Neural Network class that builds a CNN based on
    a config dictionary """
    def __init__(self, config, input_res, output_size):
        super(CNN, self).__init__()

        self.input_res = input_res

        res = input_res
        self.conv_layers = nn.ModuleList()
        # Initial convolutional layers
        for i in range(config['num_conv_layers']):

            if i == 0:
                self.conv_layers.append(
                    nn.Conv2d(
                        1,
                        config['kernels_per_layer'],
                        kernel_size=config['kernel_size'],
                        stride=1,
                        padding=int(config['kernel_size']/2)
                    )
                )
                
            else:
                self.conv_layers.append(
                    nn.Conv2d(
                        config['kernels_per_layer'],
                        config['kernels_per_layer'],
                        kernel_size=config['kernel_size'],
                        stride=1,
                        padding=int(config['kernel_size']/2)
                    )
                )

            self.conv_layers.append(
                getattr(nn, config['activation'])()
            )
            res = (res-config['kernel_size']+2*(int(config['kernel_size']/2)))+1
            if config['batch_norm']:
                self.conv_layers.append(
                    nn.BatchNorm2d(config['kernels_per_layer'], momentum=0.05)
                )
            self.conv_layers.append(nn.Dropout2d(p=config['dropout']))
            self.conv_layers.append(
                getattr(nn, config['pooling'])(kernel_size=config['pooling_size'], stride=1, padding=int(config['pooling_size']/2))
            )
            # self.conv_layers.append(
            #     nn.MaxPool2d(kernel_size=config['pooling_size'], stride=1, padding=int(config['pooling_size']/2))
            # )
            res = int((res-config['pooling_size']+2*(int(config['pooling_size']/2)))+1)

        # Flatten pixel to linear array
        self.conv_layers.append(nn.Flatten())
        self.conv_layers.append(nn.Linear(config['kernels_per_layer'] * res * res, config['conv_features']))
        self.conv_layers.append(
            getattr(nn, config['activation'])()
        )

        # Dense layers
        self.dense_layers = nn.ModuleList()
        for i in range(config['num_dense_layers']):

            if i == 0:
                self.dense_layers.append(nn.Linear(4+config['conv_features'], config['neurons_per_layer']))
                self.dense_layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.dense_layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.dense_layers.append(nn.Dropout(config['dropout']))
            elif i == config['num_dense_layers']-1:
                self.dense_layers.append(nn.Linear(config['neurons_per_layer'], output_size))
            else:
                self.dense_layers.append(nn.Linear(config['neurons_per_layer'], config['neurons_per_layer']))
                self.dense_layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.dense_layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.dense_layers.append(nn.Dropout(config['dropout']))
        
        for layer in self.dense_layers:
            layer_type = str(type(layer)).split("'")[1].split('.')[3]
            if layer_type == 'linear':
                getattr(init, config['initializer'])(layer.weight)

    def forward(self, x, additional_inputs):

        # Conv layers
        for layer in self.conv_layers:
            x = layer(x)
        
        # Concatenate additional_inputs with the flattened tensor
        x = torch.cat((x, additional_inputs), dim=1)

        # NN layers
        for layer in self.dense_layers:
            x = layer(x)

        return x

class GNN(nn.Module):
    """ A Graph Neural Network class that builds a GNN based on
    a config dictionary """
    def __init__(self, num_features, output_size, config):
        super(GNN, self).__init__()

        self.config = config

        # Preprocessing MLP layers
        self.preproc_layers = nn.ModuleList()
        for i in range(config['num_preproc_layers']):

            if i == 0:
                mlp_config = [num_features] \
                    + [config['mlp_neurons'] for a in range(config['mlp_layers'])] \
                    + [config['hidden_size']]
                
                if config['batch_norm']:
                    self.preproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm='batch_norm'))
                else:
                    self.preproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm=None))
            
            else:
                mlp_config = [config['hidden_size']] \
                    + [config['mlp_neurons'] for a in range(config['mlp_layers'])] \
                    + [config['hidden_size']]
                
                if config['batch_norm']:
                    self.preproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm='batch_norm'))
                else:
                    self.preproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm=None))
        
        # GNN layers
        self.gnn_layers = nn.ModuleList()
        self.skip_co = nn.ModuleList()
        skip_cat_dim = 0
        for i in range(config['num_layers']):

            # if no preproc layers, the fisrt layers need num_features
            if i == 0 and config['num_preproc_layers'] == 0:
                # if no attention
                if config['attention'] == 'none':
                    self.gnn_layers.append(
                        GeneralConv(num_features, config['hidden_size'], aggr=config['aggr'])
                    )
                # if attention
                else:
                    self.gnn_layers.append(
                        GeneralConv(num_features, config['hidden_size'], aggr=config['aggr'], attention=True, attention_type=config['attention'])
                    )
                self.gnn_layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.gnn_layers.append(BatchNorm(config['hidden_size']))
                self.gnn_layers.append(Dropout(config['dropout']))

                if config['layer_connectivity'] == 'sum':
                    self.skip_co.append(Linear(num_features, config['hidden_size']))

                if config['layer_connectivity'] == 'cat':
                    skip_cat_dim += config['hidden_size']
            
            # if no postproc layers, the last layer need to output 1 single val, no activation and no bn/dropout, also no need if skip-cat
            elif i == (config['num_layers']-1) and config['num_postproc_layers'] == 0 and config['layer_connectivity'] != 'cat':
                # if no attention
                if config['attention'] == 'none':
                    self.gnn_layers.append(
                        GeneralConv(config['hidden_size'], output_size, aggr=config['aggr'])
                    )
                # if attention
                else:
                    self.gnn_layers.append(
                        GeneralConv(config['hidden_size'], output_size, aggr=config['aggr'], attention=True, attention_type=config['attention'])
                    )
                if config['layer_connectivity'] == 'sum':
                    self.skip_co.append(Linear(config['hidden_size'], output_size))

            else:
                # if no attention
                if config['attention'] == 'none':
                    self.gnn_layers.append(
                        GeneralConv(config['hidden_size'], config['hidden_size'], aggr=config['aggr'])
                    )
                # if attention
                else:
                    self.gnn_layers.append(
                        GeneralConv(config['hidden_size'], config['hidden_size'], aggr=config['aggr'], attention=True, attention_type=config['attention'])
                    )
                self.gnn_layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.gnn_layers.append(BatchNorm(config['hidden_size']))
                self.gnn_layers.append(Dropout(config['dropout']))

                if config['layer_connectivity'] == 'sum':
                    self.skip_co.append(Linear(config['hidden_size'], config['hidden_size']))
                
                if config['layer_connectivity'] == 'cat':
                    skip_cat_dim += config['hidden_size']
        
        # if skip-cat define concat layer
        if config['layer_connectivity'] == 'cat':
            if config['num_postproc_layers'] == 0:
                self.skip_co = Linear(skip_cat_dim, output_size)
            else:
                self.skip_co = Linear(skip_cat_dim, config['hidden_size'])
        
        # Postprocessing layers
        self.postproc_layers = nn.ModuleList()
        for i in range(config['num_postproc_layers']):

            if i == (config['num_postproc_layers'] - 1):
                mlp_config = [config['hidden_size']] \
                    + [config['mlp_neurons'] for a in range(config['mlp_layers'])] \
                    + [output_size]
                
                if config['batch_norm']:
                    self.postproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm='batch_norm'))
                else:
                    self.postproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm=None))
            
            else:
                mlp_config = [config['hidden_size']] \
                    + [config['mlp_neurons'] for a in range(config['mlp_layers'])] \
                    + [config['hidden_size']]
                
                if config['batch_norm']:
                    self.postproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm='batch_norm'))
                else:
                    self.postproc_layers.append(MLP(mlp_config, act=config['activation'], dropout=config['dropout'], norm=None))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        #preprocessing MLP layers
        for i, layer in enumerate(self.preproc_layers):
            x = layer(x)

        # Messsage-passing layers
        skip_cats = []
        j = -1
        for i, layer in enumerate(self.gnn_layers):
            layer_type = str(type(layer)).split("'")[-2].split('.')[2]
            is_activation = True if str(type(layer)).split("'")[-2].split('.')[3] == 'activation' else False
            if layer_type == 'conv' and self.config['layer_connectivity'] == 'sum':
                j += 1
                x = layer(x=x, edge_index=edge_index) + self.skip_co[j](x)
            elif layer_type == 'conv':
                x = layer(x=x, edge_index=edge_index)
            elif is_activation and self.config['layer_connectivity'] == 'cat':
                x = layer(x)
                skip_cats.append(x)
            else:
                x = layer(x)
        
        # if skip-cat concatenate all layers
        if self.config['layer_connectivity'] == 'cat':
            x = torch.cat(skip_cats, dim=1)
            x = self.skip_co(x)
        
        # Postproc layers
        for layer in self.postproc_layers:
            x = layer(x)
        
        if 'pooling_method' in self.config:
            # Perform a global pooling to get the final node embeddings
            x = getattr(nn_geo, self.config['pooling_method'])(x, data.batch)

        return x