import torch
import torch.nn as nn
import torch.nn.init as init

class NeuralNetwork(nn.Module):
    """ A Neural Network class that builds a Multilayer Perceptron based on
    a config dictionary """
    def __init__(self, input_size, output_size, config):
        super(NeuralNetwork, self).__init__()

        self.layers = nn.ModuleList()
        for i in range(config['num_layers']):
            # For first layer use input size
            if i == 0:
                self.layers.append(
                    nn.Linear(input_size, config['neurons_per_layer'])
                )
                self.layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.layers.append(nn.Dropout(config['dropout']))
            # For last layer use output size
            elif i == config['num_layers']-1:
                self.layers.append(
                    nn.Linear(config['neurons_per_layer'], output_size)
                )
            # All other layers
            else:
                self.layers.append(
                    nn.Linear(config['neurons_per_layer'], config['neurons_per_layer'])
                )
                self.layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.layers.append(nn.Dropout(config['dropout']))
        
        for layer in self.layers:
            layer_type = str(type(layer)).split("'")[1].split('.')[3]
            if layer_type == 'linear':
                getattr(init, config['initializer'])(layer.weight)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class CNN(nn.Module):
    """ A Convolutional Neural Network class that builds a CNN based on
     a config dictionary """
    def __init__(self, config, input_res, additional_inputs_dim, output_dim):
        super(CNN, self).__init__()

        self.input_res = input_res

        res = input_res
        self.conv_layers = nn.ModuleList()
        # Initial convolutional layers
        for i in range(config['num_conv_layers']):

            if i == 0:
                self.conv_layers.append(
                    nn.Conv2d(
                        3,
                        config['kernels_per_layer'],
                        kernel_size=config['kernel_size'],
                        stride=1,
                        padding=int(config['kernel_size']/2)
                    )
                )
                
            else:
                self.conv_layers.append(
                    nn.Conv2d(
                        config['kernels_per_layer'],
                        config['kernels_per_layer'],
                        kernel_size=config['kernel_size'],
                        stride=1,
                        padding=int(config['kernel_size']/2)
                    )
                )

            self.conv_layers.append(
                getattr(nn, config['activation'])()
            )
            res = (res-config['kernel_size']+2*(int(config['kernel_size']/2)))+1
            if config['batch_norm']:
                self.conv_layers.append(
                    nn.BatchNorm2d(config['kernels_per_layer'], momentum=0.05)
                )
            self.conv_layers.append(nn.Dropout2d(p=config['dropout']))
            self.conv_layers.append(
                getattr(nn, config['pooling'])(kernel_size=config['pooling_size'], stride=1, padding=int(config['pooling_size']/2))
            )
            # self.conv_layers.append(
            #     nn.MaxPool2d(kernel_size=config['pooling_size'], stride=1, padding=int(config['pooling_size']/2))
            # )
            res = int((res-config['pooling_size']+2*(int(config['pooling_size']/2)))+1)

        # Flatten pixel to linear array
        self.conv_layers.append(nn.Flatten())
        self.conv_layers.append(nn.Linear(config['kernels_per_layer'] * res * res, config['conv_features']))
        self.conv_layers.append(
            getattr(nn, config['activation'])()
        )

        # Dense layers
        self.dense_layers = nn.ModuleList()
        for i in range(config['num_dense_layers']):

            if i == 0:
                self.dense_layers.append(nn.Linear(additional_inputs_dim+config['conv_features'], config['neurons_per_layer']))
                self.dense_layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.dense_layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.dense_layers.append(nn.Dropout(config['dropout']))
            elif i == config['num_dense_layers']-1:
                self.dense_layers.append(nn.Linear(config['neurons_per_layer'], output_dim))
            else:
                self.dense_layers.append(nn.Linear(config['neurons_per_layer'], config['neurons_per_layer']))
                self.dense_layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.dense_layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.dense_layers.append(nn.Dropout(config['dropout']))
        
        for layer in self.dense_layers:
            layer_type = str(type(layer)).split("'")[1].split('.')[3]
            if layer_type == 'linear':
                getattr(init, config['initializer'])(layer.weight)

    def forward(self, x, additional_inputs):

        # Conv layers
        for layer in self.conv_layers:
            x = layer(x)
        
        # Concatenate additional_inputs with the flattened tensor
        x = torch.cat((x, additional_inputs), dim=1)

        # NN layers
        for layer in self.dense_layers:
            x = layer(x)

        return x