Loss functions¶
metatrain supports a variety of loss functions, which can be configured in the loss subsection of the training section for each architecture in the options file.
The loss functions are designed to be flexible and can be tailored to the specific needs of the dataset and of the predicted targets.
Loss function configurations in the options.yaml file¶
A common use case is the training of machine-learning interatomic potentials (MLIPs), where the training targets include energies, forces, and stress/virial.
The loss terms for energy, forces, and stress can be specified as:
loss:
energy:
type: mse
weight: 1.0
reduction: mean
forces:
type: mse
weight: 1.0
reduction: mean
stress:
type: mse
weight: 1.0
reduction: mean
Here, forces and stress refer to the gradients of the energy target with respect to atomic positions and strain, respectively, assuming these gradients have been requested in the training set configuration.
Another common scenario is when only the loss function type needs to be specified, while default values are acceptable for the other parameters. In that case, the configuration can be further simplified to:
loss:
energy:
type: mse
forces: mae
stress: huber
where, for example, different types of losses are requested for different targets. This is equivalent to the more detailed configuration:
loss:
energy:
type: mse
weight: 1.0
reduction: mean
forces:
type: mae
weight: 1.0
reduction: mean
stress:
type: huber
weight: 1.0
reduction: mean
delta: 1.0
When all targets and their gradients should use the same loss function with equal weights and reductions, it is also possible to use the global shorthand
loss: mse
which sets the loss type to mean squared error (MSE) for all targets and, if present, for all their gradients.
This example assumes that the training set contains a target named energy, and that gradients with respect to both atomic positions (forces) and strain (stress/virial) have been requested.
If the energy target has a custom name (e.g., mtt::etot), the loss configuration should use that name instead:
loss:
mtt::etot:
type: mse
weight: 1.0
reduction: mean
forces:
type: mse
weight: 1.0
reduction: mean
stress:
type: mse
weight: 1.0
reduction: mean
...
training_set:
systems:
...
targets:
mtt::etot:
quantity: energy
forces: true # or some other allowed configuration
stress: true # or some other allowed configuration
...
Mind that, in the case the target name is not energy, the key quantity: energy in the target definition must be present to specify that this target corresponds to energies.
This allows metatrain to associate the correct gradients (forces and stress/virial) when requested.
Both the explicit MLIP configuration (with separate energy, forces, and stress entries) and the global shorthand loss: mse are thus mapped to the same internal representation, where loss terms are specified explicitly per target and per gradient.
Internal configuration format¶
The internal configuration used by metatrain during training is a more detailed version of the examples shown above, where each target has its own loss configuration and an optional gradients subsection.
The example above where the loss function is MSE for energy, forces, and stress is thus represented internally as:
loss:
energy:
type: mse
weight: 1.0
reduction: mean
gradients:
positions:
type: mse
weight: 1.0
reduction: mean
strain:
type: mse
weight: 1.0
reduction: mean
This internal format is also available to users in the options file. It can be used to handle general targets and their “non-standard” gradients, those that are not simply forces or stress (for example, custom derivatives with respect to user-defined quantities).
Generally, each loss-function term accepts the following parameters:
- param type:
This controls the type of loss to be used. The default value is
mse, and other standard options aremaeandhuber, which implement the equivalent PyTorch loss functions MSELoss, L1Loss, and HuberLoss, respectively. There are also “masked” versions of these losses, which are useful when using padded targets with values that should be masked before computing the loss. The masked losses are namedmasked_mse,masked_mae, andmasked_huber.- param
weight: This controls the weighting of different contributions to the loss (e.g., energy, forces, virial, etc.). The default value of 1.0 for all targets works well for most datasets, but can be adjusted if required.
- param
reduction: This controls how the overall loss is computed across batches. The default for this is to use the
meanof the batch losses. Thesumfunction is also supported.
Some losses, like huber, require additional parameters to be specified:
- param delta:
This parameter is specific to the Huber loss functions (
huberandmasked_huber) and defines the threshold at which the loss function transitions from quadratic to linear behavior. The default value is 1.0.
Masked loss functions¶
Masked loss functions are particularly useful when dealing with datasets that contain padded targets. In such cases, the loss function can be configured to ignore the padded values during the loss computation.
This is done by using the masked_ prefix in the loss type. For example, if the target contains padded values, you can use masked_mse or masked_mae to ensure that the loss is computed only on the valid (non-padded) values.
The values of the masks must be passed as extra_data in the training set, and the loss function will automatically apply the mask to the target values. An example configuration for a masked loss is as follows:
loss: energy: type: masked_mse weight: 1.0 reduction: sum forces: type: masked_mae weight: 0.1 reduction: sum ... training_set: systems: ... targets: mtt::my_target: ... ... extra_data: mtt::my_target_mask: read_from: my_target_mask.mts
DOS Loss Function¶
The masked DOS loss function is a specialized loss designed for training on the electronic density of states (DOS), typically represented on an energy grid. Structures in a dataset can (and usually do) have eigenvalues spanning different energy ranges, and DOS calculations do not share a common absolute energy reference. To handle this, the loss uses a user-specified number of extra predicted targets to dynamically shift the energy grid for each structure, aligning the predicted DOS with the reference DOS before computing the loss.
After this alignment step, the loss function consists of three components:
an integrated loss on the masked DOS values
masked_DOS_loss = torch.trapezoid((aligned_predictions - targets)**2 * mask, x_axis = energy_grid)
an integrated loss on the gradient of the unmasked DOS values, to ensure that values outside the masked region are also learned smoothly
unmasked_gradient_loss = torch.trapezoid(aligned_predictions_gradient**2 * (~mask), x_axis = energy_grid)
an integrated loss on the cumulative DOS values in the masked region
cumulative_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, x = energy_grid)
cumulative_targets = torch.cumulative_trapezoid(targets, x = energy_grid)
masked_cumulative_DOS_loss = torch.trapezoid((cumulative_aligned_predictions - cumulative_targets)**2 * mask, x_axis = energy_grid[1:])
Each component can be weighted independently to tailor the loss function to specific training needs.
loss = (masked_DOS_loss +
grad_weight * unmasked_gradient_loss +
int_weight * masked_cumulative_DOS_loss)
To use this loss function, you can refer to this code snippet for the loss section in your YAML configuration file:
loss:
mtt::dos:
type: "masked_dos"
grad_weight: 1e-4
int_weight: 2.0
extra_targets: 200
reduction: "mean"
- param name:
key for the dos in the prediction/target dictionary. (mtt::dos in this case)
- param grad_weight:
Multiplier for the gradient of the unmasked DOS component.
- param int_weight:
Multiplier for the cumulative DOS component.
- param extra_targets:
Number of extra targets predicted by the model.
- param reduction:
reduction mode for torch loss. Options are “mean”, “sum”, or “none”.
The values used in the above example are the ones used for PETMADDOS training and can be a reasonable starting point for other applications.
Ensemble Loss Function¶
An LLPR ensemble can be further trained to improve its uncertainty quantification.
This is done by using the metatrain.utils.loss.TensorMapEnsembleLoss function, which implements strictly proper scoring rules for probabilistic regression.
Two of the available losses assume a Gaussian predictive distribution and operate only on the ensemble-predicted mean \(\mu\) and standard deviation \(\sigma\). The third option, the empirical CRPS, uses the full ensemble of predictions and does not rely on any parametric assumption.
The Gaussian Negative Log-Likelihood (NLL) loss maximizes the likelihood of the observed data under a Gaussian predictive model. It encourages sharp predictions and is statistically optimal when the residual noise is well described by a Gaussian distribution. Internally, this option uses
torch.nn.GaussianNLLLoss.YAML configuration:
loss: mtt::target_name: type: gaussian_nll_ensemble
The analytical Gaussian Continuous Ranked Probability Score (CRPS) evaluates the integrated squared difference between the predicted and (assumed) Gaussian cumulative distribution functions. It is given by
\[\mathrm{CRPS}(\mu, \sigma; y) = \sigma \left[ \frac{1}{\sqrt{\pi}} - 2\phi\left(\frac{y - \mu}{\sigma}\right) - \frac{y - \mu}{\sigma} \left(2\Phi\left(\frac{y - \mu}{\sigma}\right) - 1\right) \right],\]where \(\phi\) and \(\Phi\) denote the standard normal density and cumulative distribution functions.
YAML configuration:
loss: mtt::target_name: type: gaussian_crps_ensemble
The empirical Continuous Ranked Probability Score does not assume a Gaussian predictive distribution. Instead, it evaluates the CRPS directly from the ensemble predictions \(\{x_j\}_{j=1}^M\). For a target value \(y\), the empirical CRPS is
\[\mathrm{CRPS}_{\mathrm{emp}}(\{x_j\}, y) = \frac{1}{M} \sum_{j=1}^M |x_j - y| - \frac{1}{2 M^2} \sum_{j=1}^M \sum_{k=1}^M |x_j - x_k|.\]This scoring rule is strictly proper for arbitrary predictive distributions and therefore leverages the full ensemble to learn non-Gaussian forms of uncertainty.
YAML configuration:
loss: mtt::target_name: type: empirical_crps_ensemble
In practice, all three scoring rules encourage calibrated uncertainty estimates, but with different characteristics. The Gaussian NLL is quadratic in the residual and therefore more sensitive to large deviations. The analytical Gaussian CRPS grows linearly with the residual and often yields smoother behaviour when the residual distribution departs from strict Gaussianity. The empirical CRPS is fully non-parametric and can in principle capture skewness, multimodality, or other non-Gaussian features present in the ensemble predictions.