import torch.nn as nn
import torch.nn.init as init

class NeuralNetwork(nn.Module):
    """ A Neural Network class that builds a Multilayer Perceptron based on
    a config dictionary """
    def __init__(self, input_size, output_size, config):
        super(NeuralNetwork, self).__init__()

        self.layers = nn.ModuleList()
        for i in range(config['num_layers']):
            # For first layer use input size
            if i == 0:
                self.layers.append(
                    nn.Linear(input_size, config['neurons_per_layer'])
                )
                self.layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.layers.append(nn.Dropout(config['dropout']))
            # For last layer use output size
            elif i == config['num_layers']-1:
                self.layers.append(
                    nn.Linear(config['neurons_per_layer'], output_size)
                )
            # All other layers
            else:
                self.layers.append(
                    nn.Linear(config['neurons_per_layer'], config['neurons_per_layer'])
                )
                self.layers.append(
                    getattr(nn, config['activation'])()
                )
                if config['batch_norm']:
                    self.layers.append(nn.BatchNorm1d(config['neurons_per_layer']))
                self.layers.append(nn.Dropout(config['dropout']))
        
        for layer in self.layers:
            layer_type = str(type(layer)).split("'")[1].split('.')[3]
            if layer_type == 'linear':
                getattr(init, config['initializer'])(layer.weight)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x