Module nujo.objective.quantitative

Expand source code
from nujo.autodiff.tensor import Tensor
from nujo.math.scalar import abs
from nujo.objective.loss import QuantitativeLoss

__all__ = [
    'L1Loss',
    'L2Loss',
]

# ====================================================================================================


class L1Loss(QuantitativeLoss):
    ''' L1 loss (or Absolute Error)

        | ÿ - y |

    '''
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return self.reduction_fn(abs(input - target),
                                 dim=self.dim,
                                 keepdim=self.keepdim)


# ====================================================================================================


class L2Loss(QuantitativeLoss):
    ''' L2 loss (or Squared Error)

        (ÿ - y)^2

    '''
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return self.reduction_fn((input - target)**2,
                                 dim=self.dim,
                                 keepdim=self.keepdim)


# ====================================================================================================

Classes

class L1Loss (*args, **kwargs)

L1 loss (or Absolute Error)

| ÿ - y |

Expand source code
class L1Loss(QuantitativeLoss):
    ''' L1 loss (or Absolute Error)

        | ÿ - y |

    '''
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return self.reduction_fn(abs(input - target),
                                 dim=self.dim,
                                 keepdim=self.keepdim)

Ancestors

Inherited members

class L2Loss (*args, **kwargs)

L2 loss (or Squared Error)

(ÿ - y)^2

Expand source code
class L2Loss(QuantitativeLoss):
    ''' L2 loss (or Squared Error)

        (ÿ - y)^2

    '''
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return self.reduction_fn((input - target)**2,
                                 dim=self.dim,
                                 keepdim=self.keepdim)

Ancestors

Inherited members