Loss¶
- class metatrain.utils.loss.LossParams[source]¶
Bases:
TypedDict- type: NotRequired[str] = 'mse'¶
- weight: NotRequired[float] = 1.0¶
- reduction: NotRequired[Literal['none', 'mean', 'sum']] = 'mean'¶
- class metatrain.utils.loss.LossSpecification[source]¶
Bases:
TypedDict- type: NotRequired[str] = 'mse'¶
- weight: NotRequired[float] = 1.0¶
- reduction: NotRequired[Literal['none', 'mean', 'sum']] = 'mean'¶
- gradients: NotRequired[dict[str, LossParams]] = {}¶
- class metatrain.utils.loss.LossInterface(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
ABCAbstract base for all loss functions.
Subclasses must implement the
computemethod.- Parameters:
- abstractmethod compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]¶
Compute the loss value.
- Parameters:
- Returns:
Value of the loss.
- Return type:
- class metatrain.utils.loss.BaseTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
LossInterfaceBackbone for pointwise losses on
TensorMapentries.Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss.
- Parameters:
- compute_flattened(tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: TensorMap | None = None) Tensor[source]¶
Flatten prediction and target blocks (and optional mask), then apply the torch loss.
- class metatrain.utils.loss.MaskedTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
BaseTensorMapLossPointwise masked loss on
TensorMapentries.Inherits flattening and torch-loss logic from BaseTensorMapLoss.
- compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]¶
Gather and flatten target and prediction blocks, then compute loss.
- Parameters:
predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target
nameused in the constructor, there is a corresponding data fieldname + "_mask"that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.
- Returns:
Scalar loss tensor.
- Return type:
- class metatrain.utils.loss.TensorMapMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLossUnmasked mean-squared error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLossUnmasked mean-absolute error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
BaseTensorMapLossUnmasked Huber loss on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMaskedMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLossMasked mean-squared error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMaskedMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLossMasked mean-absolute error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMaskedHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
MaskedTensorMapLossMasked Huber loss on
TensorMapentries.
- class metatrain.utils.loss.MaskedDOSLoss(name: str, gradient: str | None, weight: float, grad_weight: float, int_weight: float, extra_targets: int, reduction: str)[source]¶
Bases:
LossInterfaceMasked DOS loss on
TensorMapentries.- Parameters:
name (str) – key for the dos in the prediction/target dictionary.
gradient (str | None) – optional gradient field name.
weight (float) – weight of the loss contribution in the final aggregation.
grad_weight (float) – Multiplier for the gradient of the unmasked DOS component.
int_weight (float) – Multiplier for the cumulative DOS component.
extra_targets (int) – Number of extra targets predicted by the model.
reduction (str) – reduction mode for torch loss.
- compute(model_predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]¶
Gather and flatten target and prediction blocks, then compute loss.
- Parameters:
model_predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target
nameused in the constructor, there is a corresponding data fieldname + "_mask"that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.
- Returns:
Scalar loss tensor.
- Return type:
- class metatrain.utils.loss.TensorMapEnsembleLoss(name: str, gradient: str | None, weight: float, reduction: str, loss_fn: Module)[source]¶
Bases:
BaseTensorMapLossLoss for ensembles based on
TensorMapentries. Assumes that ensemble is the outermost dimension ofTensorBlockproperties.- Parameters:
- compute_flattened(pred_mean: TensorMap, target: TensorMap, pred_var: TensorMap) Tensor[source]¶
Flatten prediction and target blocks (and optional mask), then apply the torch loss.
- class metatrain.utils.loss.GaussianCRPSLoss(reduction: str = 'mean', eps: float = 1e-12)[source]¶
Bases:
ModuleGaussian CRPS loss.
This implements the closed-form expression for the CRPS of a Gaussian predictive distribution \(\mathcal{N}(\mu, \sigma^2)\) evaluated at a target value \(x\):
\[\text{CRPS}(x; \mu, \sigma) = \sigma \left[ z(2\Phi(z) - 1) + 2\phi(z) - \frac{1}{\sqrt{\pi}} \right]\]where \(z = \frac{x - \mu}{\sigma}\), \(\Phi\) is the standard normal CDF, and \(\phi\) is the standard normal PDF.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatrain.utils.loss.EmpiricalCRPSLoss(reduction: str = 'mean')[source]¶
Bases:
ModuleEmpirical CRPS loss for ensemble predictions.
The ensemble predictions \(\{Y_i\}_{i=1}^M\) for each data point define an empirical predictive distribution:
\[F_M(y) = \frac{1}{M} \sum_{i=1}^M \mathbb{1}_{Y_i \le y}\]The CRPS of this empirical distribution at observation \(z\) has the closed form:
\[\text{CRPS}(F_M, z) = \frac{1}{M} \sum_{i=1}^M |Y_i - z| - \frac{1}{2 M^2} \sum_{i,j} |Y_i - Y_j|\]- Parameters:
reduction (str) – ‘none’, ‘mean’, or ‘sum’.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatrain.utils.loss.TensorMapGaussianNLLLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
TensorMapEnsembleLossGaussian negative log-likelihood loss for
TensorMapentries.
- class metatrain.utils.loss.TensorMapGaussianCRPSLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
TensorMapEnsembleLossGaussian CRPS loss for
TensorMapentries.
- class metatrain.utils.loss.TensorMapEmpiricalCRPSLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
TensorMapEnsembleLossEmpirical CRPS loss for
TensorMapentries.- Parameters:
- class metatrain.utils.loss.LossAggregator(targets: Dict[str, TargetInfo], config: Dict[str, LossSpecification])[source]¶
Bases:
LossInterfaceAggregate multiple
LossInterfaceterms with scheduled weights and metadata.- Parameters:
targets (Dict[str, TargetInfo]) – mapping from target names to
TargetInfo.config (Dict[str, LossSpecification]) – per-target configuration dict.
- class metatrain.utils.loss.LossType(*values)[source]¶
Bases:
EnumEnumeration of available loss types and their implementing classes.
- Parameters:
key – string key for the loss type.
cls – class implementing the loss type.
- MSE = ('mse', <class 'metatrain.utils.loss.TensorMapMSELoss'>)¶
- MAE = ('mae', <class 'metatrain.utils.loss.TensorMapMAELoss'>)¶
- HUBER = ('huber', <class 'metatrain.utils.loss.TensorMapHuberLoss'>)¶
- MASKED_MSE = ('masked_mse', <class 'metatrain.utils.loss.TensorMapMaskedMSELoss'>)¶
- MASKED_MAE = ('masked_mae', <class 'metatrain.utils.loss.TensorMapMaskedMAELoss'>)¶
- MASKED_HUBER = ('masked_huber', <class 'metatrain.utils.loss.TensorMapMaskedHuberLoss'>)¶
- POINTWISE = ('pointwise', <class 'metatrain.utils.loss.BaseTensorMapLoss'>)¶
- MASKED_POINTWISE = ('masked_pointwise', <class 'metatrain.utils.loss.MaskedTensorMapLoss'>)¶
- MASKED_DOS = ('masked_dos', <class 'metatrain.utils.loss.MaskedDOSLoss'>)¶
- GAUSSIAN_NLL = ('gaussian_nll_ensemble', <class 'metatrain.utils.loss.TensorMapGaussianNLLLoss'>)¶
- GAUSSIAN_CRPS = ('gaussian_crps_ensemble', <class 'metatrain.utils.loss.TensorMapGaussianCRPSLoss'>)¶
- EMPIRICAL_CRPS = ('empirical_crps_ensemble', <class 'metatrain.utils.loss.TensorMapEmpiricalCRPSLoss'>)¶
- property cls: Type[LossInterface]¶
Class implementing this loss type.
- metatrain.utils.loss.create_loss(loss_type: str, *, name: str, gradient: str | None, weight: float, reduction: str, **extra_kwargs: Any) LossInterface[source]¶
Factory to instantiate a concrete
LossInterfacegiven its string key.- Parameters:
loss_type (str) – string key matching one of the members of
LossType.name (str) – target name for the loss.
gradient (str | None) – gradient name, if present.
weight (float) – weight of the loss contribution in the final aggregation.
reduction (str) – reduction mode for the torch loss.
**extra_kwargs (Any) – additional hyperparameters specific to the loss type.
- Returns:
instance of the selected loss.
- Return type: