Skip to main content

metatensor/data/
array.rs

1use std::os::raw::c_void;
2
3use once_cell::sync::Lazy;
4
5use dlpk::sys::{DLDevice, DLManagedTensorVersioned, DLPackVersion, DLDataType};
6use dlpk::DLPackTensor;
7
8use crate::errors::Error;
9use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t, mts_status_t};
10
11use super::MtsArray;
12
13/// The Array trait is used by metatensor to manage different kind of data array
14/// with a single API. Metatensor only knows about `Box<dyn Array>`, and
15/// manipulate the data through the functions on this trait.
16///
17/// This corresponds to the `mts_array_t` struct in metatensor-core.
18pub trait Array: std::any::Any + Send + Sync {
19    /// Get the array as a `Any` reference
20    fn as_any(&self) -> &dyn std::any::Any;
21
22    /// Get the array as a mutable `Any` reference
23    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
24
25    /// Create a new array with the same array origin, data type, and device as
26    /// the current one, but with the requested `shape`.
27    ///
28    /// The new array should be filled with the scalar value from `fill_value`,
29    /// which must be an `MtsArray` with shape `(1,)` and the same dtype as this
30    /// array.
31    fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array>;
32
33    /// Make a copy of this `array`
34    ///
35    /// The new array is expected to have the same array origin and data type,
36    /// but live on the given device.
37    fn copy(&self, device: DLDevice) -> Box<dyn Array>;
38
39    /// Get the shape of the array. This can be empty if the array has no shape
40    /// (e.g. a scalar).
41    fn shape(&self) -> Vec<usize>;
42
43    /// Change the shape of the array to the given `shape`
44    fn reshape(&mut self, shape: &[usize]);
45
46    /// Swap the axes `axis_1` and `axis_2` in this array
47    fn swap_axes(&mut self, axis_1: usize, axis_2: usize);
48
49    /// Set entries in `self` taking data from the `input` array.
50    ///
51    /// The `output` array is guaranteed to be created by calling
52    /// `mts_array_t::create` with one of the arrays in the same block or tensor
53    /// map as the `input`.
54    ///
55    /// The `movements` indicate where the data should be moved from `input` to
56    /// `output`.
57    ///
58    /// This function should copy data from `input[movements[i].sample_in, ...,
59    /// movements[i].properties_start_in + x]` to
60    /// `array[movements[i].sample_out, ..., movements[i].properties_start_out +
61    /// x]` for `i` up to `movements_count` and `x` up to
62    /// `movements[i].properties_length`. All indexes are 0-based.
63    fn move_data(
64        &mut self,
65        input: &dyn Array,
66        movements: &[mts_data_movement_t],
67    );
68
69    /// Get the device where this array's data resides.
70    ///
71    /// For CPU arrays this should return `DLDevice::cpu()`.
72    fn device(&self) -> DLDevice;
73
74    /// Get the data type of this array.
75    ///
76    /// This populates the `dtype` vtable slot for fast dtype queries.
77    /// Implementations should return the appropriate `DLDataType` for their
78    /// element type (e.g. float64 = `DLDataType { code: kDLFloat, bits: 64, lanes: 1 }`).
79    fn dtype(&self) -> DLDataType;
80
81    /// Convert the array to a `DLPack` tensor.
82    /// The returned pointer is owned by the caller (and cleaned up via its deleter).
83    fn as_dlpack(
84        &self,
85        device: DLDevice,
86        stream: Option<i64>,
87        max_version: DLPackVersion
88    ) -> Result<DLPackTensor, Error>;
89
90    /// Create a new array from a `DLPack` tensor, taking ownership of the
91    /// tensor's data.
92    #[allow(clippy::wrong_self_convention)]
93    fn from_dlpack(&self, dl_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error>;
94}
95
96pub (super) struct RustArray {
97    impl_: Box<dyn Array>,
98    shape: Vec<usize>,
99}
100
101impl std::ops::Deref for RustArray {
102    type Target = dyn Array;
103
104    fn deref(&self) -> &Self::Target {
105        &*self.impl_
106    }
107}
108
109impl std::ops::DerefMut for RustArray {
110    fn deref_mut(&mut self) -> &mut Self::Target {
111        &mut *self.impl_
112    }
113}
114
115impl From<Box<dyn Array>> for MtsArray {
116    fn from(value: Box<dyn Array>) -> Self {
117        let shape = value.shape();
118        let array = RustArray {
119            impl_: value,
120            shape,
121        };
122
123        let raw = mts_array_t {
124            ptr: Box::into_raw(Box::new(array)).cast(),
125            origin: Some(rust_array_origin),
126            device: Some(rust_array_device),
127            dtype: Some(rust_array_dtype),
128            as_dlpack: Some(rust_array_as_dlpack),
129            from_dlpack: Some(rust_array_from_dlpack),
130            shape: Some(rust_array_shape),
131            reshape: Some(rust_array_reshape),
132            swap_axes: Some(rust_array_swap_axes),
133            create: Some(rust_array_create),
134            copy: Some(rust_array_copy),
135            destroy: Some(rust_array_destroy),
136            move_data: Some(rust_array_move_data),
137        };
138
139        return MtsArray::from_raw(raw);
140    }
141}
142
143impl<T> From<T> for MtsArray where T: Array + 'static {
144    fn from(value: T) -> Self {
145        let boxed = Box::new(value) as Box<dyn Array>;
146        return MtsArray::from(boxed);
147    }
148}
149
150macro_rules! check_pointers {
151    ($pointer: ident) => {
152        if $pointer.is_null() {
153            panic!(
154                "got invalid NULL pointer for {} at {}:{}",
155                stringify!($pointer), file!(), line!()
156            );
157        }
158    };
159    ($($pointer: ident),* $(,)?) => {
160        $(check_pointers!($pointer);)*
161    }
162}
163
164pub(super) static RUST_DATA_ORIGIN: Lazy<mts_data_origin_t> = Lazy::new(|| {
165    super::origin::register_data_origin("RustArray".into()).expect("failed to register a new origin")
166});
167
168/******************************************************************************/
169/// Implementation of `mts_array_t.origin` using `RustArray`
170unsafe extern "C" fn rust_array_origin(
171    array: *const c_void,
172    origin: *mut mts_data_origin_t
173) -> mts_status_t {
174    crate::errors::catch_unwind(|| {
175        check_pointers!(array, origin);
176        *origin = *RUST_DATA_ORIGIN;
177
178        Ok(())
179    })
180}
181
182/// Implementation of `mts_array_t.device` using `RustArray`
183unsafe extern "C" fn rust_array_device(
184    array: *const c_void,
185    device: *mut DLDevice,
186) -> mts_status_t {
187    crate::errors::catch_unwind(|| {
188        check_pointers!(array, device);
189        let array = array.cast::<RustArray>();
190        *device = (*array).impl_.device();
191
192        Ok(())
193    })
194}
195
196/// Implementation of `mts_array_t.dtype` using `RustArray`
197unsafe extern "C" fn rust_array_dtype(
198    array: *const c_void,
199    dtype: *mut DLDataType,
200) -> mts_status_t {
201    crate::errors::catch_unwind(|| {
202        check_pointers!(array, dtype);
203        let array = array.cast::<RustArray>();
204        *dtype = (*array).impl_.dtype();
205
206        Ok(())
207    })
208}
209
210/// Implementation of `mts_array_t.shape` using `RustArray`
211unsafe extern "C" fn rust_array_shape(
212    array: *const c_void,
213    shape: *mut *const usize,
214    shape_count: *mut usize,
215) -> mts_status_t {
216    crate::errors::catch_unwind(|| {
217        check_pointers!(array, shape, shape_count);
218        let array = array.cast::<RustArray>();
219        let rust_shape = &(*array).shape;
220
221        *shape = rust_shape.as_ptr();
222        *shape_count = rust_shape.len();
223
224        Ok(())
225    })
226}
227
228/// Implementation of `mts_array_t.reshape` using `RustArray`
229#[allow(clippy::cast_possible_truncation)]
230unsafe extern "C" fn rust_array_reshape(
231    array: *mut c_void,
232    shape: *const usize,
233    shape_count: usize,
234) -> mts_status_t {
235    crate::errors::catch_unwind(|| {
236        check_pointers!(array);
237        let array = array.cast::<RustArray>();
238
239        let shape = if shape_count == 0 {
240            &[]
241        } else {
242            check_pointers!(shape);
243            std::slice::from_raw_parts(shape, shape_count)
244        };
245
246        (*array).impl_.reshape(shape);
247        (*array).shape = shape.to_vec();
248
249        Ok(())
250    })
251}
252
253/// Implementation of `mts_array_t.swap_axes` using `RustArray`
254#[allow(clippy::cast_possible_truncation)]
255unsafe extern "C" fn rust_array_swap_axes(
256    array: *mut c_void,
257    axis_1: usize,
258    axis_2: usize,
259) -> mts_status_t {
260    crate::errors::catch_unwind(|| {
261        check_pointers!(array);
262        let array = array.cast::<RustArray>();
263        (*array).impl_.swap_axes(axis_1, axis_2);
264        (*array).shape.swap(axis_1, axis_2);
265
266        Ok(())
267    })
268}
269
270/// Implementation of `mts_array_t.create` using `RustArray`
271#[allow(clippy::cast_possible_truncation)]
272unsafe extern "C" fn rust_array_create(
273    array: *const c_void,
274    shape: *const usize,
275    shape_count: usize,
276    fill_value: mts_array_t,
277    array_storage: *mut mts_array_t,
278) -> mts_status_t {
279    crate::errors::catch_unwind(|| {
280        check_pointers!(array, array_storage);
281        let array = array.cast::<RustArray>();
282
283        let shape = if shape_count == 0 {
284            &[]
285        } else {
286            check_pointers!(shape);
287            std::slice::from_raw_parts(shape, shape_count)
288        };
289
290        let new_array = (*array).impl_.create(shape, MtsArray::from_raw(fill_value));
291        let new_array = MtsArray::from(new_array);
292
293        *array_storage = new_array.into_raw();
294
295        Ok(())
296    })
297}
298
299/// Implementation of `mts_array_t.copy` using `RustArray`
300unsafe extern "C" fn rust_array_copy(
301    array: *const c_void,
302    device: DLDevice,
303    new_array: *mut mts_array_t
304) -> mts_status_t {
305    crate::errors::catch_unwind(|| {
306        check_pointers!(array, new_array);
307        let array = array.cast::<RustArray>();
308
309        let copy = (*array).impl_.copy(device);
310        let copy = MtsArray::from(copy);
311        *new_array = copy.into_raw();
312
313        Ok(())
314    })
315}
316
317/// Implementation of `mts_array_t.destroy` for `RustArray`
318unsafe extern "C" fn rust_array_destroy(
319    array: *mut c_void,
320) {
321    if !array.is_null() {
322        let array = array.cast::<RustArray>();
323        let boxed = Box::from_raw(array);
324        std::mem::drop(boxed);
325    }
326}
327
328/// Implementation of `mts_array_t.move_sample` using `RustArray`
329#[allow(clippy::cast_possible_truncation)]
330unsafe extern "C" fn rust_array_move_data(
331    output: *mut c_void,
332    input: *const c_void,
333    movements: *const mts_data_movement_t,
334    movements_count: usize,
335) -> mts_status_t {
336    crate::errors::catch_unwind(|| {
337        check_pointers!(output, input);
338        let output = output.cast::<RustArray>();
339        let input = input.cast::<RustArray>();
340
341        let movements = if movements_count == 0 {
342            &[]
343        } else {
344            check_pointers!(movements);
345            std::slice::from_raw_parts(movements, movements_count)
346        };
347
348        (*output).impl_.move_data(&*(*input).impl_, movements);
349
350        Ok(())
351    })
352}
353
354/// Implementation of `mts_array_t.as_dlpack` using `RustArray`
355unsafe extern "C" fn rust_array_as_dlpack(
356    array: *mut c_void,
357    dl_tensor: *mut *mut DLManagedTensorVersioned,
358    device: DLDevice,
359    stream: *const i64,
360    max_version: DLPackVersion,
361) -> mts_status_t {
362    crate::errors::catch_unwind(|| {
363        check_pointers!(array, dl_tensor);
364        let array = array.cast::<RustArray>();
365        let stream_opt = stream.as_ref().copied();
366        let tensor = (*array).impl_.as_dlpack(device, stream_opt, max_version)?;
367
368        *dl_tensor = tensor.into_raw().as_ptr();
369        Ok(())
370    })
371}
372
373/// Implementation of `mts_array_t.from_dlpack` using `RustArray`
374unsafe extern "C" fn rust_array_from_dlpack(
375    array: *const c_void,
376    dl_tensor: *mut DLManagedTensorVersioned,
377    new_array: *mut mts_array_t,
378) -> mts_status_t {
379    crate::errors::catch_unwind(|| {
380        check_pointers!(array, dl_tensor, new_array);
381        let array = array.cast::<RustArray>();
382        let dl_tensor = DLPackTensor::from_ptr(dl_tensor);
383
384        let new_rust_array = (*array).impl_.from_dlpack(dl_tensor)?;
385
386        *new_array = MtsArray::from(new_rust_array).into_raw();
387
388        Ok(())
389    })
390}