Note
Go to the end to download the full example code.
Working with arbitrary types and devices¶
Metatensor uses the DLPack standard for
data interchange. This means that data stored inside a TensorMap is not
restricted to 64-bit floating point numbers on CPU – it can be any numeric
type on any device, and can cross language boundaries without copies.
This tutorial shows how this works in practice: storing integer, float16, and complex data; moving between numpy and torch; and round-tripping through metatensor’s Rust serialization layer without losing type information.
Let’s start with the necessary imports.
import os
import pathlib
import tempfile
import numpy as np
import metatensor as mts
from metatensor import Labels, TensorBlock, TensorMap
Storing any numeric type¶
Metatensor can store any numeric type inside a TensorBlock — integers,
half-precision floats, booleans, or complex numbers — and preserves the type
faithfully through serialization and cross-language boundaries.
int32_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
block_i32 = TensorBlock(
values=int32_data,
samples=Labels(["sample"], np.array([[0], [1]], dtype=np.int32)),
components=[],
properties=Labels(["property"], np.array([[0], [1], [2]], dtype=np.int32)),
)
print("int32 block dtype:", block_i32.values.dtype) # int32
print("int32 block values:\n", block_i32.values)
int32 block dtype: int32
int32 block values:
[[1 2 3]
[4 5 6]]
The same works for float16 (half precision), which is common in ML inference:
f16_data = np.array([[0.5, 1.0], [1.5, 2.0]], dtype=np.float16)
block_f16 = TensorBlock(
values=f16_data,
samples=Labels.range("sample", 2),
components=[],
properties=Labels.range("property", 2),
)
print("float16 block dtype:", block_f16.values.dtype) # float16
print("float16 block values:\n", block_f16.values)
float16 block dtype: float16
float16 block values:
[[0.5 1. ]
[1.5 2. ]]
And for complex numbers, used in quantum chemistry and signal processing:
complex_data = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex128)
block_complex = TensorBlock(
values=complex_data,
samples=Labels.range("sample", 1),
components=[],
properties=Labels.range("property", 2),
)
print("complex128 block dtype:", block_complex.values.dtype) # complex128
print("complex128 block values:\n", block_complex.values)
complex128 block dtype: complex128
complex128 block values:
[[1.+2.j 3.+4.j]]
Serialization preserves types¶
When you save a TensorMap to disk, the dtype of every block’s values is
preserved exactly. The path through the code is:
Python array → (DLPack) → C API → Rust → .npy inside a .mts file
and on loading, the reverse. At every boundary crossing, DLPack carries the type information, so nothing is lost.
keys = Labels(["type"], np.array([[0]], dtype=np.int32))
tensor_i32 = TensorMap(keys, [block_i32.copy()])
# use_numpy=False means: Python -> DLPack -> C API -> Rust -> .npy serialization
mts.save("int32_tensor.mts", tensor_i32, use_numpy=False)
# use_numpy=True means: .npy deserialization -> numpy (preserves dtype natively)
loaded = mts.load("int32_tensor.mts", use_numpy=True)
print("Saved dtype: ", int32_data.dtype)
print("Loaded dtype:", loaded.block(0).values.dtype)
print("Values match:", np.array_equal(loaded.block(0).values, int32_data))
Saved dtype: int32
Loaded dtype: int32
Values match: True
This works for every supported type. Here is a quick round-trip test for several dtypes:
dtypes = [
np.float16,
np.float32,
np.float64,
np.int8,
np.int32,
np.int64,
np.uint8,
np.bool_,
np.complex64,
np.complex128,
]
with tempfile.TemporaryDirectory() as tmp:
for dtype in dtypes:
data = np.arange(6).reshape(2, 3).astype(dtype)
block = TensorBlock(
values=data,
samples=Labels.range("s", 2),
components=[],
properties=Labels.range("p", 3),
)
tensor = TensorMap(keys, [block])
path = os.path.join(tmp, f"{dtype.__name__}.mts")
mts.save(path, tensor, use_numpy=False)
loaded = mts.load(path, use_numpy=True)
assert loaded.block(0).values.dtype == dtype
print(f" {dtype.__name__:>12s}: round-trip OK")
float16: round-trip OK
float32: round-trip OK
float64: round-trip OK
int8: round-trip OK
int32: round-trip OK
int64: round-trip OK
uint8: round-trip OK
bool: round-trip OK
complex64: round-trip OK
complex128: round-trip OK
Numpy and Torch interoperability¶
If PyTorch is available, you can also create blocks from torch tensors. The DLPack layer handles the conversion transparently – the data is shared, not copied.
try:
import torch
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
if HAS_TORCH:
# Create a torch tensor with float32 dtype
torch_data = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
block_torch = TensorBlock(
values=torch_data,
samples=Labels.range("sample", 2),
components=[],
properties=Labels.range("property", 2),
)
print("Torch block dtype:", block_torch.values.dtype)
print("Torch block values:\n", block_torch.values)
print("Is a torch.Tensor:", isinstance(block_torch.values, torch.Tensor))
Torch block dtype: torch.float32
Torch block values:
tensor([[1., 2.],
[3., 4.]])
Is a torch.Tensor: True
You can convert between numpy and torch backends using .to():
if HAS_TORCH:
# torch float32 -> numpy float32
as_numpy = block_torch.to(arrays="numpy")
print("Converted to numpy:", type(as_numpy.values).__name__, as_numpy.values.dtype)
# numpy float16 -> torch float16
as_torch = block_f16.to(arrays="torch")
print("Converted to torch:", type(as_torch.values).__name__, as_torch.values.dtype)
Converted to numpy: ndarray float32
Converted to torch: Tensor torch.float16
GPU tensors (if available)¶
With PyTorch, metatensor can hold data on GPU. The DLPack metadata carries the device information, so metatensor knows where the data lives and can handle it correctly.
if HAS_TORCH and torch.cuda.is_available():
gpu_data = torch.randn(3, 4, dtype=torch.float32, device="cuda")
block_gpu = TensorBlock(
values=gpu_data,
samples=Labels.range("sample", 3),
components=[],
properties=Labels.range("property", 4),
)
print("GPU block device:", block_gpu.device)
print("GPU block dtype: ", block_gpu.dtype)
# Move to CPU and convert to numpy in one step
block_cpu = block_gpu.to("cpu").to(arrays="numpy")
print("After .to('cpu').to(arrays='numpy'):", type(block_cpu.values).__name__)
elif HAS_TORCH:
print("(CUDA not available, skipping GPU example)")
else:
print("(torch not available, skipping GPU example)")
(CUDA not available, skipping GPU example)
Labels and device placement¶
DLPack is not limited to block data. Labels also participate in the
DLPack ecosystem via the mts_array_t values array stored inside each
Labels object. When labels are moved or directly constructed on a
device (e.g. via Labels.to("cuda") in the torch backend), the
underlying values tensor stays on that device. The device() query
on the values array tells
callers where the label data lives, and as_dlpack() exports it without
an implicit copy.
This matters for GPU workflows: label metadata (sample indices, property
indices) can be kept on the same device as the block data, avoiding
unnecessary CPU round-trips during operations like keys_to_samples or
set operations (union, intersection, difference).
How DLPack enables this¶
Under the hood, every data array in metatensor implements a single interface:
as_dlpack. This is part of the DLPack standard, which encodes:
A pointer to the raw data buffer
The dtype (integer, float, complex, bool) and bit width (8, 16, 32, 64)
The device type and ID (CPU, CUDA, ROCm, Metal, etc.)
The shape and strides of the array
This metadata travels across every FFI boundary – Python to C, C to Rust, Rust to C++ – without losing information. The array data itself is never copied during these transitions; only the thin DLPack descriptor is passed.
This is what makes the following pipeline work for any type and device:
Python (numpy/torch)
│ ── DLPack ──► C API (mts_array_t.as_dlpack)
│ │ ── DLPack ──► Rust (serialization / operations)
│ │ │ ── DLPack ──► .npy files
│ │ ◄── DLPack ── Rust
│ ◄── DLPack ── C API
Python
Each arrow is a zero-copy handoff. The type and device information is preserved
at every step because DLPack’s DLDataType and DLDevice structs are
part of the tensor descriptor.
Cleanup
pathlib.Path("int32_tensor.mts").unlink(missing_ok=True)
Total running time of the script: (0 minutes 0.020 seconds)