import torch

def save_checkpoint(model, optimizer, epoch, file_path):
    """ A function to save the state of a model, along with the optimizer and the epoch """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, file_path)

class EarlyStopping:
    """ An early stopping class """
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.early_stop = False
        self.best_score = float('inf')

    def __call__(self, val_loss):

        if (val_loss + self.delta) < self.best_score:
            self.best_score = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        
        return self.early_stop