import warnings
from typing import List, Optional, Union
import numpy as np
import torch
from . import System
try:
import ase
HAS_ASE = True
except ImportError:
HAS_ASE = False
class IntoSystem:
"""A type that can be converted into a
:py:class:`metatensor.torch.atomistic.System`.
This is an abstract class that is used to indicate a class whose objects
can be converted into a :py:class:`System`. For the moment,
the only supported type is :py:class:`ase.Atoms`."""
pass
[docs]
def systems_to_torch(
systems: Union[IntoSystem, List[IntoSystem]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
positions_requires_grad: bool = False,
cell_requires_grad: bool = False,
) -> Union[System, List[System]]:
"""Converts a system or a list of systems into a
``metatensor.torch.atomistic.System`` or a list of such objects.
:param: systems: The system or list of systems to convert.
:param: dtype: The dtype of the output tensors. If ``None``, the default
dtype is used.
:param: device: The device of the output tensors. If ``None``, the default
device is used.
:param: positions_requires_grad: Whether the positions tensors of
the outputs should require gradients.
:param: cell_requires_grad: Whether the cell tensors of the outputs
should require gradients.
:return: The converted system or list of systems.
"""
if isinstance(systems, list):
return [
_system_to_torch(
system, dtype, device, positions_requires_grad, cell_requires_grad
)
for system in systems
]
else:
return _system_to_torch(
systems, dtype, device, positions_requires_grad, cell_requires_grad
)
def _system_to_torch(
system: IntoSystem,
dtype: Optional[torch.dtype],
device: Optional[torch.device],
positions_requires_grad: bool,
cell_requires_grad: bool,
) -> System:
"""Converts a system into a ``metatensor.torch.atomistic.System``.
:param: system: The system to convert.
:param: dtype: The dtype of the output tensors. If ``None``, the default
dtype is used.
:param: device: The device of the output tensors. If ``None``, the default
device is used.
:param: positions_requires_grad: Whether the positions tensors of
the outputs should require gradients.
:param: cell_requires_grad: Whether the cell tensors of the outputs
should require gradients.
:return: The converted system.
"""
if not HAS_ASE:
raise RuntimeError("The `ase` package is required to convert systems to torch.")
if not isinstance(system, ase.Atoms):
raise ValueError(
"Only `ase.Atoms` objects can be converted to `System`s "
f"for now; got {type(system)}."
)
if dtype is None:
# this is necessary because creating torch tensors from numpy arrays
# takes the dtype from the numpy array, which is not always the default
# dtype
dtype = torch.get_default_dtype()
positions = torch.tensor(
system.positions,
requires_grad=positions_requires_grad,
dtype=dtype,
device=device,
)
cell_vectors_are_not_zero = np.any(system.cell != 0, axis=1)
if not np.all(cell_vectors_are_not_zero == system.pbc):
warnings.warn(
"A conversion to `System` was requested for an `ase.Atoms` object "
"with one or more non-zero cell vectors but where the corresponding "
"boundary conditions are set to `False`. "
"The corresponding cell vectors will be set to zero.",
stacklevel=3,
)
cell = torch.zeros((3, 3), dtype=dtype, device=device)
pbc = torch.tensor(system.pbc, dtype=torch.bool, device=device)
cell[pbc] = torch.tensor(system.cell[system.pbc], dtype=dtype, device=device)
types = torch.tensor(system.numbers, device=device, dtype=torch.int32)
return System(positions=positions, cell=cell, types=types, pbc=pbc)