1use std::os::raw::c_void;
2
3use once_cell::sync::Lazy;
4
5use dlpk::sys::{DLDevice, DLManagedTensorVersioned, DLPackVersion, DLDataType};
6use dlpk::DLPackTensor;
7
8use crate::errors::Error;
9use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t, mts_status_t};
10
11use super::MtsArray;
12
13pub trait Array: std::any::Any + Send + Sync {
19 fn as_any(&self) -> &dyn std::any::Any;
21
22 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
24
25 fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array>;
32
33 fn copy(&self, device: DLDevice) -> Box<dyn Array>;
38
39 fn shape(&self) -> Vec<usize>;
42
43 fn reshape(&mut self, shape: &[usize]);
45
46 fn swap_axes(&mut self, axis_1: usize, axis_2: usize);
48
49 fn move_data(
64 &mut self,
65 input: &dyn Array,
66 movements: &[mts_data_movement_t],
67 );
68
69 fn device(&self) -> DLDevice;
73
74 fn dtype(&self) -> DLDataType;
80
81 fn as_dlpack(
84 &self,
85 device: DLDevice,
86 stream: Option<i64>,
87 max_version: DLPackVersion
88 ) -> Result<DLPackTensor, Error>;
89
90 #[allow(clippy::wrong_self_convention)]
93 fn from_dlpack(&self, dl_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error>;
94}
95
96pub (super) struct RustArray {
97 impl_: Box<dyn Array>,
98 shape: Vec<usize>,
99}
100
101impl std::ops::Deref for RustArray {
102 type Target = dyn Array;
103
104 fn deref(&self) -> &Self::Target {
105 &*self.impl_
106 }
107}
108
109impl std::ops::DerefMut for RustArray {
110 fn deref_mut(&mut self) -> &mut Self::Target {
111 &mut *self.impl_
112 }
113}
114
115impl From<Box<dyn Array>> for MtsArray {
116 fn from(value: Box<dyn Array>) -> Self {
117 let shape = value.shape();
118 let array = RustArray {
119 impl_: value,
120 shape,
121 };
122
123 let raw = mts_array_t {
124 ptr: Box::into_raw(Box::new(array)).cast(),
125 origin: Some(rust_array_origin),
126 device: Some(rust_array_device),
127 dtype: Some(rust_array_dtype),
128 as_dlpack: Some(rust_array_as_dlpack),
129 from_dlpack: Some(rust_array_from_dlpack),
130 shape: Some(rust_array_shape),
131 reshape: Some(rust_array_reshape),
132 swap_axes: Some(rust_array_swap_axes),
133 create: Some(rust_array_create),
134 copy: Some(rust_array_copy),
135 destroy: Some(rust_array_destroy),
136 move_data: Some(rust_array_move_data),
137 };
138
139 return MtsArray::from_raw(raw);
140 }
141}
142
143impl<T> From<T> for MtsArray where T: Array + 'static {
144 fn from(value: T) -> Self {
145 let boxed = Box::new(value) as Box<dyn Array>;
146 return MtsArray::from(boxed);
147 }
148}
149
150macro_rules! check_pointers {
151 ($pointer: ident) => {
152 if $pointer.is_null() {
153 panic!(
154 "got invalid NULL pointer for {} at {}:{}",
155 stringify!($pointer), file!(), line!()
156 );
157 }
158 };
159 ($($pointer: ident),* $(,)?) => {
160 $(check_pointers!($pointer);)*
161 }
162}
163
164pub(super) static RUST_DATA_ORIGIN: Lazy<mts_data_origin_t> = Lazy::new(|| {
165 super::origin::register_data_origin("RustArray".into()).expect("failed to register a new origin")
166});
167
168unsafe extern "C" fn rust_array_origin(
171 array: *const c_void,
172 origin: *mut mts_data_origin_t
173) -> mts_status_t {
174 crate::errors::catch_unwind(|| {
175 check_pointers!(array, origin);
176 *origin = *RUST_DATA_ORIGIN;
177
178 Ok(())
179 })
180}
181
182unsafe extern "C" fn rust_array_device(
184 array: *const c_void,
185 device: *mut DLDevice,
186) -> mts_status_t {
187 crate::errors::catch_unwind(|| {
188 check_pointers!(array, device);
189 let array = array.cast::<RustArray>();
190 *device = (*array).impl_.device();
191
192 Ok(())
193 })
194}
195
196unsafe extern "C" fn rust_array_dtype(
198 array: *const c_void,
199 dtype: *mut DLDataType,
200) -> mts_status_t {
201 crate::errors::catch_unwind(|| {
202 check_pointers!(array, dtype);
203 let array = array.cast::<RustArray>();
204 *dtype = (*array).impl_.dtype();
205
206 Ok(())
207 })
208}
209
210unsafe extern "C" fn rust_array_shape(
212 array: *const c_void,
213 shape: *mut *const usize,
214 shape_count: *mut usize,
215) -> mts_status_t {
216 crate::errors::catch_unwind(|| {
217 check_pointers!(array, shape, shape_count);
218 let array = array.cast::<RustArray>();
219 let rust_shape = &(*array).shape;
220
221 *shape = rust_shape.as_ptr();
222 *shape_count = rust_shape.len();
223
224 Ok(())
225 })
226}
227
228#[allow(clippy::cast_possible_truncation)]
230unsafe extern "C" fn rust_array_reshape(
231 array: *mut c_void,
232 shape: *const usize,
233 shape_count: usize,
234) -> mts_status_t {
235 crate::errors::catch_unwind(|| {
236 check_pointers!(array);
237 let array = array.cast::<RustArray>();
238
239 let shape = if shape_count == 0 {
240 &[]
241 } else {
242 check_pointers!(shape);
243 std::slice::from_raw_parts(shape, shape_count)
244 };
245
246 (*array).impl_.reshape(shape);
247 (*array).shape = shape.to_vec();
248
249 Ok(())
250 })
251}
252
253#[allow(clippy::cast_possible_truncation)]
255unsafe extern "C" fn rust_array_swap_axes(
256 array: *mut c_void,
257 axis_1: usize,
258 axis_2: usize,
259) -> mts_status_t {
260 crate::errors::catch_unwind(|| {
261 check_pointers!(array);
262 let array = array.cast::<RustArray>();
263 (*array).impl_.swap_axes(axis_1, axis_2);
264 (*array).shape.swap(axis_1, axis_2);
265
266 Ok(())
267 })
268}
269
270#[allow(clippy::cast_possible_truncation)]
272unsafe extern "C" fn rust_array_create(
273 array: *const c_void,
274 shape: *const usize,
275 shape_count: usize,
276 fill_value: mts_array_t,
277 array_storage: *mut mts_array_t,
278) -> mts_status_t {
279 crate::errors::catch_unwind(|| {
280 check_pointers!(array, array_storage);
281 let array = array.cast::<RustArray>();
282
283 let shape = if shape_count == 0 {
284 &[]
285 } else {
286 check_pointers!(shape);
287 std::slice::from_raw_parts(shape, shape_count)
288 };
289
290 let new_array = (*array).impl_.create(shape, MtsArray::from_raw(fill_value));
291 let new_array = MtsArray::from(new_array);
292
293 *array_storage = new_array.into_raw();
294
295 Ok(())
296 })
297}
298
299unsafe extern "C" fn rust_array_copy(
301 array: *const c_void,
302 device: DLDevice,
303 new_array: *mut mts_array_t
304) -> mts_status_t {
305 crate::errors::catch_unwind(|| {
306 check_pointers!(array, new_array);
307 let array = array.cast::<RustArray>();
308
309 let copy = (*array).impl_.copy(device);
310 let copy = MtsArray::from(copy);
311 *new_array = copy.into_raw();
312
313 Ok(())
314 })
315}
316
317unsafe extern "C" fn rust_array_destroy(
319 array: *mut c_void,
320) {
321 if !array.is_null() {
322 let array = array.cast::<RustArray>();
323 let boxed = Box::from_raw(array);
324 std::mem::drop(boxed);
325 }
326}
327
328#[allow(clippy::cast_possible_truncation)]
330unsafe extern "C" fn rust_array_move_data(
331 output: *mut c_void,
332 input: *const c_void,
333 movements: *const mts_data_movement_t,
334 movements_count: usize,
335) -> mts_status_t {
336 crate::errors::catch_unwind(|| {
337 check_pointers!(output, input);
338 let output = output.cast::<RustArray>();
339 let input = input.cast::<RustArray>();
340
341 let movements = if movements_count == 0 {
342 &[]
343 } else {
344 check_pointers!(movements);
345 std::slice::from_raw_parts(movements, movements_count)
346 };
347
348 (*output).impl_.move_data(&*(*input).impl_, movements);
349
350 Ok(())
351 })
352}
353
354unsafe extern "C" fn rust_array_as_dlpack(
356 array: *mut c_void,
357 dl_tensor: *mut *mut DLManagedTensorVersioned,
358 device: DLDevice,
359 stream: *const i64,
360 max_version: DLPackVersion,
361) -> mts_status_t {
362 crate::errors::catch_unwind(|| {
363 check_pointers!(array, dl_tensor);
364 let array = array.cast::<RustArray>();
365 let stream_opt = stream.as_ref().copied();
366 let tensor = (*array).impl_.as_dlpack(device, stream_opt, max_version)?;
367
368 *dl_tensor = tensor.into_raw().as_ptr();
369 Ok(())
370 })
371}
372
373unsafe extern "C" fn rust_array_from_dlpack(
375 array: *const c_void,
376 dl_tensor: *mut DLManagedTensorVersioned,
377 new_array: *mut mts_array_t,
378) -> mts_status_t {
379 crate::errors::catch_unwind(|| {
380 check_pointers!(array, dl_tensor, new_array);
381 let array = array.cast::<RustArray>();
382 let dl_tensor = DLPackTensor::from_ptr(dl_tensor);
383
384 let new_rust_array = (*array).impl_.from_dlpack(dl_tensor)?;
385
386 *new_array = MtsArray::from(new_rust_array).into_raw();
387
388 Ok(())
389 })
390}