Loss

class metatrain.utils.loss.LossInterface(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: ABC

Abstract base for all loss functions.

Subclasses must implement the compute method.

Parameters:
  • name (str) – key in the predictions/targets dict to select the TensorMap.

  • gradient (str | None) – optional name of a gradient field to extract.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch losses (“mean”, “sum”, etc.).

target: str
gradient: str | None
weight: float
reduction: str
loss_kwargs: Dict[str, Any]
abstractmethod compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Compute the loss value.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping from target names to the predictions for those targets.

  • targets (Dict[str, TensorMap]) – mapping from target names to the reference targets.

  • extra_data (Any | None) – Any extra data needed for the loss computation.

Returns:

Value of the loss.

Return type:

Tensor

classmethod from_config(cfg: Dict[str, Any]) LossInterface[source]

Instantiate a loss from a config dict.

Parameters:

cfg (Dict[str, Any]) – keyword args matching the loss constructor.

Returns:

instance of a LossInterface subclass.

Return type:

LossInterface

class metatrain.utils.loss.BaseTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]

Bases: LossInterface

Backbone for pointwise losses on TensorMap entries.

Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn (_Loss) – pre-instantiated torch.nn loss (e.g. MSELoss).

compute_flattened(tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: TensorMap | None = None) Tensor[source]

Flatten prediction and target blocks (and optional mask), then apply the torch loss.

Parameters:
  • tensor_map_predictions_for_target (TensorMap) – predicted TensorMap.

  • tensor_map_targets_for_target (TensorMap) – target TensorMap.

  • tensor_map_mask_for_target (TensorMap | None) – optional mask TensorMap.

Returns:

scalar torch.Tensor of the computed loss.

Return type:

Tensor

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Compute the unmasked pointwise loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping of names to TensorMap.

  • targets (Dict[str, TensorMap]) – mapping of names to TensorMap.

  • extra_data (Any | None) – ignored for unmasked losses.

Returns:

scalar torch.Tensor loss.

Return type:

Tensor

class metatrain.utils.loss.MaskedTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]

Bases: BaseTensorMapLoss

Pointwise masked loss on TensorMap entries.

Inherits flattening and torch-loss logic from BaseTensorMapLoss.

Parameters:
  • name (str)

  • gradient (str | None)

  • weight (float)

  • reduction (str)

  • loss_fn (_Loss)

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]

Gather and flatten target and prediction blocks, then compute loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target name used in the constructor, there is a corresponding data field name + "_mask" that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.

Returns:

Scalar loss tensor.

Return type:

Tensor

class metatrain.utils.loss.TensorMapMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: BaseTensorMapLoss

Unmasked mean-squared error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: BaseTensorMapLoss

Unmasked mean-absolute error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]

Bases: BaseTensorMapLoss

Unmasked Huber loss on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

  • delta (float) – threshold parameter for HuberLoss.

class metatrain.utils.loss.TensorMapMaskedMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: MaskedTensorMapLoss

Masked mean-squared error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapMaskedMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: MaskedTensorMapLoss

Masked mean-absolute error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapMaskedHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]

Bases: MaskedTensorMapLoss

Masked Huber loss on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

  • delta (float) – threshold parameter for HuberLoss.

class metatrain.utils.loss.LossAggregator(targets: Dict[str, TargetInfo], config: Dict[str, Dict[str, Any]])[source]

Bases: LossInterface

Aggregate multiple LossInterface terms with scheduled weights and metadata.

Parameters:
  • targets (Dict[str, TargetInfo]) – mapping from target names to TargetInfo.

  • config (Dict[str, Dict[str, Any]]) – per-target configuration dict.

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Sum over all scheduled losses present in the predictions.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • targets (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • extra_data (Any | None) – Any extra data needed for the loss computation.

Returns:

scalar torch.Tensor with the total loss.

Return type:

Tensor

class metatrain.utils.loss.LossType(*values)[source]

Bases: Enum

Enumeration of available loss types and their implementing classes.

Parameters:
  • key – string key for the loss type.

  • cls – class implementing the loss type.

MSE = ('mse', <class 'metatrain.utils.loss.TensorMapMSELoss'>)
MAE = ('mae', <class 'metatrain.utils.loss.TensorMapMAELoss'>)
HUBER = ('huber', <class 'metatrain.utils.loss.TensorMapHuberLoss'>)
MASKED_MSE = ('masked_mse', <class 'metatrain.utils.loss.TensorMapMaskedMSELoss'>)
MASKED_MAE = ('masked_mae', <class 'metatrain.utils.loss.TensorMapMaskedMAELoss'>)
MASKED_HUBER = ('masked_huber', <class 'metatrain.utils.loss.TensorMapMaskedHuberLoss'>)
POINTWISE = ('pointwise', <class 'metatrain.utils.loss.BaseTensorMapLoss'>)
MASKED_POINTWISE = ('masked_pointwise', <class 'metatrain.utils.loss.MaskedTensorMapLoss'>)
property key: str

String key for this loss type.

property cls: Type[LossInterface]

Class implementing this loss type.

classmethod from_key(key: str) LossType[source]

Look up a LossType by its string key.

Parameters:

key (str) – key that identifies the loss type.

Raises:

ValueError – if the key is not valid.

Returns:

the matching LossType enum member.

Return type:

LossType

metatrain.utils.loss.create_loss(loss_type: str, *, name: str, gradient: str | None, weight: float, reduction: str, **extra_kwargs: Any) LossInterface[source]

Factory to instantiate a concrete LossInterface given its string key.

Parameters:
  • loss_type (str) – string key matching one of the members of LossType.

  • name (str) – target name for the loss.

  • gradient (str | None) – gradient name, if present.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for the torch loss.

  • **extra_kwargs (Any) – additional hyperparameters specific to the loss type.

Returns:

instance of the selected loss.

Return type:

LossInterface