metatensor/io/
labels.rs

1use std::ffi::CString;
2
3use crate::c_api::mts_labels_t;
4use crate::errors::check_status;
5use crate::{Labels, Error};
6
7use super::realloc_vec;
8
9/// Load previously saved `Labels` from the file at the given path.
10pub fn load_labels(path: impl AsRef<std::path::Path>) -> Result<Labels, Error> {
11    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
12    let path = CString::new(path).expect("this path contains a NULL byte");
13
14    let mut labels = mts_labels_t::null();
15
16    unsafe {
17        check_status(crate::c_api::mts_labels_load(path.as_ptr(), &mut labels))?;
18    }
19
20    return Ok(unsafe { Labels::from_raw(labels) });
21}
22
23/// Load previously saved `Labels` from an in-memory `buffer`.
24pub fn load_labels_buffer(buffer: &[u8]) -> Result<Labels, Error> {
25    let mut labels = mts_labels_t::null();
26
27    unsafe {
28        check_status(crate::c_api::mts_labels_load_buffer(
29            buffer.as_ptr(),
30            buffer.len(),
31            &mut labels
32        ))?;
33    }
34
35    return Ok(unsafe { Labels::from_raw(labels) });
36}
37
38/// Save the given `Labels` to a file.
39///
40/// If the file already exists, it is overwritten. The recomended file extension
41/// when saving data is `.mts`, to prevent confusion with generic `.npz`.
42pub fn save_labels(path: impl AsRef<std::path::Path>, labels: &Labels) -> Result<(), Error> {
43    let path = path.as_ref().as_os_str().to_str().expect("this path is not valid UTF8");
44    let path = CString::new(path).expect("this path contains a NULL byte");
45
46    unsafe {
47        check_status(crate::c_api::mts_labels_save(path.as_ptr(), labels.raw))
48    }
49}
50
51
52/// Save the given `labels` to an in-memory `buffer`.
53///
54/// This function will grow the buffer as required to fit the labels.
55pub fn save_labels_buffer(labels: &Labels, buffer: &mut Vec<u8>) -> Result<(), Error> {
56    let mut buffer_ptr = buffer.as_mut_ptr();
57    let mut buffer_count = buffer.len();
58
59    unsafe {
60        check_status(crate::c_api::mts_labels_save_buffer(
61            &mut buffer_ptr,
62            &mut buffer_count,
63            (buffer as *mut Vec<u8>).cast(),
64            Some(realloc_vec),
65            labels.raw,
66        ))?;
67    }
68
69    buffer.resize(buffer_count, 0);
70
71    Ok(())
72}