metatensor/io/
tensor.rs

1use std::ffi::CString;
2
3use crate::errors::{check_status, check_ptr};
4use crate::{TensorMap, Error};
5
6use super::{realloc_vec, create_ndarray};
7
8/// Load the serialized tensor map from the given path.
9///
10/// `TensorMap` are serialized using numpy's NPZ format, i.e. a ZIP file
11/// without compression (storage method is STORED), where each file is stored as
12/// a `.npy` array. Both the ZIP and NPY format are well documented:
13///
14/// - ZIP: <https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT>
15/// - NPY: <https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html>
16///
17/// We add other restriction on top of these formats when saving/loading data.
18/// First, `Labels` instances are saved as structured array, see the `labels`
19/// module for more information. Only 32-bit integers are supported for Labels,
20/// and only 64-bit floats are supported for data (values and gradients).
21///
22/// Second, the path of the files in the archive also carry meaning. The keys of
23/// the `TensorMap` are stored in `/keys.npy`, and then different blocks are
24/// stored as
25///
26/// ```bash
27/// /  blocks / <block_id>  / values / samples.npy
28///                         / values / components  / 0.npy
29///                                                / <...>.npy
30///                                                / <n_components>.npy
31///                         / values / properties.npy
32///                         / values / data.npy
33///
34///                         # optional sections for gradients, one by parameter
35///                         /   gradients / <parameter> / samples.npy
36///                                                     /   components  / 0.npy
37///                                                                     / <...>.npy
38///                                                                     / <n_components>.npy
39///                                                     /   data.npy
40/// ```
41pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
42    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
43    let path = CString::new(path).expect("this path contains a NULL byte");
44
45    let ptr = unsafe {
46        crate::c_api::mts_tensormap_load(
47            path.as_ptr(),
48            Some(create_ndarray)
49        )
50    };
51
52    check_ptr(ptr)?;
53
54    return Ok(unsafe { TensorMap::from_raw(ptr) });
55}
56
57/// Load a serialized `TensorMap` from a `buffer`.
58///
59/// See the [`load`] function for more information on the data format.
60pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
61    let ptr = unsafe {
62        crate::c_api::mts_tensormap_load_buffer(
63            buffer.as_ptr(),
64            buffer.len(),
65            Some(create_ndarray)
66        )
67    };
68
69    check_ptr(ptr)?;
70
71    return Ok(unsafe { TensorMap::from_raw(ptr) });
72}
73
74/// Save the given tensor to a file.
75///
76/// If the file already exists, it is overwritten. The recomended file extension
77/// when saving data is `.mts`, to prevent confusion with generic `.npz`.
78///
79/// The format used is documented in the [`load`] function, and consists of a
80/// zip archive containing NPY files.
81pub fn save(path: impl AsRef<std::path::Path>, tensor: &TensorMap) -> Result<(), Error> {
82    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
83    let path = CString::new(path).expect("this path contains a NULL byte");
84
85    unsafe {
86        check_status(crate::c_api::mts_tensormap_save(path.as_ptr(), tensor.ptr))
87    }
88}
89
90
91/// Save the given `tensor` to an in-memory `buffer`.
92///
93/// This function will grow the buffer as required to fit the whole tensor.
94pub fn save_buffer(tensor: &TensorMap, buffer: &mut Vec<u8>) -> Result<(), Error> {
95    let mut buffer_ptr = buffer.as_mut_ptr();
96    let mut buffer_count = buffer.len();
97
98    unsafe {
99        check_status(crate::c_api::mts_tensormap_save_buffer(
100            &mut buffer_ptr,
101            &mut buffer_count,
102            (buffer as *mut Vec<u8>).cast(),
103            Some(realloc_vec),
104            tensor.ptr,
105        ))?;
106    }
107
108    buffer.resize(buffer_count, 0);
109
110    Ok(())
111}