Skip to main content

metatensor/io/
labels.rs

1use std::ffi::CString;
2
3use crate::errors::{check_status, check_ptr};
4use crate::{Labels, Error};
5
6use super::realloc_vec;
7
8/// Load previously saved `Labels` from the file at the given path.
9pub fn load_labels(path: impl AsRef<std::path::Path>) -> Result<Labels, Error> {
10    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
11    let path = CString::new(path).expect("this path contains a NULL byte");
12
13    let ptr = unsafe { crate::c_api::mts_labels_load(path.as_ptr()) };
14    check_ptr(ptr)?;
15
16    return Ok(unsafe { Labels::from_raw(ptr) });
17}
18
19/// Load previously saved `Labels` from an in-memory `buffer`.
20pub fn load_labels_buffer(buffer: &[u8]) -> Result<Labels, Error> {
21    let ptr = unsafe {
22        crate::c_api::mts_labels_load_buffer(buffer.as_ptr(), buffer.len())
23    };
24    check_ptr(ptr)?;
25
26    return Ok(unsafe { Labels::from_raw(ptr) });
27}
28
29/// Save the given `Labels` to a file.
30///
31/// If the file already exists, it is overwritten. The recommended file extension
32/// when saving data is `.mts`, to prevent confusion with generic `.npz`.
33pub fn save_labels(path: impl AsRef<std::path::Path>, labels: &Labels) -> Result<(), Error> {
34    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
35    let path = CString::new(path).expect("this path contains a NULL byte");
36
37    unsafe {
38        check_status(crate::c_api::mts_labels_save(path.as_ptr(), labels.as_mts_labels_t()))
39    }
40}
41
42
43/// Save the given `labels` to an in-memory `buffer`.
44///
45/// This function will grow the buffer as required to fit the labels.
46pub fn save_labels_buffer(labels: &Labels, buffer: &mut Vec<u8>) -> Result<(), Error> {
47    let mut buffer_ptr = buffer.as_mut_ptr();
48    let mut buffer_count = buffer.len();
49
50    unsafe {
51        check_status(crate::c_api::mts_labels_save_buffer(
52            &mut buffer_ptr,
53            &mut buffer_count,
54            (buffer as *mut Vec<u8>).cast(),
55            Some(realloc_vec),
56            labels.as_mts_labels_t(),
57        ))?;
58    }
59
60    buffer.resize(buffer_count, 0);
61
62    Ok(())
63}