Source code for metatensor.io._labels

import ctypes
import pathlib
import warnings
from typing import BinaryIO, Union

import numpy as np

from .._c_lib import _get_library
from .._labels import Labels
from .._status import check_pointer
from ._utils import _save_buffer_raw


[docs] def load_labels(file: Union[str, pathlib.Path, BinaryIO]) -> Labels: """ Load previously saved :py:class:`Labels` from the given file. :param file: file to load: this can be a string, a :py:class:`pathlib.Path` containing the path to the file to load, or a file-like object opened in binary mode. """ if isinstance(file, (str, pathlib.Path)): lib = _get_library() if isinstance(file, str): path = file.encode("utf8") elif isinstance(file, pathlib.Path): path = bytes(file) ptr = lib.mts_labels_load(path) check_pointer(ptr) return Labels._from_mts_labels_t(ptr) else: # assume we have a file-like object buffer = file.read() assert isinstance(buffer, bytes) return load_labels_buffer(buffer)
[docs] def load_labels_buffer(buffer: Union[bytes, bytearray, memoryview]) -> Labels: """ Load previously saved :py:class:`Labels` from an in-memory buffer. :param buffer: in-memory buffer containing saved :py:class:`Labels` """ lib = _get_library() array = np.frombuffer(buffer, dtype=np.uint8) ptr = lib.mts_labels_load_buffer( array.ctypes.data_as(ctypes.c_char_p), array.nbytes, ) check_pointer(ptr) return Labels._from_mts_labels_t(ptr)
def _save_labels( file: Union[str, pathlib.Path, BinaryIO], labels: Labels, ): """ Save :py:class:`Labels` to the given file. :param file: where to save the data. This can be a string, :py:class:`pathlib.Path` containing the path to the file to load, or a file-like object that should be opened in binary mode. :param labels: Labels to save """ assert isinstance(labels, Labels) lib = _get_library() if isinstance(file, str): if not file.endswith(".mts"): file += ".mts" warnings.warn( message=f"adding '.mts' extension, the file will be saved at '{file}'", stacklevel=1, ) path = file.encode("utf8") lib.mts_labels_save(path, labels._ptr) elif isinstance(file, pathlib.Path): if not file.name.endswith(".mts"): file = file.with_name(file.name + ".mts") warnings.warn( message="adding '.mts' extension," f" the file will be saved at '{file.name}'", stacklevel=1, ) path = bytes(file) lib.mts_labels_save(path, labels._ptr) else: # assume we have a file-like object buffer = _save_labels_buffer_raw(labels) file.write(buffer.raw) def _save_labels_buffer_raw(labels: Labels) -> ctypes.Array: """ Save Labels to an in-memory buffer, returning the data as a ctypes array of ``ctypes.c_char``. """ lib = _get_library() return _save_buffer_raw(lib.mts_labels_save_buffer, labels._ptr) def _labels_from_mts(data): names = data.dtype.names return Labels(names=names, values=data.view(dtype=np.int32).reshape(-1, len(names))) def _labels_to_mts(labels): dtype = [(name, np.int32) for name in labels.names] values = np.asarray(labels.values) return values.view(dtype=dtype).reshape((values.shape[0],)).copy()