Adding a new loss function

This page describes the required classes and files necessary for adding a new loss function to metatrain. Defining a new loss can be useful in case some extra data has to be used to compute the loss.

Loss functions in metatrain are implemented as subclasses of metatrain.utils.loss.LossInterface. This interface defines the required method compute(), which takes the model predictions and the ground truth values as input and returns the computed loss value. The compute() method accepts an additional argument extra_data on top of predictions and targets, that can be used to pass any extra information needed for the loss computation.

from typing import Dict, Optional
import torch
from metatrain.utils.loss import LossInterface
from metatensor.torch import TensorMap

class NewLoss(LossInterface):
    def __init__(
        self,
        name: str,
        gradient: Optional[str],
        weight: float,
        reduction: str,
    ) -> None:
        ...

    def compute(
        self,
        predictions: Dict[str, TensorMap],
        targets: Dict[str, TensorMap],
        extra_data: Dict[str, TensorMap]
    ) -> torch.Tensor:
        ...

Examples of loss functions already implemented in metatrain are metatrain.utils.loss.TensorMapMSELoss and metatrain.utils.loss.TensorMapMAELoss. They both inherit from the metatrain.utils.loss.BaseTensorMapLoss class, which implements pointwise losses for metatensor.torch.TensorMap objects.