Skip to main content

metatensor/data/
array_ref.rs

1use std::sync::{Arc, RwLock};
2
3use ndarray::ArrayD;
4
5use dlpk::sys::DLDevice;
6
7use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t};
8use crate::Error;
9use crate::errors::check_status;
10
11use super::external::MtsArray;
12use super::origin::get_data_origin;
13
14/// Reference to a data array in metatensor-core
15///
16/// The data array can come from any origin, this struct provides facilities to
17/// access data that was created through the [`crate::Array`] trait, and in particular
18/// as `ndarray::ArrayD` instances.
19#[derive(Debug, Clone, Copy)]
20pub struct ArrayRef<'a> {
21    array: mts_array_t,
22    /// `ArrayRef` should behave like `&'a mts_array_t`
23    marker: std::marker::PhantomData<&'a mts_array_t>,
24}
25
26impl<'a> ArrayRef<'a> {
27    /// Create a new `ArrayRef` from the given raw `mts_array_t`
28    ///
29    /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
30    /// Make sure the lifetime is actually constrained by the lifetime of the
31    /// owner of this `mts_array_t`.
32    pub unsafe fn from_raw(array: mts_array_t) -> ArrayRef<'a> {
33        ArrayRef {
34            array: mts_array_t {
35                // remove the destructor if any, since we only have a reference
36                // to the array, and it should not be dropped when passed back
37                // to C through `as_raw()`.
38                destroy: None,
39                ..array
40            },
41            marker: std::marker::PhantomData,
42        }
43    }
44
45    /// Get the underlying array as an `&dyn Any` instance.
46    ///
47    /// This function panics if the array was not created though this crate and
48    /// the [`crate::Array`] trait.
49    #[inline]
50    pub fn as_any(&self) -> &dyn std::any::Any {
51        let origin = self.origin().unwrap_or(0);
52        assert_eq!(
53            origin, *super::array::RUST_DATA_ORIGIN,
54            "this array was not created as a rust Array (origin is '{}')",
55            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
56        );
57
58        let array = self.array.ptr.cast::<super::array::RustArray>();
59        unsafe {
60            return (*array).as_any();
61        }
62    }
63
64    /// Get a reference to the underlying array as an `&dyn Any` instance,
65    /// re-using the same lifetime as the `ArrayRef`.
66    ///
67    /// This function panics if the array was not created though this crate and
68    /// the [`crate::Array`] trait.
69    #[inline]
70    pub fn to_any(self) -> &'a dyn std::any::Any {
71        let origin = self.origin().unwrap_or(0);
72        assert_eq!(
73            origin, *super::array::RUST_DATA_ORIGIN,
74            "this array was not created as a rust Array (origin is '{}')",
75            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
76        );
77
78        let array = self.array.ptr.cast::<super::array::RustArray>();
79        unsafe {
80            return (*array).as_any();
81        }
82    }
83
84    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
85    /// contains one.
86    ///
87    /// This function will panic if the data in the `mts_array_t` in this
88    /// `ArrayRef` is a different kind of array.
89    #[inline]
90    pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
91        self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
92    }
93
94    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
95    /// contains one, keeping the initial lifetime of the `ArrayRef`.
96    ///
97    /// This function will panic if the data in the `mts_array_t` in this
98    /// `ArrayRef` is a different kind of array.
99    #[inline]
100    pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
101        self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
102    }
103
104    /// Get the raw underlying `mts_array_t`
105    pub fn as_raw(&self) -> &mts_array_t {
106        &self.array
107    }
108
109    /// Get the origin of this array.
110    ///
111    /// This corresponds to `mts_array_t.origin`, but with a more convenient API.
112    pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
113        let function = self.array.origin.expect("mts_array_t.origin function is NULL");
114
115        let mut origin = 0;
116        unsafe {
117            check_status(function(self.array.ptr, &mut origin))?;
118        }
119
120        return Ok(origin);
121    }
122
123    /// Get the device of this array.
124    ///
125    /// This corresponds to `mts_array_t.device`, but with a more convenient API.
126    pub fn device(&self) -> Result<DLDevice, Error> {
127        let function = self.array.device.expect("mts_array_t.device function is NULL");
128
129        let mut device = DLDevice::cpu();
130        unsafe {
131            check_status(function(self.array.ptr, &mut device))?;
132        }
133
134        return Ok(device);
135    }
136
137    /// Get the dtype of this array.
138    ///
139    /// This corresponds to `mts_array_t.dtype`, but with a more convenient API.
140    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
141        let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
142
143        let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
144        unsafe {
145            check_status(function(self.array.ptr, &mut dtype))?;
146        }
147
148        return Ok(dtype);
149    }
150
151    /// Get a [`dlpk::DLPackTensor`] from this array, if supported by the underlying data.
152    ///
153    /// This corresponds to `mts_array_t.as_dlpack`, but with a more convenient API.
154    pub fn as_dlpack(
155        &self,
156        device: DLDevice,
157        stream: Option<i64>,
158        max_version: dlpk::sys::DLPackVersion,
159    ) -> Result<dlpk::DLPackTensor, Error> {
160        let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
161
162        let mut tensor = std::ptr::null_mut();
163        let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
164
165        unsafe {
166            check_status(function(self.array.ptr, &mut tensor, device, stream_c, max_version))?;
167        }
168
169        let tensor = unsafe {
170            dlpk::DLPackTensor::from_ptr(tensor)
171        };
172
173        return Ok(tensor);
174    }
175
176    /// Get the shape of this array.
177    ///
178    /// This corresponds to `mts_array_t.shape`, but with a more convenient API.
179    pub fn shape(&self) -> Result<&[usize], Error> {
180        let function = self.array.shape.expect("mts_array_t.shape function is NULL");
181
182        let mut shape = std::ptr::null();
183        let mut shape_count: usize = 0;
184
185        unsafe {
186            check_status(function(self.array.ptr, &mut shape, &mut shape_count))?;
187        }
188
189        if shape_count == 0 {
190            return Ok(&[]);
191        } else {
192            assert!(!shape.is_null());
193            let shape = unsafe {
194                std::slice::from_raw_parts(shape, shape_count)
195            };
196            return Ok(shape);
197        }
198    }
199
200    /// Create a new array with the same options as this one (dtype, device)
201    /// and the given shape, filled with zeros.
202    ///
203    /// This corresponds to `mts_array_t.create`, but with a more convenient API.
204    pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
205        let function = self.array.create.expect("mts_array_t.create function is NULL");
206
207        let mut new_array = mts_array_t::null();
208        unsafe {
209            check_status(function(
210                self.array.ptr,
211                shape.as_ptr(),
212                shape.len(),
213                *fill_value.as_raw(),
214                &mut new_array
215            ))?;
216        }
217
218        return Ok(MtsArray::from_raw(new_array));
219    }
220
221    /// Copy the data in this array, if supported by the underlying data.
222    ///
223    /// This corresponds to `mts_array_t.copy`, but with a more convenient API.
224    pub fn copy(&self, device: DLDevice) -> Result<MtsArray, Error> {
225        let function = self.array.copy.expect("mts_array_t.copy function is NULL");
226        let mut new_array = mts_array_t::null();
227        unsafe {
228            check_status(function(self.array.ptr, device, &mut new_array))?;
229        }
230
231        return Ok(MtsArray::from_raw(new_array));
232    }
233}
234
235/// Mutable reference to a data array in metatensor-core
236///
237/// The data array can come from any origin, this struct provides facilities to
238/// access data that was created through the [`crate::Array`] trait, and in
239/// particular as `ndarray::ArrayD` instances.
240#[derive(Debug)]
241pub struct ArrayRefMut<'a> {
242    array: mts_array_t,
243    /// `ArrayRefMut` should behave like `&'a mut mts_array_t`
244    marker: std::marker::PhantomData<&'a mut mts_array_t>,
245}
246
247impl<'a> ArrayRefMut<'a> {
248    /// Create a new `ArrayRefMut` from the given raw `mts_array_t`
249    ///
250    /// This is a **VERY** unsafe function, creating a lifetime out of thin air,
251    /// and allowing mutable access to the `mts_array_t`. Make sure the lifetime
252    /// is actually constrained by the lifetime of the owner of this
253    /// `mts_array_t`; and that the owner is mutably borrowed by this
254    /// `ArrayRefMut`.
255    #[inline]
256    pub unsafe fn from_raw(array: mts_array_t) -> ArrayRefMut<'a> {
257        ArrayRefMut {
258            array: mts_array_t {
259                // remove the destructor if any, since we only have a reference
260                // to the array, and it should not be dropped when passed back
261                // to C through `as_raw()`.
262                destroy: None,
263                ..array
264            },
265            marker: std::marker::PhantomData,
266        }
267    }
268
269    /// Get the underlying array as an `&dyn Any` instance.
270    ///
271    /// This function panics if the array was not created though this crate and
272    /// the [`crate::Array`] trait.
273    #[inline]
274    pub fn as_any(&self) -> &dyn std::any::Any {
275        let origin = self.origin().unwrap_or(0);
276        assert_eq!(
277            origin, *super::array::RUST_DATA_ORIGIN,
278            "this array was not created as a rust Array (origin is '{}')",
279            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
280        );
281
282        let array = self.array.ptr.cast::<super::array::RustArray>();
283        unsafe {
284            return (*array).as_any();
285        }
286    }
287
288    /// Get a reference to the underlying array as an `&dyn Any` instance,
289    /// re-using the same lifetime as the `ArrayRefMut`.
290    ///
291    /// This function panics if the array was not created though this crate and
292    /// the [`crate::Array`] trait.
293    #[inline]
294    pub fn to_any(self) -> &'a dyn std::any::Any {
295        let origin = self.origin().unwrap_or(0);
296        assert_eq!(
297            origin, *super::array::RUST_DATA_ORIGIN,
298            "this array was not created as a rust Array (origin is '{}')",
299            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
300        );
301
302        let array = self.array.ptr.cast::<super::array::RustArray>();
303        unsafe {
304            return (*array).as_any();
305        }
306    }
307
308    /// Get the underlying array as an `&mut dyn Any` instance.
309    ///
310    /// This function panics if the array was not created though this crate and
311    /// the [`crate::Array`] trait.
312    #[inline]
313    pub fn to_any_mut(self) -> &'a mut dyn std::any::Any {
314        let origin = self.origin().unwrap_or(0);
315        assert_eq!(
316            origin, *super::array::RUST_DATA_ORIGIN,
317            "this array was not created as a rust Array (origin is '{}')",
318            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
319        );
320
321        let array = self.array.ptr.cast::<super::array::RustArray>();
322        unsafe {
323            return (*array).as_any_mut();
324        }
325    }
326
327    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
328    /// contains one.
329    ///
330    /// This function will panic if the data in the `mts_array_t` in this
331    /// `ArrayRefMut` is a different kind of array.
332    #[inline]
333    pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
334        self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
335    }
336
337    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
338    /// contains one, keeping the initial lifetime of the `ArrayRef`.
339    ///
340    /// This function will panic if the data in the `mts_array_t` in this
341    /// `ArrayRefMut` is a different kind of array.
342    #[inline]
343    pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
344        self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
345    }
346
347    /// Get a mutable reference to the underlying array, consuming this
348    /// `ArrayRefMut`.
349    ///
350    /// Since this array is already guaranteed to be unique through the mutable
351    /// borrow, we do not need to lock the `RwLock` to get access to the
352    /// `ArrayD`.
353    ///
354    /// This function will panic if the data in the `mts_array_t` in this
355    /// `ArrayRefMut` does not contain an `Arc<RwLock<ArrayD<T>>>`, or if the
356    /// `Arc` already has multiple strong references.
357    #[inline]
358    pub fn get_ndarray_mut<T>(self) -> &'a mut ArrayD<T> where T: 'static {
359        let arc = self.to_any_mut().downcast_mut::<Arc<RwLock<ArrayD<T>>>>().expect("this is not an Arc<RwLock<ArrayD>>");
360        let lock = Arc::get_mut(arc).expect("the outer Arc already has multiple owners");
361        return lock.get_mut().expect("lock was poisoned");
362    }
363
364    /// Get the raw underlying `mts_array_t`
365    pub fn as_raw(&self) -> &mts_array_t {
366        &self.array
367    }
368
369    /// Get a mutable reference to the raw underlying `mts_array_t`
370    pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
371        &mut self.array
372    }
373
374    /// Get the origin of this array.
375    ///
376    /// This corresponds to `mts_array_t.origin`, but with a more convenient API.
377    pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
378        let function = self.array.origin.expect("mts_array_t.origin function is NULL");
379
380        let mut origin = 0;
381        unsafe {
382            check_status(function(self.array.ptr, &mut origin))?;
383        }
384
385        return Ok(origin);
386    }
387
388    /// Get the device of this array.
389    ///
390    /// This corresponds to `mts_array_t.device`, but with a more convenient API.
391    pub fn device(&self) -> Result<DLDevice, Error> {
392        let function = self.array.device.expect("mts_array_t.device function is NULL");
393
394        let mut device = DLDevice::cpu();
395        unsafe {
396            check_status(function(self.array.ptr, &mut device))?;
397        }
398
399        return Ok(device);
400    }
401
402    /// Get the dtype of this array.
403    ///
404    /// This corresponds to `mts_array_t.dtype`, but with a more convenient API.
405    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
406        let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
407
408        let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
409        unsafe {
410            check_status(function(self.array.ptr, &mut dtype))?;
411        }
412
413        return Ok(dtype);
414    }
415
416    /// Get a [`dlpk::DLPackTensor`] from this array, if supported by the underlying data.
417    ///
418    /// This corresponds to `mts_array_t.as_dlpack`, but with a more convenient API.
419    pub fn as_dlpack(
420        &self,
421        device: DLDevice,
422        stream: Option<i64>,
423        max_version: dlpk::sys::DLPackVersion,
424    ) -> Result<dlpk::DLPackTensor, Error> {
425        let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
426
427        let mut tensor = std::ptr::null_mut();
428        let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
429
430        unsafe {
431            check_status(function(
432                self.array.ptr,
433                &mut tensor,
434                device,
435                stream_c,
436                max_version
437            ))?;
438        }
439
440        let tensor = unsafe {
441            dlpk::DLPackTensor::from_ptr(tensor)
442        };
443
444        return Ok(tensor);
445    }
446
447    /// Get the shape of this array.
448    ///
449    /// This corresponds to `mts_array_t.shape`, but with a more convenient API.
450    pub fn shape(&self) -> Result<&[usize], Error> {
451        let function = self.array.shape.expect("mts_array_t.shape function is NULL");
452
453        let mut shape = std::ptr::null();
454        let mut shape_count: usize = 0;
455
456        unsafe {
457            check_status(function(self.array.ptr, &mut shape, &mut shape_count))?;
458        }
459
460        if shape_count == 0 {
461            return Ok(&[]);
462        } else {
463            assert!(!shape.is_null());
464            let shape = unsafe {
465                std::slice::from_raw_parts(shape, shape_count)
466            };
467            return Ok(shape);
468        }
469    }
470
471    /// Reshape the data in this array, if supported by the underlying data.
472    ///
473    /// This corresponds to `mts_array_t.reshape`, but with a more convenient API.
474    pub fn reshape(&mut self, shape: &[usize]) -> Result<(), Error> {
475        let function = self.array.reshape.expect("mts_array_t.reshape function is NULL");
476
477        unsafe {
478            check_status(function(self.array.ptr, shape.as_ptr(), shape.len()))?;
479        }
480
481        return Ok(());
482    }
483
484    /// Swap two axes of the data in this array, if supported by the underlying data.
485    ///
486    /// This corresponds to `mts_array_t.swap_axes`, but with a more convenient API.
487    pub fn swap_axes(&mut self, axis_1: usize, axis_2: usize) -> Result<(), Error> {
488        let function = self.array.swap_axes.expect("mts_array_t.swap_axes function is NULL");
489
490        unsafe {
491            check_status(function(self.array.ptr, axis_1, axis_2))?;
492        }
493
494        return Ok(());
495    }
496
497    /// Create a new array with the same options as this one (dtype, device)
498    /// and the given shape, filled with zeros.
499    ///
500    /// This corresponds to `mts_array_t.create`, but with a more convenient API.
501    pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
502        let function = self.array.create.expect("mts_array_t.create function is NULL");
503
504        let mut new_array = mts_array_t::null();
505        unsafe {
506            check_status(function(
507                self.array.ptr,
508                shape.as_ptr(),
509                shape.len(),
510                *fill_value.as_raw(),
511                &mut new_array
512            ))?;
513        }
514
515        return Ok(MtsArray::from_raw(new_array));
516    }
517
518    /// Copy the data in this array, if supported by the underlying data.
519    ///
520    /// This corresponds to `mts_array_t.copy`, but with a more convenient API.
521    pub fn copy(&self, device: DLDevice) -> Result<MtsArray, Error> {
522        let function = self.array.copy.expect("mts_array_t.copy function is NULL");
523        let mut new_array = mts_array_t::null();
524        unsafe {
525            check_status(function(self.array.ptr, device, &mut new_array))?;
526        }
527
528        return Ok(MtsArray::from_raw(new_array));
529    }
530
531    /// Move the data in this array to another array, if supported by the underlying data.
532    ///
533    /// This corresponds to `mts_array_t.move_data`, but with a more convenient API.
534    pub fn move_data<'input>(
535        &mut self,
536        input: impl Into<ArrayRef<'input>>,
537        moves: &[mts_data_movement_t],
538    ) -> Result<(), Error> {
539        let function = self.array.move_data.expect("mts_array_t.move_data function is NULL");
540
541        let input = input.into();
542        unsafe {
543            check_status(function(
544                self.array.ptr,
545                input.as_raw().ptr,
546                moves.as_ptr(),
547                moves.len(),
548            ))?;
549        }
550
551        return Ok(());
552    }
553}