.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/6-torchsim-batched.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_6-torchsim-batched.py: .. _torchsim-batched: Batched simulations with TorchSim ================================= TorchSim supports batching multiple systems into a single ``SimState`` for efficient parallel evaluation on GPU. :py:class:`~metatomic_torchsim.MetatomicModel` handles this transparently. .. GENERATED FROM PYTHON SOURCE LINES 14-20 Setup ----- We reuse the same minimal model from :ref:`torchsim-getting-started`. The model must produce differentiable energy so that forces/stress can be computed via autograd. .. GENERATED FROM PYTHON SOURCE LINES 21-80 .. code-block:: Python from typing import Dict, List, Optional import ase.build import matplotlib.pyplot as plt import torch import torch_sim as ts from metatensor.torch import Labels, TensorBlock, TensorMap import metatomic.torch as mta from metatomic_torchsim import MetatomicModel class HarmonicEnergy(torch.nn.Module): """Harmonic restraint: E = k * sum(positions^2).""" def __init__(self, k: float = 0.1): super().__init__() self.k = k def forward( self, systems: List[mta.System], outputs: Dict[str, mta.ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: energies: List[torch.Tensor] = [] for system in systems: e = self.k * torch.sum(system.positions**2) energies.append(e.reshape(1, 1)) energy = torch.cat(energies, dim=0) block = TensorBlock( values=energy, samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), components=[], properties=Labels("energy", torch.tensor([[0]])), ) return { "energy": TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block]) } capabilities = mta.ModelCapabilities( length_unit="Angstrom", atomic_types=[13, 29], # Al, Cu interaction_range=0.0, outputs={"energy": mta.ModelOutput(quantity="energy", unit="eV")}, supported_devices=["cpu"], dtype="float64", ) atomistic_model = mta.AtomisticModel( HarmonicEnergy(0.1).eval(), mta.ModelMetadata(), capabilities ) model = MetatomicModel(atomistic_model, device="cpu") .. GENERATED FROM PYTHON SOURCE LINES 81-85 Creating a batched state ------------------------ Pass a list of ASE ``Atoms`` objects to ``initialize_state``: .. GENERATED FROM PYTHON SOURCE LINES 86-96 .. code-block:: Python atoms_list = [ ase.build.bulk("Cu", "fcc", a=3.6, cubic=True), ase.build.bulk("Cu", "fcc", a=3.65, cubic=True), ase.build.bulk("Al", "fcc", a=4.05, cubic=True), ] sim_state = ts.initialize_state(atoms_list, device=model.device, dtype=model.dtype) print("Total atoms in batch:", sim_state.n_atoms) .. rst-class:: sphx-glr-script-out .. code-block:: none Total atoms in batch: 12 .. GENERATED FROM PYTHON SOURCE LINES 97-101 Evaluating the batch -------------------- A single forward call evaluates all systems: .. GENERATED FROM PYTHON SOURCE LINES 102-109 .. code-block:: Python results = model(sim_state) print("Energy shape:", results["energy"].shape) # [n_systems] print("Forces shape:", results["forces"].shape) # [n_total_atoms, 3] print("Stress shape:", results["stress"].shape) # [n_systems, 3, 3] .. rst-class:: sphx-glr-script-out .. code-block:: none Energy shape: torch.Size([3]) Forces shape: torch.Size([12, 3]) Stress shape: torch.Size([3, 3, 3]) .. GENERATED FROM PYTHON SOURCE LINES 110-117 The output shapes reflect the batch: - ``results["energy"]`` has shape ``[n_systems]`` -- one energy per system - ``results["forces"]`` has shape ``[n_total_atoms, 3]`` -- all atoms concatenated - ``results["stress"]`` has shape ``[n_systems, 3, 3]`` -- one 3x3 tensor per system .. GENERATED FROM PYTHON SOURCE LINES 118-121 .. code-block:: Python print("Per-system energies:", results["energy"]) .. rst-class:: sphx-glr-script-out .. code-block:: none Per-system energies: tensor([1.9440, 1.9984, 2.4604], dtype=torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 122-128 How ``system_idx`` works ------------------------ ``SimState`` tracks which atom belongs to which system via the ``system_idx`` tensor. For three 4-atom systems, ``system_idx`` looks like: .. GENERATED FROM PYTHON SOURCE LINES 129-132 .. code-block:: Python print("system_idx:", sim_state.system_idx) .. rst-class:: sphx-glr-script-out .. code-block:: none system_idx: tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) .. GENERATED FROM PYTHON SOURCE LINES 133-143 ``MetatomicModel.forward`` uses this to split the batched positions and types into per-system ``System`` objects before calling the underlying model. Batch consistency ----------------- Energies computed in a batch match those computed individually. This is guaranteed because each system gets its own neighbor list and independent evaluation: .. GENERATED FROM PYTHON SOURCE LINES 144-164 .. code-block:: Python individual_energies = [] for atoms in atoms_list: state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) res = model(state) individual_energies.append(res["energy"].item()) print("Batched: ", [e.item() for e in results["energy"]]) print("Individual:", individual_energies) plt.scatter(individual_energies, results["energy"].cpu().numpy()) plt.plot( [min(individual_energies), max(individual_energies)], [min(individual_energies), max(individual_energies)], "k--", ) plt.xlabel("Individual energies") plt.ylabel("Batched energies") plt.show() .. image-sg:: /examples/images/sphx_glr_6-torchsim-batched_001.png :alt: 6 torchsim batched :srcset: /examples/images/sphx_glr_6-torchsim-batched_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Batched: [1.9440000000000002, 1.9983750000000002, 2.460375] Individual: [1.9440000000000002, 1.9983750000000002, 2.460375] .. GENERATED FROM PYTHON SOURCE LINES 165-176 Performance considerations -------------------------- Batching is most beneficial on GPU, where the neighbor list computation and model forward pass can run in parallel across systems. On CPU, the speedup comes from reduced Python overhead (one call instead of N). For very large systems or many small ones, adjust the batch size to fit in GPU memory. TorchSim does not impose a maximum batch size, but each system gets its own neighbor list, so memory scales with the sum of per-system sizes. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.134 seconds) .. _sphx_glr_download_examples_6-torchsim-batched.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 6-torchsim-batched.ipynb <6-torchsim-batched.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 6-torchsim-batched.py <6-torchsim-batched.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 6-torchsim-batched.zip <6-torchsim-batched.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_