metatensor/io/tensor.rs
1use std::ffi::CString;
2
3use crate::errors::{check_status, check_ptr};
4use crate::{TensorMap, Error};
5
6use super::{realloc_vec, create_ndarray};
7
8/// Load the serialized tensor map from the given path.
9///
10/// `TensorMap` are serialized using numpy's NPZ format, i.e. a ZIP file
11/// without compression (storage method is STORED), where each file is stored as
12/// a `.npy` array. Both the ZIP and NPY format are well documented:
13///
14/// - ZIP: <https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT>
15/// - NPY: <https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html>
16///
17/// We add other restriction on top of these formats when saving/loading data.
18/// First, `Labels` instances are saved as structured array, see the `labels`
19/// module for more information. Only 32-bit integers are supported for Labels,
20/// and only 64-bit floats are supported for data (values and gradients).
21///
22/// Second, the path of the files in the archive also carry meaning. The keys of
23/// the `TensorMap` are stored in `/keys.npy`, and then different blocks are
24/// stored as
25///
26/// ```bash
27/// / blocks / <block_id> / values / samples.npy
28/// / values / components / 0.npy
29/// / <...>.npy
30/// / <n_components>.npy
31/// / values / properties.npy
32/// / values / data.npy
33///
34/// # optional sections for gradients, one by parameter
35/// / gradients / <parameter> / samples.npy
36/// / components / 0.npy
37/// / <...>.npy
38/// / <n_components>.npy
39/// / data.npy
40/// ```
41pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
42 let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
43 let path = CString::new(path).expect("this path contains a NULL byte");
44
45 let ptr = unsafe {
46 crate::c_api::mts_tensormap_load(
47 path.as_ptr(),
48 Some(create_ndarray)
49 )
50 };
51
52 check_ptr(ptr)?;
53
54 return Ok(unsafe { TensorMap::from_raw(ptr) });
55}
56
57/// Load a serialized `TensorMap` from a `buffer`.
58///
59/// See the [`load`] function for more information on the data format.
60pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
61 let ptr = unsafe {
62 crate::c_api::mts_tensormap_load_buffer(
63 buffer.as_ptr(),
64 buffer.len(),
65 Some(create_ndarray)
66 )
67 };
68
69 check_ptr(ptr)?;
70
71 return Ok(unsafe { TensorMap::from_raw(ptr) });
72}
73
74/// Save the given tensor to a file.
75///
76/// If the file already exists, it is overwritten. The recomended file extension
77/// when saving data is `.mts`, to prevent confusion with generic `.npz`.
78///
79/// The format used is documented in the [`load`] function, and consists of a
80/// zip archive containing NPY files.
81pub fn save(path: impl AsRef<std::path::Path>, tensor: &TensorMap) -> Result<(), Error> {
82 let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
83 let path = CString::new(path).expect("this path contains a NULL byte");
84
85 unsafe {
86 check_status(crate::c_api::mts_tensormap_save(path.as_ptr(), tensor.ptr))
87 }
88}
89
90
91/// Save the given `tensor` to an in-memory `buffer`.
92///
93/// This function will grow the buffer as required to fit the whole tensor.
94pub fn save_buffer(tensor: &TensorMap, buffer: &mut Vec<u8>) -> Result<(), Error> {
95 let mut buffer_ptr = buffer.as_mut_ptr();
96 let mut buffer_count = buffer.len();
97
98 unsafe {
99 check_status(crate::c_api::mts_tensormap_save_buffer(
100 &mut buffer_ptr,
101 &mut buffer_count,
102 (buffer as *mut Vec<u8>).cast(),
103 Some(realloc_vec),
104 tensor.ptr,
105 ))?;
106 }
107
108 buffer.resize(buffer_count, 0);
109
110 Ok(())
111}