Source code for metatensor.torch.atomistic.io

import io
import os
import warnings
import zipfile
from pathlib import Path
from typing import Union

import numpy as np
import torch

import metatensor.torch

from . import NeighborListOptions, System


[docs] def save(file: Union[str, Path, io.BytesIO], system: System) -> None: """Save a System object to a file. The provided System must contain float64 data and be on the CPU device. The saved file will be a zip archive containing the following files: - ``types.npy``, containing the atomic types in numpy's NPY format; - ``positions.npy``, containing the systems' positions in numpy's NPY format; - ``cell.npy``, containing the systems' cell in numpy's NPY format; - ``pbc.npy``, containing the periodic boundary conditions in numpy's NPY format; For each neighbor list in the System object, the following files will be saved (where ``{nl_idx}`` is the index of the neighbor list): - ``pairs/{nl_idx}/options.json``: the ``NeighborListOptions`` object converted to a JSON string. - ``pairs/{nl_idx}/data.mts``: the neighbor list ``TensorBlock`` object For each extra data in the System object, the following file will be saved (where ``{name}`` is the name of the extra data): - ``data/{name}.mts``: The extra data ``TensorMap`` :param file: The path (or file-like object) to save the System to. :param system: The System object to save. """ if isinstance(file, (str, Path)): if not file.endswith(".mta"): raise ValueError("The provided path must have the `.mta` extension.") if _is_system(system): _save_system(file, system) else: raise ValueError("`system` must be a System object.")
[docs] def load_system(file: Union[str, Path, io.BytesIO]) -> System: """Load a System object from a file. The loaded System object will be on the CPU device and contain float64 data. :param file: The path (or file-like object) to load the System object from. :return: The System object. """ if isinstance(file, (str, Path)): if not os.path.exists(file): raise FileNotFoundError(f"File not found: {file}") if not file.endswith(".mta"): raise ValueError("The provided path must have the `.mta` extension.") if _is_system_mta(file): return _load_system(file) else: raise ValueError(f"File does not contain a valid System object: {file}")
def _is_system_mta(path: Union[str, Path, io.BytesIO]) -> bool: zipf = zipfile.ZipFile(path, "r") all_zip_files = zipf.namelist() required_files = ["positions.npy", "cell.npy", "types.npy", "pbc.npy"] for file in required_files: if file not in all_zip_files: return False return True def _load_system(path: Union[str, Path, io.BytesIO]) -> System: # we filter a warning related to the fact that numpy arrays from buffers # are not writable, while torch would like arrays to be writable when # converting them to tensors; this is ok because we then clone the tensor warnings.filterwarnings( "ignore", category=UserWarning, message="The given NumPy array is not writable" ) with zipfile.ZipFile(path, "r") as zipf: positions = torch.from_numpy(np.load(zipf.open("positions.npy"))) cell = torch.from_numpy(np.load(zipf.open("cell.npy"))) types = torch.from_numpy(np.load(zipf.open("types.npy"))) pbc = torch.from_numpy(np.load(zipf.open("pbc.npy"))) neighbor_list_options_list = [] neighbor_lists = [] for path in zipf.namelist(): if path.startswith("pairs/") and path.endswith("/options.json"): nl_options = NeighborListOptions(0.0, False, False) nl_options._get_method("__setstate__")(zipf.read(path)) neighbor_list_options_list.append(nl_options) data_path = path[:-12] + "data.mts" numpy_buffer = np.frombuffer(zipf.read(data_path), dtype=np.uint8) tensor_buffer = torch.from_numpy(numpy_buffer) neighbor_lists.append(metatensor.torch.load_block_buffer(tensor_buffer)) extra_data_dict = {} for path in zipf.namelist(): if path.startswith("data/"): name = os.path.basename(path).replace(".mts", "") numpy_buffer = np.frombuffer(zipf.read(path), dtype=np.uint8) tensor_buffer = torch.from_numpy(numpy_buffer) extra_data_dict[name] = metatensor.torch.load_buffer(tensor_buffer) system = System( positions=positions, cell=cell, types=types, pbc=pbc, ) for options, neighbors in zip(neighbor_list_options_list, neighbor_lists): system.add_neighbor_list(options, neighbors) for key, value in extra_data_dict.items(): system.add_data(key, value) return system def _is_system(data: torch.ScriptObject) -> bool: if not isinstance(data, torch.ScriptObject): return False try: data.positions data.cell data.types data.pbc return True except AttributeError: return False def _save_system(path: Union[str, Path], system: System) -> None: with zipfile.ZipFile(path, "w") as zipf: for nl_idx, nl_options in enumerate(system.known_neighbor_lists()): zipf.writestr(f"pairs/{nl_idx}/options.json", nl_options.__getstate__()[0]) nl = system.get_neighbor_list(nl_options) tensor_buffer = metatensor.torch.save_buffer(nl) zipf.writestr(f"pairs/{nl_idx}/data.mts", tensor_buffer.numpy().tobytes()) for key in system.known_data(): data = system.get_data(key) tensor_buffer = metatensor.torch.save_buffer(data) zipf.writestr(f"data/{key}.mts", tensor_buffer.numpy().tobytes()) for tensor_name in ["positions", "cell", "types", "pbc"]: tensor = getattr(system, tensor_name) numpy_array = tensor.numpy() with zipf.open(f"{tensor_name}.npy", "w") as fd: np.save(fd, numpy_array)