Training a FlashMD model

This tutorial demonstrates how to train a FlashMD model for the direct prediction of molecular dynamics. This type of model affords faster MD simulations compared to MLIPs by a factor between 10 and 30 (https://arxiv.org/abs/2505.19350).

import copy
import subprocess

import ase
import ase.build
import ase.io
import ase.units
from ase.calculators.emt import EMT
from ase.md import VelocityVerlet
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution

Data generation

FlashMD models train on molecular dynamics trajectories in the NVE ensemble (i.e., most often with the velocity Verlet integrator). These trajectories can be generated with almost any MD code (i-PI, LAMMPS, etc.). Here, for simplicity, we will use ASE and its built-in EMT potential. In reality, you might want to use a more accurate baseline such as ab initio MD or a machine-learned interatomic potential (MLIP).

# We start by creating a simple system (a small box of aluminum).
atoms = ase.build.bulk("Al", "fcc", cubic=True) * (2, 2, 2)

# We first equilibrate the system at 300K using a Langevin thermostat.
MaxwellBoltzmannDistribution(atoms, temperature_K=300)
atoms.calc = EMT()
dyn = Langevin(
    atoms, 2 * ase.units.fs, temperature_K=300, friction=1 / (100 * ase.units.fs)
)
dyn.run(1000)  # 2 ps equilibration (around 10 ps is better in practice)

# Then, we run a production simulation in the NVE ensemble.
trajectory = []


def store_trajectory():
    trajectory.append(copy.deepcopy(atoms))


dyn = VelocityVerlet(atoms, 1 * ase.units.fs)
dyn.attach(store_trajectory, interval=1)
dyn.run(2000)  # 2 ps NVE run
True

Data preparation

Now, we need to generate the training data from the trajectory. FlashMD models require future positions and momenta as targets. We will save them in an .xyz file under the future_positions and future_momenta keys.

# The FlashMD model will be trained to predict 30 steps into the future, i.e., 30 fs
# since we ran the reference simulation with a time step of 1 fs. For this type
# of system, FlashMD is expected to perform well up to around 60-80 fs.
time_lag = 30

# We pick starting structures that are 200 steps apart. To avoid wasting training
# structures, this should be set to be around the expected velocity-velocity
# autocorrelation time for the system. This is the time scale that quantifies how long
# it takes for the system to forget its original velocities.
spacing = 200


def get_structure_for_dataset(frame_now, frame_ahead):
    s = copy.deepcopy(frame_now)
    s.arrays["future_positions"] = frame_ahead.get_positions()
    s.arrays["future_momenta"] = frame_ahead.get_momenta()
    return s


structures_for_dataset = []
for i in range(0, len(trajectory) - time_lag, spacing):
    frame_now = trajectory[i]
    frame_ahead = trajectory[i + time_lag]
    s = get_structure_for_dataset(frame_now, frame_ahead)
    structures_for_dataset.append(s)

    # Here, we also add the time-reversed pair (optional)
    frame_now_trev = copy.deepcopy(frame_now)
    frame_ahead_trev = copy.deepcopy(frame_ahead)
    frame_now_trev.set_momenta(-frame_now_trev.get_momenta())
    frame_ahead_trev.set_momenta(-frame_ahead_trev.get_momenta())
    s = get_structure_for_dataset(frame_ahead_trev, frame_now_trev)
    structures_for_dataset.append(s)

# Write the structures to an xyz file
ase.io.write("flashmd.xyz", structures_for_dataset)

Training the model

The dataset is now ready for training. You can now provide it to metatrain and train your FlashMD model!

For example, you can use the following options file:

seed: 42

architecture:
  name: experimental.flashmd
  training:
    timestep: 30  # in this case 30 (time lag) * 1 fs (timestep of reference MD)
    batch_size: 2  # to be increased in a production scenario
    num_epochs: 5  # to be increased (at least 1000-10000) in a production scenario
    log_interval: 1
    loss:
      positions:
        type: mse
        weight: 1.0
        reduction: mean
      momenta:
        type: mse
        weight: 1.0
        reduction: mean

training_set:
  systems:
    read_from: flashmd.xyz
    length_unit: A
  targets:
    positions:
      key: future_positions
      quantity: length
      unit: A
      type:
        cartesian:
          rank: 1
      per_atom: true
    momenta:
      key: future_momenta
      quantity: momentum
      unit: (eV*u)^1/2
      type:
        cartesian:
          rank: 1
      per_atom: true

validation_set: 0.1
test_set: 0.1
subprocess.run(["mtt", "train", "options.yaml"])
CompletedProcess(args=['mtt', 'train', 'options.yaml'], returncode=0)

Total running time of the script: (0 minutes 32.937 seconds)

Gallery generated by Sphinx-Gallery