Adding a new loss function¶
This page describes the required classes and files necessary for adding a new
loss function to metatrain
. Defining a new loss can be useful in case some extra
data has to be used to compute the loss.
Loss functions in metatrain
are implemented as subclasses of
metatrain.utils.loss.LossInterface
. This interface defines the
required method compute()
, which takes the model predictions and
the ground truth values as input and returns the computed loss value. The
compute()
method accepts an additional argument extra_data
on top of
predictions
and targets
, that can be used to pass any extra information needed
for the loss computation.
from typing import Dict, Optional
import torch
from metatrain.utils.loss import LossInterface
from metatensor.torch import TensorMap
class NewLoss(LossInterface):
def __init__(
self,
name: str,
gradient: Optional[str],
weight: float,
reduction: str,
) -> None:
...
def compute(
self,
predictions: Dict[str, TensorMap],
targets: Dict[str, TensorMap],
extra_data: Dict[str, TensorMap]
) -> torch.Tensor:
...
Examples of loss functions already implemented in metatrain
are
metatrain.utils.loss.TensorMapMSELoss
and
metatrain.utils.loss.TensorMapMAELoss
. They both inherit from the
metatrain.utils.loss.BaseTensorMapLoss
class, which implements pointwise
losses for metatensor.torch.TensorMap
objects.