FlashMD

FlashMD is a method for the direct prediction of positions and momenta in a molecular dynamics simulation, presented in [1]. When compared to traditional molecular dynamics methods, it predicts the positions and momenta of atoms after a long time interval, allowing the use of much larger time steps. Therefore, it achieves a significant speedup (10-30x) compared to molecular dynamics using MLIPs. The FlashMD architecture implemented in metatrain is based on the PET architecture.

Installation

To install this architecture along with the metatrain package, run:

pip install metatrain[flashmd]

where the square brackets indicate that you want to install the optional dependencies required for flashmd.

Default Hyperparameters

The description of all the hyperparameters used in flashmd is provided further down this page. However, here we provide you with a yaml file containing all the default hyperparameters, which might be convenient as a starting point to create your own hyperparameter files:

architecture:
  name: experimental.flashmd
  model:
    predict_momenta_as_difference: false
    cutoff: 4.5
    cutoff_width: 0.2
    d_pet: 128
    d_head: 128
    d_node: 512
    d_feedforward: 256
    num_heads: 8
    num_attention_layers: 2
    num_gnn_layers: 3
    normalization: RMSNorm
    activation: SwiGLU
    transformer_type: PreLN
    featurizer_type: feedforward
    long_range:
      enable: false
      use_ewald: false
      smearing: 1.4
      kspace_resolution: 1.33
      interpolation_nodes: 5
  training:
    timestep: null
    masses: {}
    distributed: false
    distributed_port: 39591
    batch_size: 16
    num_epochs: 1000
    warmup_fraction: 0.01
    learning_rate: 0.0003
    weight_decay: null
    log_interval: 1
    checkpoint_interval: 100
    scale_targets: true
    fixed_composition_weights: {}
    fixed_scaling_weights: {}
    per_structure_targets: []
    num_workers: null
    log_mae: false
    log_separate_blocks: false
    best_model_metric: rmse_prod
    grad_clip_norm: 1.0
    loss: mse

Tuning hyperparameters

Most of the parameters of FlashMD are inherited from the PET architecure, although they might have different default values.

  • FlashMD-specific parameters for the model:

    ModelHypers.predict_momenta_as_difference: bool = False

    This parameter controls whether the model will predict future momenta directly or as a difference between the future and the current momenta. Setting it to true will help when predicting relatively small timesteps (when compared to the momentum autocorrelation time), while setting it to false is beneficial when predicting large timesteps.

  • FlashMD-specific parameters for the trainer:

    TrainerHypers.timestep: float | None = None

    The time interval (in fs) between the current and the future positions and momenta that the model must predict. This option is not used in the training, but it is registered in the model and it will be used to validate that the timestep used during inference in MD engines is the same as the one used during training. This hyperparameter must be provided by the user.

    TrainerHypers.masses: dict[int, float] = {}

    A dictionary mapping atomic species to their masses (in atomic mass units).

    Indeed, it should be noted that FlashMD models, as implemented in metatrain, are not transferable across different isotopes. The masses are not used during training, but they are registered in the model and they will be used during inference to validate that the masses used in MD engines are the same as the ones used during training. If not provided, masses from the ase.data module will be used. These correspond to masses averaged over the natural isotopic abundance of each element.

Model hyperparameters

The parameters that go under the architecture.model section of the config file are the following:

ModelHypers.predict_momenta_as_difference: bool = False

This parameter controls whether the model will predict future momenta directly or as a difference between the future and the current momenta. Setting it to true will help when predicting relatively small timesteps (when compared to the momentum autocorrelation time), while setting it to false is beneficial when predicting large timesteps.

ModelHypers.cutoff: float = 4.5

Cutoff radius for neighbor search.

This should be set to a value after which most of the interactions between atoms is expected to be negligible. A lower cutoff will lead to faster models.

ModelHypers.cutoff_width: float = 0.2

Width of the smoothing function at the cutoff

ModelHypers.d_pet: int = 128

Dimension of the edge features.

This hyperparameters controls width of the neural network. In general, increasing it might lead to better accuracy, especially on larger datasets, at the cost of increased training and evaluation time.

ModelHypers.d_head: int = 128

Dimension of the attention heads.

ModelHypers.d_node: int = 512

Dimension of the node features.

Increasing this hyperparameter might lead to better accuracy, with a relatively small increase in inference time.

ModelHypers.d_feedforward: int = 256

Dimension of the feedforward network in the attention layer.

ModelHypers.num_heads: int = 8

Attention heads per attention layer.

ModelHypers.num_attention_layers: int = 2

The number of attention layers in each layer of the graph neural network. Depending on the dataset, increasing this hyperparameter might lead to better accuracy, at the cost of increased training and evaluation time.

ModelHypers.num_gnn_layers: int = 3

The number of graph neural network layers.

In general, decreasing this hyperparameter to 1 will lead to much faster models, at the expense of accuracy. Increasing it may or may not lead to better accuracy, depending on the dataset, at the cost of increased training and evaluation time.

ModelHypers.normalization: Literal['RMSNorm', 'LayerNorm'] = 'RMSNorm'

Layer normalization type.

ModelHypers.activation: Literal['SiLU', 'SwiGLU'] = 'SwiGLU'

Activation function.

ModelHypers.transformer_type: Literal['PreLN', 'PostLN'] = 'PreLN'

The order in which the layer normalization and attention are applied in a transformer block. Available options are PreLN (normalization before attention) and PostLN (normalization after attention).

ModelHypers.featurizer_type: Literal['residual', 'feedforward'] = 'feedforward'

Implementation of the featurizer of the model to use. Available options are residual (the original featurizer from the PET paper, that uses residual connections at each GNN layer for readout) and feedforward (a modern version that uses the last representation after all GNN iterations for readout). Additionally, the feedforward version uses bidirectional features flow during the message passing iterations, that favors features flowing from atom i to atom j to be not equal to the features flowing from atom j to atom i.

ModelHypers.long_range: LongRangeHypers = {'enable': False, 'interpolation_nodes': 5, 'kspace_resolution': 1.33, 'smearing': 1.4, 'use_ewald': False}

Long-range Coulomb interactions parameters.

Trainer hyperparameters

The parameters that go under the architecture.trainer section of the config file are the following:

TrainerHypers.timestep: float | None = None

The time interval (in fs) between the current and the future positions and momenta that the model must predict. This option is not used in the training, but it is registered in the model and it will be used to validate that the timestep used during inference in MD engines is the same as the one used during training. This hyperparameter must be provided by the user.

TrainerHypers.masses: dict[int, float] = {}

A dictionary mapping atomic species to their masses (in atomic mass units).

Indeed, it should be noted that FlashMD models, as implemented in metatrain, are not transferable across different isotopes. The masses are not used during training, but they are registered in the model and they will be used during inference to validate that the masses used in MD engines are the same as the ones used during training. If not provided, masses from the ase.data module will be used. These correspond to masses averaged over the natural isotopic abundance of each element.

TrainerHypers.distributed: bool = False

Whether to use distributed training

TrainerHypers.distributed_port: int = 39591

Port for DDP communication

TrainerHypers.batch_size: int = 16

The number of samples to use in each batch of training. This hyperparameter controls the tradeoff between training speed and memory usage. In general, larger batch sizes will lead to faster training, but might require more memory.

TrainerHypers.num_epochs: int = 1000

Number of epochs.

TrainerHypers.warmup_fraction: float = 0.01

Fraction of training steps used for learning rate warmup.

TrainerHypers.learning_rate: float = 0.0003

Learning rate.

TrainerHypers.weight_decay: float | None = None
TrainerHypers.log_interval: int = 1

Interval to log metrics.

TrainerHypers.checkpoint_interval: int = 100

Interval to save checkpoints.

TrainerHypers.scale_targets: bool = True

Normalize targets to unit std during training.

TrainerHypers.fixed_composition_weights: dict[str, dict[int, float]] = {}

Weights for atomic contributions.

This is passed to the fixed_weights argument of CompositionModel.train_model, see its documentation to understand exactly what to pass here.

TrainerHypers.fixed_scaling_weights: dict[str, float | dict[int, float]] = {}

Weights for target scaling.

This is passed to the fixed_weights argument of Scaler.train_model, see its documentation to understand exactly what to pass here.

TrainerHypers.per_structure_targets: list[str] = []

Targets to calculate per-structure losses.

TrainerHypers.num_workers: int | None = None

Number of workers for data loading. If not provided, it is set automatically.

TrainerHypers.log_mae: bool = False

Log MAE alongside RMSE

TrainerHypers.log_separate_blocks: bool = False

Log per-block error.

TrainerHypers.best_model_metric: Literal['rmse_prod', 'mae_prod', 'loss'] = 'rmse_prod'

Metric used to select best checkpoint (e.g., rmse_prod)

TrainerHypers.grad_clip_norm: float = 1.0

Maximum gradient norm value, by default inf (no clipping)

TrainerHypers.loss: str | dict[str, LossSpecification | str] = 'mse'

This section describes the loss function to be used. See the Loss functions for more details.

References