Loss¶
- class metatrain.utils.loss.LossInterface(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
ABC
Abstract base for all loss functions.
Subclasses must implement the
compute
method.- Parameters:
- abstractmethod compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor [source]¶
Compute the loss value.
- Parameters:
- Returns:
Value of the loss.
- Return type:
- class metatrain.utils.loss.BaseTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
LossInterface
Backbone for pointwise losses on
TensorMap
entries.Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss.
- Parameters:
- compute_flattened(tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: TensorMap | None = None) Tensor [source]¶
Flatten prediction and target blocks (and optional mask), then apply the torch loss.
- class metatrain.utils.loss.MaskedTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
BaseTensorMapLoss
Pointwise masked loss on
TensorMap
entries.Inherits flattening and torch-loss logic from BaseTensorMapLoss.
- compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor [source]¶
Gather and flatten target and prediction blocks, then compute loss.
- Parameters:
predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target
name
used in the constructor, there is a corresponding data fieldname + "_mask"
that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.
- Returns:
Scalar loss tensor.
- Return type:
- class metatrain.utils.loss.TensorMapMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLoss
Unmasked mean-squared error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLoss
Unmasked mean-absolute error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
BaseTensorMapLoss
Unmasked Huber loss on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapMaskedMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLoss
Masked mean-squared error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapMaskedMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLoss
Masked mean-absolute error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapMaskedHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
MaskedTensorMapLoss
Masked Huber loss on
TensorMap
entries.
- class metatrain.utils.loss.LossAggregator(targets: Dict[str, TargetInfo], config: Dict[str, Dict[str, Any]])[source]¶
Bases:
LossInterface
Aggregate multiple
LossInterface
terms with scheduled weights and metadata.- Parameters:
- class metatrain.utils.loss.LossType(*values)[source]¶
Bases:
Enum
Enumeration of available loss types and their implementing classes.
- Parameters:
key – string key for the loss type.
cls – class implementing the loss type.
- MSE = ('mse', <class 'metatrain.utils.loss.TensorMapMSELoss'>)¶
- MAE = ('mae', <class 'metatrain.utils.loss.TensorMapMAELoss'>)¶
- HUBER = ('huber', <class 'metatrain.utils.loss.TensorMapHuberLoss'>)¶
- MASKED_MSE = ('masked_mse', <class 'metatrain.utils.loss.TensorMapMaskedMSELoss'>)¶
- MASKED_MAE = ('masked_mae', <class 'metatrain.utils.loss.TensorMapMaskedMAELoss'>)¶
- MASKED_HUBER = ('masked_huber', <class 'metatrain.utils.loss.TensorMapMaskedHuberLoss'>)¶
- POINTWISE = ('pointwise', <class 'metatrain.utils.loss.BaseTensorMapLoss'>)¶
- MASKED_POINTWISE = ('masked_pointwise', <class 'metatrain.utils.loss.MaskedTensorMapLoss'>)¶
- property cls: Type[LossInterface]¶
Class implementing this loss type.
- metatrain.utils.loss.create_loss(loss_type: str, *, name: str, gradient: str | None, weight: float, reduction: str, **extra_kwargs: Any) LossInterface [source]¶
Factory to instantiate a concrete
LossInterface
given its string key.- Parameters:
loss_type (str) – string key matching one of the members of
LossType
.name (str) – target name for the loss.
gradient (str | None) – gradient name, if present.
weight (float) – weight of the loss contribution in the final aggregation.
reduction (str) – reduction mode for the torch loss.
**extra_kwargs (Any) – additional hyperparameters specific to the loss type.
- Returns:
instance of the selected loss.
- Return type: