metatensor/io/
block.rs

1use std::ffi::CString;
2
3use crate::errors::{check_ptr, check_status};
4use crate::{Error, TensorBlock, TensorBlockRef};
5
6use super::{realloc_vec, create_ndarray};
7
8/// Load previously saved `TensorBlock` from the file at the given path.
9pub fn load_block(path: impl AsRef<std::path::Path>) -> Result<TensorBlock, Error> {
10    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
11    let path = CString::new(path).expect("this path contains a NULL byte");
12
13    let ptr = unsafe {
14        crate::c_api::mts_block_load(
15            path.as_ptr(),
16            Some(create_ndarray)
17        )
18    };
19
20    check_ptr(ptr)?;
21
22    return Ok(unsafe { TensorBlock::from_raw(ptr) });
23}
24
25/// Load a serialized `TensorBlock` from a `buffer`.
26pub fn load_block_buffer(buffer: &[u8]) -> Result<TensorBlock, Error> {
27    let ptr = unsafe {
28        crate::c_api::mts_block_load_buffer(
29            buffer.as_ptr(),
30            buffer.len(),
31            Some(create_ndarray)
32        )
33    };
34
35    check_ptr(ptr)?;
36
37    return Ok(unsafe { TensorBlock::from_raw(ptr) });
38}
39
40/// Save the given `block` to a file.
41///
42/// If the file already exists, it is overwritten. The recomended file extension
43/// when saving data is `.mts`, to prevent confusion with generic `.npz`.
44pub fn save_block(path: impl AsRef<std::path::Path>, block: TensorBlockRef) -> Result<(), Error> {
45    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
46    let path = CString::new(path).expect("this path contains a NULL byte");
47
48    unsafe {
49        check_status(crate::c_api::mts_block_save(path.as_ptr(), block.as_ptr()))
50    }
51}
52
53
54/// Save the given `block` to an in-memory `buffer`.
55///
56/// This function will grow the buffer as required to fit the data.
57pub fn save_block_buffer(block: TensorBlockRef, buffer: &mut Vec<u8>) -> Result<(), Error> {
58    let mut buffer_ptr = buffer.as_mut_ptr();
59    let mut buffer_count = buffer.len();
60
61    unsafe {
62        check_status(crate::c_api::mts_block_save_buffer(
63            &mut buffer_ptr,
64            &mut buffer_count,
65            (buffer as *mut Vec<u8>).cast(),
66            Some(realloc_vec),
67            block.as_ptr(),
68        ))?;
69    }
70
71    buffer.resize(buffer_count, 0);
72
73    Ok(())
74}