import torch
import torch.nn as nn
import torch.nn.functional as F

class RMSELoss(nn.Module):
    """ A class to compute the RMSE loss """
    def __init__(self):
        super(RMSELoss, self).__init__()

    def forward(self, prediction, target):
        mse_loss = nn.MSELoss()(prediction, target)
        rmse_loss = torch.sqrt(mse_loss)
        return rmse_loss

class MAEMetric(nn.Module):
    def __init__(self, proportion=0.05):
        self.proportion = proportion
        super(MAEMetric, self).__init__()

    def forward(self, predicted, target):
        """
        Compute MAPE between predicted and target tensors for the lowest 5% values in target.

        Args:
            predicted (torch.Tensor): Tensor containing predicted values.
            target (torch.Tensor): Tensor containing target values.

        Returns:
            torch.Tensor: Mean Absolute Percentage Error (MAPE) for the lowest 5% values in target.
        """
        # Ensure both tensors have the same shape
        assert predicted.shape == target.shape, "Shapes of predicted and target tensors must match."

        # Sort target tensor and predictions accordingly
        sorted_target, indices = torch.sort(target)
        sorted_predicted = predicted[indices]

        # Compute number of elements for 5% lowest values
        num_lowest_values = int(self.proportion * sorted_target.numel())

        # Select the lowest 5% values
        lowest_target = sorted_target[:num_lowest_values]
        lowest_predicted = sorted_predicted[:num_lowest_values]

        # Calculate MAPE for the selected lowest values
        absolute_error = torch.abs(lowest_target - lowest_predicted)
        mape_lowest = torch.mean(absolute_error)

        return mape_lowest

    def reset(self):
        """
        Reset the metric to its initial state.
        """
        pass  # Since there are no accumulated state variables, no need to reset anything


class MAPEPercent(nn.Module):
    def __init__(self, proportion=0.05):
        self.proportion = proportion
        super(MAPEPercent, self).__init__()

    def forward(self, predicted, target):
        """
        Compute MAPE between predicted and target tensors for the lowest 5% values in target.

        Args:
            predicted (torch.Tensor): Tensor containing predicted values.
            target (torch.Tensor): Tensor containing target values.

        Returns:
            torch.Tensor: Mean Absolute Percentage Error (MAPE) for the lowest 5% values in target.
        """
        # Ensure both tensors have the same shape
        assert predicted.shape == target.shape, "Shapes of predicted and target tensors must match."

        # Sort target tensor and predictions accordingly
        sorted_target, indices = torch.sort(target)
        sorted_predicted = predicted[indices]

        # Compute number of elements for 5% lowest values
        num_lowest_values = int(self.proportion * sorted_target.numel())

        # Select the lowest 5% values
        lowest_target = sorted_target[:num_lowest_values]
        lowest_predicted = sorted_predicted[:num_lowest_values]

        # Calculate MAPE for the selected lowest values
        absolute_percentage_error = torch.abs((lowest_target - lowest_predicted) / lowest_target)
        mape_lowest = 100 * torch.mean(absolute_percentage_error)

        return mape_lowest

    def reset(self):
        """
        Reset the metric to its initial state.
        """
        pass  # Since there are no accumulated state variables, no need to reset anything

class RMSEMetric(torch.nn.Module):
    def __init__(self):
        super(RMSEMetric, self).__init__()

    def forward(self, predicted, target):
        """
        Compute RMSE between predicted and target tensors.
        
        Args:
            predicted (torch.Tensor): Tensor containing predicted values.
            target (torch.Tensor): Tensor containing target values.

        Returns:
            float: Root Mean Squared Error (RMSE).
        """
        # Ensure both tensors have the same shape
        assert predicted.shape == target.shape, "Shapes of predicted and target tensors must match."

        # Compute RMSE
        rmse = torch.sqrt(F.mse_loss(predicted, target))

        return rmse

    def reset(self):
        """
        Reset the metric to its initial state.
        """
        pass  # Since there are no accumulated state variables, no need to reset anything

class AbsolutePercentageError(nn.Module):
    def __init__(self):
        super(AbsolutePercentageError, self).__init__()

    def forward(self, predicted, target):
        """
        Compute absolute percentage error between predicted and target tensors.

        Args:
            predicted (torch.Tensor): Tensor containing predicted values.
            target (torch.Tensor): Tensor containing target values.

        Returns:
            torch.Tensor: Absolute percentage error.
        """
        # Ensure both tensors have the same shape
        assert predicted.shape == target.shape, "Shapes of predicted and target tensors must match."

        # Compute absolute percentage error
        abs_percentage_error = torch.abs((predicted - target) / target) * 100

        return abs_percentage_error

    def reset(self):
        """
        Reset the metric to its initial state.
        """
        pass  # Since there are no accumulated state variables, no need to reset anything