Loss functions

metatrain supports a variety of loss functions, which can be configured in the loss subsection of the training section for each architecture in the options file. The loss functions are designed to be flexible and can be tailored to the specific needs of the dataset and of the predicted targets.

Loss function configurations in the options.yaml file

A common use case is the training of machine-learning interatomic potentials (MLIPs), where the training targets include energies, forces, and stress/virial.

The loss terms for energy, forces, and stress can be specified as:

loss:
  energy:
    type: mse
    weight: 1.0
    reduction: mean
    forces:
      type: mse
      weight: 1.0
      reduction: mean
    stress:
      type: mse
      weight: 1.0
      reduction: mean

Here, forces and stress refer to the gradients of the energy target with respect to atomic positions and strain, respectively, assuming these gradients have been requested in the training set configuration.

Another common scenario is when only the loss function type needs to be specified, while default values are acceptable for the other parameters. In that case, the configuration can be further simplified to:

loss:
  energy:
    type: mse
    forces: mae
    stress: huber

where, for example, different types of losses are requested for different targets. This is equivalent to the more detailed configuration:

loss:
  energy:
    type: mse
    weight: 1.0
    reduction: mean
    forces:
      type: mae
      weight: 1.0
      reduction: mean
    stress:
      type: huber
      weight: 1.0
      reduction: mean
      delta: 1.0

When all targets and their gradients should use the same loss function with equal weights and reductions, it is also possible to use the global shorthand

loss: mse

which sets the loss type to mean squared error (MSE) for all targets and, if present, for all their gradients.

This example assumes that the training set contains a target named energy, and that gradients with respect to both atomic positions (forces) and strain (stress/virial) have been requested. If the energy target has a custom name (e.g., mtt::etot), the loss configuration should use that name instead:

loss:
  mtt::etot:
    type: mse
    weight: 1.0
    reduction: mean
    forces:
      type: mse
      weight: 1.0
      reduction: mean
    stress:
      type: mse
      weight: 1.0
      reduction: mean
...
training_set:
  systems:
  ...
  targets:
    mtt::etot:
      quantity: energy
      forces: true  # or some other allowed configuration
      stress: true  # or some other allowed configuration
  ...

Mind that, in the case the target name is not energy, the key quantity: energy in the target definition must be present to specify that this target corresponds to energies. This allows metatrain to associate the correct gradients (forces and stress/virial) when requested. Both the explicit MLIP configuration (with separate energy, forces, and stress entries) and the global shorthand loss: mse are thus mapped to the same internal representation, where loss terms are specified explicitly per target and per gradient.

Internal configuration format

The internal configuration used by metatrain during training is a more detailed version of the examples shown above, where each target has its own loss configuration and an optional gradients subsection.

The example above where the loss function is MSE for energy, forces, and stress is thus represented internally as:

loss:
  energy:
    type: mse
    weight: 1.0
    reduction: mean
    gradients:
      positions:
        type: mse
        weight: 1.0
        reduction: mean
      strain:
        type: mse
        weight: 1.0
        reduction: mean

This internal format is also available to users in the options file. It can be used to handle general targets and their “non-standard” gradients, those that are not simply forces or stress (for example, custom derivatives with respect to user-defined quantities).

Generally, each loss-function term accepts the following parameters:

param type:

This controls the type of loss to be used. The default value is mse, and other standard options are mae and huber, which implement the equivalent PyTorch loss functions MSELoss, L1Loss, and HuberLoss, respectively. There are also “masked” versions of these losses, which are useful when using padded targets with values that should be masked before computing the loss. The masked losses are named masked_mse, masked_mae, and masked_huber.

param weight:

This controls the weighting of different contributions to the loss (e.g., energy, forces, virial, etc.). The default value of 1.0 for all targets works well for most datasets, but can be adjusted if required.

param reduction:

This controls how the overall loss is computed across batches. The default for this is to use the mean of the batch losses. The sum function is also supported.

Some losses, like huber, require additional parameters to be specified:

param delta:

This parameter is specific to the Huber loss functions (huber and masked_huber) and defines the threshold at which the loss function transitions from quadratic to linear behavior. The default value is 1.0.

Masked loss functions

Masked loss functions are particularly useful when dealing with datasets that contain padded targets. In such cases, the loss function can be configured to ignore the padded values during the loss computation. This is done by using the masked_ prefix in the loss type. For example, if the target contains padded values, you can use masked_mse or masked_mae to ensure that the loss is computed only on the valid (non-padded) values. The values of the masks must be passed as extra_data in the training set, and the loss function will automatically apply the mask to the target values. An example configuration for a masked loss is as follows:

loss:
  energy:
    type: masked_mse
    weight: 1.0
    reduction: sum
    forces:
      type: masked_mae
      weight: 0.1
      reduction: sum
...

training_set:
  systems:
  ...
  targets:
    mtt::my_target:
      ...
  ...
  extra_data:
    mtt::my_target_mask:
      read_from: my_target_mask.mts

DOS Loss Function

The masked DOS loss function is a specialized loss designed for training on the electronic density of states (DOS), typically represented on an energy grid. Structures in a dataset can (and usually do) have eigenvalues spanning different energy ranges, and DOS calculations do not share a common absolute energy reference. To handle this, the loss uses a user-specified number of extra predicted targets to dynamically shift the energy grid for each structure, aligning the predicted DOS with the reference DOS before computing the loss.

After this alignment step, the loss function consists of three components:

  • an integrated loss on the masked DOS values

masked_DOS_loss = torch.trapezoid((aligned_predictions - targets)**2 * mask, x_axis = energy_grid)
  • an integrated loss on the gradient of the unmasked DOS values, to ensure that values outside the masked region are also learned smoothly

unmasked_gradient_loss = torch.trapezoid(aligned_predictions_gradient**2 * (~mask), x_axis = energy_grid)
  • an integrated loss on the cumulative DOS values in the masked region

cumulative_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, x = energy_grid)
cumulative_targets = torch.cumulative_trapezoid(targets, x = energy_grid)
masked_cumulative_DOS_loss = torch.trapezoid((cumulative_aligned_predictions - cumulative_targets)**2 * mask, x_axis = energy_grid[1:])

Each component can be weighted independently to tailor the loss function to specific training needs.

loss = (masked_DOS_loss +
        grad_weight * unmasked_gradient_loss +
        int_weight * masked_cumulative_DOS_loss)

To use this loss function, you can refer to this code snippet for the loss section in your YAML configuration file:

loss:
  mtt::dos:
    type: "masked_dos"
    grad_weight: 1e-4
    int_weight: 2.0
    extra_targets: 200
    reduction: "mean"
param name:

key for the dos in the prediction/target dictionary. (mtt::dos in this case)

param grad_weight:

Multiplier for the gradient of the unmasked DOS component.

param int_weight:

Multiplier for the cumulative DOS component.

param extra_targets:

Number of extra targets predicted by the model.

param reduction:

reduction mode for torch loss. Options are “mean”, “sum”, or “none”.

The values used in the above example are the ones used for PETMADDOS training and can be a reasonable starting point for other applications.

Ensemble Loss Function

An LLPR ensemble can be further trained to improve its uncertainty quantification. This is done by using the metatrain.utils.loss.TensorMapEnsembleLoss function, which implements strictly proper scoring rules for probabilistic regression.

Two of the available losses assume a Gaussian predictive distribution and operate only on the ensemble-predicted mean \(\mu\) and standard deviation \(\sigma\). The third option, the empirical CRPS, uses the full ensemble of predictions and does not rely on any parametric assumption.

  • The Gaussian Negative Log-Likelihood (NLL) loss maximizes the likelihood of the observed data under a Gaussian predictive model. It encourages sharp predictions and is statistically optimal when the residual noise is well described by a Gaussian distribution. Internally, this option uses torch.nn.GaussianNLLLoss.

    YAML configuration:

    loss:
      mtt::target_name:
        type: gaussian_nll_ensemble
    
  • The analytical Gaussian Continuous Ranked Probability Score (CRPS) evaluates the integrated squared difference between the predicted and (assumed) Gaussian cumulative distribution functions. It is given by

    \[\mathrm{CRPS}(\mu, \sigma; y) = \sigma \left[ \frac{1}{\sqrt{\pi}} - 2\phi\left(\frac{y - \mu}{\sigma}\right) - \frac{y - \mu}{\sigma} \left(2\Phi\left(\frac{y - \mu}{\sigma}\right) - 1\right) \right],\]

    where \(\phi\) and \(\Phi\) denote the standard normal density and cumulative distribution functions.

    YAML configuration:

    loss:
      mtt::target_name:
        type: gaussian_crps_ensemble
    
  • The empirical Continuous Ranked Probability Score does not assume a Gaussian predictive distribution. Instead, it evaluates the CRPS directly from the ensemble predictions \(\{x_j\}_{j=1}^M\). For a target value \(y\), the empirical CRPS is

    \[\mathrm{CRPS}_{\mathrm{emp}}(\{x_j\}, y) = \frac{1}{M} \sum_{j=1}^M |x_j - y| - \frac{1}{2 M^2} \sum_{j=1}^M \sum_{k=1}^M |x_j - x_k|.\]

    This scoring rule is strictly proper for arbitrary predictive distributions and therefore leverages the full ensemble to learn non-Gaussian forms of uncertainty.

    YAML configuration:

    loss:
      mtt::target_name:
        type: empirical_crps_ensemble
    

In practice, all three scoring rules encourage calibrated uncertainty estimates, but with different characteristics. The Gaussian NLL is quadratic in the residual and therefore more sensitive to large deviations. The analytical Gaussian CRPS grows linearly with the residual and often yields smoother behaviour when the residual distribution departs from strict Gaussianity. The empirical CRPS is fully non-parametric and can in principle capture skewness, multimodality, or other non-Gaussian features present in the ensemble predictions.