Source code for metatensor._data._extract

import ctypes
from typing import Any, NewType, Union

import numpy as np

from .._c_api import (
    DLDevice,
    DLDeviceType,
    DLManagedTensorVersioned,
    DLPackVersion,
    c_uintptr_t,
    mts_array_t,
    mts_data_origin_t,
    mts_status_t,
)
from .._status import MetatensorError, check_status
from ._array import (
    _KNOWN_ARRAY_WRAPPERS,
    _origin_numpy,
    _origin_pytorch,
    _register_origin,
)
from ._dlpack import DLPackArray, wrap_versioned_as_unversioned


try:
    import torch

    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False


if HAS_TORCH:
    # This NewType is only used for typechecking and documentation purposes. If you are
    # trying to add support for new array types, see `data.array.ArrayWrapper` instead.
    Array = NewType("Array", Union[np.ndarray, torch.Tensor])
else:
    Array = NewType("Array", np.ndarray)

Array.__doc__ = """
An ``Array`` contains the actual data stored in a :py:class:`metatensor.TensorBlock`.

This data is manipulated by ``metatensor`` in a completely opaque way: this library does
not know what's inside the arrays appart from a small set of constrains:

- the array contains numeric data (:py:func:`metatensor.load` and
  :py:func:`metatensor.save` additionally requires contiguous arrays of 64-bit IEEE-754
  floating points numbers);
- they are stored as row-major, n-dimensional arrays with at least 2 dimensions;
- it is possible to create new arrays and move data from one array to another.

The actual type of an ``Array`` depends on how the :py:class:`metatensor.TensorBlock`
was created. Currently, :py:class:`numpy.ndarray` and :py:class:`torch.Tensor` are
supported.
"""


_ADDITIONAL_ORIGINS = {}


[docs] def register_external_data_wrapper(origin, klass): """ Register a non-Python data origin and the corresponding class wrapper. The wrapper class constructor must take two arguments (raw ``mts_array`` and python ``parent`` object) and return a subclass of either :py:class:`numpy.ndarray` or :py:class:`torch.Tensor`, which keeps ``parent`` alive. The :py:class:`metatensor.data.ExternalCpuArray` class should provide the right behavior for data living in CPU memory, and can serve as an example for more advanced custom arrays. :param origin: data origin name as a string, corresponding to the output of :c:func:`mts_array_t.origin` :param klass: wrapper class to use for this origin """ if not isinstance(origin, str): raise ValueError(f"origin must be a string, got {type(origin)}") global _ADDITIONAL_ORIGINS _ADDITIONAL_ORIGINS[_register_origin(origin)] = klass
def mts_array_to_python_array(mts_array, parent=None): """Convert a raw mts_array to a Python ``Array``. Either the underlying array was allocated by Python, and the Python object is directly returned; or the underlying array was not allocated by Python, and additional origins are searched for a suitable Python wrapper class. """ origin = data_origin(mts_array) if _is_python_origin(origin): return _KNOWN_ARRAY_WRAPPERS[mts_array.ptr].array elif origin in _ADDITIONAL_ORIGINS: return _ADDITIONAL_ORIGINS[origin](mts_array, parent=parent) else: raise ValueError( f"unable to handle data coming from '{data_origin_name(origin)}', " "you should maybe register a new array wrapper with metatensor" ) def mts_array_was_allocated_by_python(mts_array): """Check if a given mts_array was allocated by Python""" return _is_python_origin(data_origin(mts_array)) def _is_python_origin(origin): return origin in [_origin_numpy(), _origin_pytorch()] def data_origin(mts_array): """Get the data origin of an mts_array""" origin = mts_data_origin_t() status = mts_array.origin(mts_array.ptr, origin) check_status(status) return origin.value def data_origin_name(origin): """Get the name of the data origin of an mts_array""" from .._c_lib import _get_library lib = _get_library() return _call_with_growing_buffer( lambda buffer, bufflen: lib.mts_get_data_origin(origin, buffer, bufflen) ) # ============================================================================ # class ExternalCpuArray(np.ndarray): """ Small wrapper class around ``np.ndarray``, adding a reference to a parent Python object that actually owns the memory used inside the array. This prevents the parent from being garbage collected while the ndarray is still alive, thus preventing use-after-free. This is intended to be used to wrap Rust-owned memory inside a numpy's ndarray, while making sure the memory owner is kept around for long enough. """ def __new__(cls, mts_array: mts_array_t, parent: Any): """ :param mts_array: raw array to wrap in a Python-compatible class :param parent: owner of the raw array, we will keep a reference to this python object """ shape_ptr = ctypes.POINTER(c_uintptr_t)() shape_count = c_uintptr_t() status = mts_array.shape(mts_array.ptr, shape_ptr, shape_count) check_status(status) shape = [] for i in range(shape_count.value): shape.append(shape_ptr[i]) # Use as_dlpack to get data pointer and dtype dl_managed_ptr = ctypes.POINTER(DLManagedTensorVersioned)() device = DLDevice(device_type=DLDeviceType.kDLCPU, device_id=0) version = DLPackVersion(major=1, minor=0) status = mts_array.as_dlpack( mts_array.ptr, ctypes.byref(dl_managed_ptr), device, None, version, ) check_status(status) array = np.from_dlpack(DLPackArray(dl_managed_ptr)) obj = array.view(cls) # keep a reference to the parent object (if any) to prevent it from # being garbage-collected too early. obj._parent = parent # prevent the DLPack tensor from being freed while we hold a view obj._dl_managed_ptr = dl_managed_ptr return obj def __array_finalize__(self, obj): # keep the parent around when creating sub-views of this array self._parent = getattr(obj, "_parent", None) self._dl_managed_ptr = getattr(obj, "_dl_managed_ptr", None) def __array_wrap__(self, new, context=None, return_scalar=False): self_ptr = self.ctypes.data self_size = self.nbytes new_ptr = new.ctypes.data if self_ptr <= new_ptr <= self_ptr + self_size: # if the new array is a view inside memory owned by self, wrap it in # a ExternalCpuArray return super().__array_wrap__(new) else: # return the ndarray straight away return np.asarray(new) class ExternalCudaArray: """ Factory that wraps non-Python data on a CUDA device as a ``torch.Tensor`` via DLPack, keeping a reference to a parent Python object to prevent use-after-free. This is the CUDA counterpart to :py:class:`ExternalCpuArray`, intended for data that lives in CUDA memory. Requires PyTorch. For CUDA data (``device_type=2``), we go through CuPy when available for true zero-copy import, then convert to a ``torch.Tensor``. If CuPy is not installed we fall back to ``torch.from_dlpack`` directly. """ def __new__( cls, mts_array: mts_array_t, parent: Any, *, device_type: int = 2, device_id: int = 0, ): """ :param mts_array: raw array to wrap in a Python-compatible class :param parent: owner of the raw array, we will keep a reference to this python object :param device_type: DLPack device type (default: 2 = kDLCUDA) :param device_id: device index (default: 0) """ try: import torch # noqa: F811 except ImportError as e: raise ImportError( "ExternalCudaArray requires PyTorch; " "install it with `pip install torch`" ) from e dl_managed_ptr = ctypes.POINTER(DLManagedTensorVersioned)() device = DLDevice(device_type=device_type, device_id=device_id) version = DLPackVersion(major=1, minor=0) status = mts_array.as_dlpack( mts_array.ptr, ctypes.byref(dl_managed_ptr), device, None, version, ) check_status(status) dlpack_array = DLPackArray(dl_managed_ptr) # For CUDA data, prefer CuPy for true zero-copy import of external # GPU memory, then convert to a torch.Tensor (also zero-copy via # __cuda_array_interface__). tensor = None if device_type == 2: # kDLCUDA try: import cupy cupy_array = cupy.from_dlpack(dlpack_array) tensor = torch.as_tensor(cupy_array) except ImportError: pass if tensor is None: try: tensor = torch.from_dlpack(dlpack_array) except RuntimeError: # Older PyTorch (< 2.4) doesn't understand versioned # DLPack capsules. Fall back to an unversioned capsule. unversioned = wrap_versioned_as_unversioned(dl_managed_ptr) tensor = torch.from_dlpack(DLPackArray(unversioned)) # keep a reference to the parent object to prevent it from being # garbage-collected while the tensor is alive tensor._parent = parent return tensor def _call_with_growing_buffer(callback, initial=1024): bufflen = initial while True: buffer = ctypes.create_string_buffer(bufflen) try: callback(buffer, bufflen) break except MetatensorError as e: if e.status == mts_status_t.MTS_BUFFER_SIZE_ERROR: # grow the buffer and retry bufflen *= 2 else: raise return buffer.value.decode("utf8")