Skip to main content

metatensor/io/
mod.rs

1//! Input/Output facilities for storing [`crate::TensorMap`] and
2//! [`crate::Labels`] on disk
3
4use std::os::raw::c_void;
5
6use dlpk::sys::{DLDataType, DLDataTypeCode};
7
8use crate::c_api::{MTS_SUCCESS, mts_array_t, mts_status_t};
9use crate::MtsArray;
10
11mod tensor;
12pub use self::tensor::{load, save, load_buffer, save_buffer};
13
14mod block;
15pub use self::block::{load_block, load_block_buffer, save_block, save_block_buffer};
16
17mod labels;
18pub use self::labels::{load_labels, load_labels_buffer, save_labels, save_labels_buffer};
19
20
21/// Implementation of realloc for `Vec<u8>`, used in `save_buffer`
22unsafe extern "C" fn realloc_vec(user_data: *mut c_void, _ptr: *mut u8, new_size: usize) -> *mut u8 {
23    let mut result = std::ptr::null_mut();
24    let unwind_wrapper = std::panic::AssertUnwindSafe(&mut result);
25
26    let status = crate::errors::catch_unwind(move || {
27        let vector = &mut *user_data.cast::<Vec<u8>>();
28        vector.resize(new_size, 0);
29
30        // force the closure to capture the full unwind_wrapper, not just
31        // unwind_wrapper.0
32        let _ = &unwind_wrapper;
33        *(unwind_wrapper.0) = vector.as_mut_ptr();
34
35        Ok(())
36    });
37
38    if status != MTS_SUCCESS {
39        return std::ptr::null_mut();
40    }
41
42    return result;
43}
44
45/// Create a typed `ndarray::Array<T>` and box it as `dyn Array`.
46macro_rules! create_typed_array {
47    ($shape:expr, $c_array:expr, $T:ty) => {{
48        let array = ndarray::Array::<$T, _>::from_elem($shape, <$T>::default());
49        std::convert::Into::<MtsArray>::into(array)
50    }};
51}
52
53/// callback used to create `ndarray::ArcArray` when loading a `TensorMap`
54unsafe extern "C" fn create_ndarray(
55    shape_ptr: *const usize,
56    shape_count: usize,
57    dtype: DLDataType,
58    c_array: *mut mts_array_t,
59) -> mts_status_t {
60    crate::errors::catch_unwind(|| {
61        assert!(shape_count != 0);
62        let shape = std::slice::from_raw_parts(shape_ptr, shape_count);
63
64        if dtype.lanes != 1 {
65            return Err(crate::Error {
66                code: None,
67                message: format!(
68                    "unsupported dtype in create_ndarray: lanes={} (expected 1)",
69                    dtype.lanes
70                ),
71            });
72        }
73
74        let array = match (dtype.code, dtype.bits) {
75            (DLDataTypeCode::kDLFloat, 32) => create_typed_array!(shape, c_array, f32),
76            (DLDataTypeCode::kDLFloat, 64) => create_typed_array!(shape, c_array, f64),
77            (DLDataTypeCode::kDLInt, 8) => create_typed_array!(shape, c_array, i8),
78            (DLDataTypeCode::kDLInt, 16) => create_typed_array!(shape, c_array, i16),
79            (DLDataTypeCode::kDLInt, 32) => create_typed_array!(shape, c_array, i32),
80            (DLDataTypeCode::kDLInt, 64) => create_typed_array!(shape, c_array, i64),
81            (DLDataTypeCode::kDLUInt, 8) => create_typed_array!(shape, c_array, u8),
82            (DLDataTypeCode::kDLUInt, 16) => create_typed_array!(shape, c_array, u16),
83            (DLDataTypeCode::kDLUInt, 32) => create_typed_array!(shape, c_array, u32),
84            (DLDataTypeCode::kDLUInt, 64) => create_typed_array!(shape, c_array, u64),
85            (DLDataTypeCode::kDLBool, 8) => create_typed_array!(shape, c_array, bool),
86            (DLDataTypeCode::kDLFloat, 16) => create_typed_array!(shape, c_array, half::f16),
87            (DLDataTypeCode::kDLComplex, 64) => create_typed_array!(shape, c_array, [f32; 2]),
88            (DLDataTypeCode::kDLComplex, 128) => create_typed_array!(shape, c_array, [f64; 2]),
89            _ => {
90                return Err(crate::Error {
91                    code: None,
92                    message: format!(
93                        "unsupported dtype in create_ndarray: code={:?} bits={}",
94                        dtype.code, dtype.bits
95                    ),
96                });
97            }
98        };
99
100        *c_array = array.into_raw();
101
102        Ok(())
103    })
104}