1use std::os::raw::c_void;
5
6use dlpk::sys::{DLDataType, DLDataTypeCode};
7
8use crate::c_api::{MTS_SUCCESS, mts_array_t, mts_status_t};
9use crate::MtsArray;
10
11mod tensor;
12pub use self::tensor::{load, save, load_buffer, save_buffer};
13
14mod block;
15pub use self::block::{load_block, load_block_buffer, save_block, save_block_buffer};
16
17mod labels;
18pub use self::labels::{load_labels, load_labels_buffer, save_labels, save_labels_buffer};
19
20
21unsafe extern "C" fn realloc_vec(user_data: *mut c_void, _ptr: *mut u8, new_size: usize) -> *mut u8 {
23 let mut result = std::ptr::null_mut();
24 let unwind_wrapper = std::panic::AssertUnwindSafe(&mut result);
25
26 let status = crate::errors::catch_unwind(move || {
27 let vector = &mut *user_data.cast::<Vec<u8>>();
28 vector.resize(new_size, 0);
29
30 let _ = &unwind_wrapper;
33 *(unwind_wrapper.0) = vector.as_mut_ptr();
34
35 Ok(())
36 });
37
38 if status != MTS_SUCCESS {
39 return std::ptr::null_mut();
40 }
41
42 return result;
43}
44
45macro_rules! create_typed_array {
47 ($shape:expr, $c_array:expr, $T:ty) => {{
48 let array = ndarray::Array::<$T, _>::from_elem($shape, <$T>::default());
49 std::convert::Into::<MtsArray>::into(array)
50 }};
51}
52
53unsafe extern "C" fn create_ndarray(
55 shape_ptr: *const usize,
56 shape_count: usize,
57 dtype: DLDataType,
58 c_array: *mut mts_array_t,
59) -> mts_status_t {
60 crate::errors::catch_unwind(|| {
61 assert!(shape_count != 0);
62 let shape = std::slice::from_raw_parts(shape_ptr, shape_count);
63
64 if dtype.lanes != 1 {
65 return Err(crate::Error {
66 code: None,
67 message: format!(
68 "unsupported dtype in create_ndarray: lanes={} (expected 1)",
69 dtype.lanes
70 ),
71 });
72 }
73
74 let array = match (dtype.code, dtype.bits) {
75 (DLDataTypeCode::kDLFloat, 32) => create_typed_array!(shape, c_array, f32),
76 (DLDataTypeCode::kDLFloat, 64) => create_typed_array!(shape, c_array, f64),
77 (DLDataTypeCode::kDLInt, 8) => create_typed_array!(shape, c_array, i8),
78 (DLDataTypeCode::kDLInt, 16) => create_typed_array!(shape, c_array, i16),
79 (DLDataTypeCode::kDLInt, 32) => create_typed_array!(shape, c_array, i32),
80 (DLDataTypeCode::kDLInt, 64) => create_typed_array!(shape, c_array, i64),
81 (DLDataTypeCode::kDLUInt, 8) => create_typed_array!(shape, c_array, u8),
82 (DLDataTypeCode::kDLUInt, 16) => create_typed_array!(shape, c_array, u16),
83 (DLDataTypeCode::kDLUInt, 32) => create_typed_array!(shape, c_array, u32),
84 (DLDataTypeCode::kDLUInt, 64) => create_typed_array!(shape, c_array, u64),
85 (DLDataTypeCode::kDLBool, 8) => create_typed_array!(shape, c_array, bool),
86 (DLDataTypeCode::kDLFloat, 16) => create_typed_array!(shape, c_array, half::f16),
87 (DLDataTypeCode::kDLComplex, 64) => create_typed_array!(shape, c_array, [f32; 2]),
88 (DLDataTypeCode::kDLComplex, 128) => create_typed_array!(shape, c_array, [f64; 2]),
89 _ => {
90 return Err(crate::Error {
91 code: None,
92 message: format!(
93 "unsupported dtype in create_ndarray: code={:?} bits={}",
94 dtype.code, dtype.bits
95 ),
96 });
97 }
98 };
99
100 *c_array = array.into_raw();
101
102 Ok(())
103 })
104}