Skip to main content

metatensor/data/
array_ref.rs

1use std::ptr::NonNull;
2use std::sync::{Arc, RwLock};
3
4use ndarray::ArrayD;
5
6use dlpk::sys::DLDevice;
7
8use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t};
9use crate::Error;
10
11use super::external::{check_status_external, MtsArray};
12use super::origin::get_data_origin;
13
14/// Reference to a data array in metatensor-core
15///
16/// The data array can come from any origin, this struct provides facilities to
17/// access data that was created through the [`Array`] trait, and in particular
18/// as `ndarray::ArrayD` instances.
19#[derive(Debug, Clone, Copy)]
20pub struct ArrayRef<'a> {
21    array: mts_array_t,
22    /// `ArrayRef` should behave like `&'a mts_array_t`
23    marker: std::marker::PhantomData<&'a mts_array_t>,
24}
25
26impl<'a> ArrayRef<'a> {
27    /// Create a new `ArrayRef` from the given raw `mts_array_t`
28    ///
29    /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
30    /// Make sure the lifetime is actually constrained by the lifetime of the
31    /// owner of this `mts_array_t`.
32    pub unsafe fn from_raw(array: mts_array_t) -> ArrayRef<'a> {
33        ArrayRef {
34            array: mts_array_t {
35                // remove the destructor if any, since we only have a reference
36                // to the array, and it should not be dropped when passed back
37                // to C through `as_raw()`.
38                destroy: None,
39                ..array
40            },
41            marker: std::marker::PhantomData,
42        }
43    }
44
45    /// Get the underlying array as an `&dyn Any` instance.
46    ///
47    /// This function panics if the array was not created though this crate and
48    /// the [`Array`] trait.
49    #[inline]
50    pub fn as_any(&self) -> &dyn std::any::Any {
51        let origin = self.origin().unwrap_or(0);
52        assert_eq!(
53            origin, *super::array::RUST_DATA_ORIGIN,
54            "this array was not created as a rust Array (origin is '{}')",
55            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
56        );
57
58        let array = self.array.ptr.cast::<super::array::RustArray>();
59        unsafe {
60            return (*array).as_any();
61        }
62    }
63
64    /// Get a reference to the underlying array as an `&dyn Any` instance,
65    /// re-using the same lifetime as the `ArrayRef`.
66    ///
67    /// This function panics if the array was not created though this crate and
68    /// the [`Array`] trait.
69    #[inline]
70    pub fn to_any(self) -> &'a dyn std::any::Any {
71        let origin = self.origin().unwrap_or(0);
72        assert_eq!(
73            origin, *super::array::RUST_DATA_ORIGIN,
74            "this array was not created as a rust Array (origin is '{}')",
75            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
76        );
77
78        let array = self.array.ptr.cast::<super::array::RustArray>();
79        unsafe {
80            return (*array).as_any();
81        }
82    }
83
84    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
85    /// contains one.
86    ///
87    /// This function will panic if the data in the `mts_array_t` in this
88    /// `ArrayRef` is a different kind of array.
89    #[inline]
90    pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
91        self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
92    }
93
94    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
95    /// contains one, keeping the initial lifetime of the `ArrayRef`.
96    ///
97    /// This function will panic if the data in the `mts_array_t` in this
98    /// `ArrayRef` is a different kind of array.
99    #[inline]
100    pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
101        self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
102    }
103
104    /// Get the raw underlying `mts_array_t`
105    pub fn as_raw(&self) -> &mts_array_t {
106        &self.array
107    }
108
109    /// Get the origin of this array.
110    ///
111    /// This corresponds to `mts_array_t.origin`, but with a more convenient API.
112    pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
113        let function = self.array.origin.expect("mts_array_t.origin function is NULL");
114
115        let mut origin = 0;
116        unsafe {
117            check_status_external(
118                function(self.array.ptr, &mut origin),
119                "mts_array_t.origin",
120            )?;
121        }
122
123        return Ok(origin);
124    }
125
126    /// Get the device of this array.
127    ///
128    /// This corresponds to `mts_array_t.device`, but with a more convenient API.
129    pub fn device(&self) -> Result<DLDevice, Error> {
130        let function = self.array.device.expect("mts_array_t.device function is NULL");
131
132        let mut device = DLDevice::cpu();
133        unsafe {
134            check_status_external(
135                function(self.array.ptr, &mut device),
136                "mts_array_t.device",
137            )?;
138        }
139
140        return Ok(device);
141    }
142
143    /// Get the dtype of this array.
144    ///
145    /// This corresponds to `mts_array_t.dtype`, but with a more convenient API.
146    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
147        let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
148
149        let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
150        unsafe {
151            check_status_external(
152                function(self.array.ptr, &mut dtype),
153                "mts_array_t.dtype",
154            )?;
155        }
156
157        return Ok(dtype);
158    }
159
160    /// Get a [`dlpk::DLPackTensor`] from this array, if supported by the underlying data.
161    ///
162    /// This corresponds to `mts_array_t.as_dlpack`, but with a more convenient API.
163    pub fn as_dlpack(
164        &self,
165        device: DLDevice,
166        stream: Option<i64>,
167        max_version: dlpk::sys::DLPackVersion,
168    ) -> Result<dlpk::DLPackTensor, Error> {
169        let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
170
171        let mut tensor = std::ptr::null_mut();
172        let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
173
174        unsafe {
175            check_status_external(
176                function(self.array.ptr, &mut tensor, device, stream_c, max_version),
177                "mts_array_t.as_dlpack",
178            )?;
179        }
180
181        let tensor = NonNull::new(tensor).expect("got a NULL DLManagedTensorVersioned from `as_dlpack`");
182        let tensor = unsafe {
183            dlpk::DLPackTensor::from_ptr(tensor)
184        };
185
186        return Ok(tensor);
187    }
188
189    /// Get the shape of this array.
190    ///
191    /// This corresponds to `mts_array_t.shape`, but with a more convenient API.
192    pub fn shape(&self) -> Result<&[usize], Error> {
193        let function = self.array.shape.expect("mts_array_t.shape function is NULL");
194
195        let mut shape = std::ptr::null();
196        let mut shape_count: usize = 0;
197
198        unsafe {
199            check_status_external(
200                function(self.array.ptr, &mut shape, &mut shape_count),
201                "mts_array_t.shape"
202            )?;
203        }
204
205        if shape_count == 0 {
206            return Ok(&[]);
207        } else {
208            assert!(!shape.is_null());
209            let shape = unsafe {
210                std::slice::from_raw_parts(shape, shape_count)
211            };
212            return Ok(shape);
213        }
214    }
215
216    /// Create a new array with the same options as this one (dtype, device)
217    /// and the given shape, filled with zeros.
218    ///
219    /// This corresponds to `mts_array_t.create`, but with a more convenient API.
220    pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
221        let function = self.array.create.expect("mts_array_t.create function is NULL");
222
223        let mut new_array = mts_array_t::null();
224        unsafe {
225            check_status_external(
226                function(
227                    self.array.ptr,
228                    shape.as_ptr(),
229                    shape.len(),
230                    *fill_value.as_raw(),
231                    &mut new_array
232                ),
233                "mts_array_t.create",
234            )?;
235        }
236
237        return Ok(MtsArray::from_raw(new_array));
238    }
239
240    /// Copy the data in this array, if supported by the underlying data.
241    ///
242    /// This corresponds to `mts_array_t.copy`, but with a more convenient API.
243    pub fn copy(&self) -> Result<MtsArray, Error> {
244        let function = self.array.copy.expect("mts_array_t.copy function is NULL");
245        let mut new_array = mts_array_t::null();
246        unsafe {
247            check_status_external(
248                function(self.array.ptr, &mut new_array),
249                "mts_array_t.copy",
250            )?;
251        }
252
253        return Ok(MtsArray::from_raw(new_array));
254    }
255}
256
257/// Mutable reference to a data array in metatensor-core
258///
259/// The data array can come from any origin, this struct provides facilities to
260/// access data that was created through the [`Array`] trait, and in particular
261/// as `ndarray::ArrayD` instances.
262#[derive(Debug)]
263pub struct ArrayRefMut<'a> {
264    array: mts_array_t,
265    /// `ArrayRefMut` should behave like `&'a mut mts_array_t`
266    marker: std::marker::PhantomData<&'a mut mts_array_t>,
267}
268
269impl<'a> ArrayRefMut<'a> {
270    /// Create a new `ArrayRefMut` from the given raw `mts_array_t`
271    ///
272    /// This is a **VERY** unsafe function, creating a lifetime out of thin air,
273    /// and allowing mutable access to the `mts_array_t`. Make sure the lifetime
274    /// is actually constrained by the lifetime of the owner of this
275    /// `mts_array_t`; and that the owner is mutably borrowed by this
276    /// `ArrayRefMut`.
277    #[inline]
278    pub unsafe fn from_raw(array: mts_array_t) -> ArrayRefMut<'a> {
279        ArrayRefMut {
280            array: mts_array_t {
281                // remove the destructor if any, since we only have a reference
282                // to the array, and it should not be dropped when passed back
283                // to C through `as_raw()`.
284                destroy: None,
285                ..array
286            },
287            marker: std::marker::PhantomData,
288        }
289    }
290
291    /// Get the underlying array as an `&dyn Any` instance.
292    ///
293    /// This function panics if the array was not created though this crate and
294    /// the [`Array`] trait.
295    #[inline]
296    pub fn as_any(&self) -> &dyn std::any::Any {
297        let origin = self.origin().unwrap_or(0);
298        assert_eq!(
299            origin, *super::array::RUST_DATA_ORIGIN,
300            "this array was not created as a rust Array (origin is '{}')",
301            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
302        );
303
304        let array = self.array.ptr.cast::<super::array::RustArray>();
305        unsafe {
306            return (*array).as_any();
307        }
308    }
309
310    /// Get a reference to the underlying array as an `&dyn Any` instance,
311    /// re-using the same lifetime as the `ArrayRefMut`.
312    ///
313    /// This function panics if the array was not created though this crate and
314    /// the [`Array`] trait.
315    #[inline]
316    pub fn to_any(self) -> &'a dyn std::any::Any {
317        let origin = self.origin().unwrap_or(0);
318        assert_eq!(
319            origin, *super::array::RUST_DATA_ORIGIN,
320            "this array was not created as a rust Array (origin is '{}')",
321            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
322        );
323
324        let array = self.array.ptr.cast::<super::array::RustArray>();
325        unsafe {
326            return (*array).as_any();
327        }
328    }
329
330    /// Get the underlying array as an `&mut dyn Any` instance.
331    ///
332    /// This function panics if the array was not created though this crate and
333    /// the [`Array`] trait.
334    #[inline]
335    pub fn to_any_mut(self) -> &'a mut dyn std::any::Any {
336        let origin = self.origin().unwrap_or(0);
337        assert_eq!(
338            origin, *super::array::RUST_DATA_ORIGIN,
339            "this array was not created as a rust Array (origin is '{}')",
340            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
341        );
342
343        let array = self.array.ptr.cast::<super::array::RustArray>();
344        unsafe {
345            return (*array).as_any_mut();
346        }
347    }
348
349    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
350    /// contains one.
351    ///
352    /// This function will panic if the data in the `mts_array_t` in this
353    /// `ArrayRefMut` is a different kind of array.
354    #[inline]
355    pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
356        self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
357    }
358
359    /// Extract the `Arc<RwLock<ArrayD<T>>>` from this `ArrayRef`, if it
360    /// contains one, keeping the initial lifetime of the `ArrayRef`.
361    ///
362    /// This function will panic if the data in the `mts_array_t` in this
363    /// `ArrayRefMut` is a different kind of array.
364    #[inline]
365    pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
366        self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
367    }
368
369    /// Get a mutable reference to the underlying array, consuming this
370    /// `ArrayRefMut`.
371    ///
372    /// Since this array is already guaranteed to be unique through the mutable
373    /// borrow, we do not need to lock the `RwLock` to get access to the
374    /// `ArrayD`.
375    ///
376    /// This function will panic if the data in the `mts_array_t` in this
377    /// `ArrayRefMut` does not contain an `Arc<RwLock<ArrayD<T>>>`, or if the
378    /// `Arc` already has multiple strong references.
379    #[inline]
380    pub fn get_ndarray_mut<T>(self) -> &'a mut ArrayD<T> where T: 'static {
381        let arc = self.to_any_mut().downcast_mut::<Arc<RwLock<ArrayD<T>>>>().expect("this is not an Arc<RwLock<ArrayD>>");
382        let lock = Arc::get_mut(arc).expect("the outer Arc already has multiple owners");
383        return lock.get_mut().expect("lock was poisoned");
384    }
385
386    /// Get the raw underlying `mts_array_t`
387    pub fn as_raw(&self) -> &mts_array_t {
388        &self.array
389    }
390
391    /// Get a mutable reference to the raw underlying `mts_array_t`
392    pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
393        &mut self.array
394    }
395
396    /// Get the origin of this array.
397    ///
398    /// This corresponds to `mts_array_t.origin`, but with a more convenient API.
399    pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
400        let function = self.array.origin.expect("mts_array_t.origin function is NULL");
401
402        let mut origin = 0;
403        unsafe {
404            check_status_external(
405                function(self.array.ptr, &mut origin),
406                "mts_array_t.origin",
407            )?;
408        }
409
410        return Ok(origin);
411    }
412
413    /// Get the device of this array.
414    ///
415    /// This corresponds to `mts_array_t.device`, but with a more convenient API.
416    pub fn device(&self) -> Result<DLDevice, Error> {
417        let function = self.array.device.expect("mts_array_t.device function is NULL");
418
419        let mut device = DLDevice::cpu();
420        unsafe {
421            check_status_external(
422                function(self.array.ptr, &mut device),
423                "mts_array_t.device",
424            )?;
425        }
426
427        return Ok(device);
428    }
429
430    /// Get the dtype of this array.
431    ///
432    /// This corresponds to `mts_array_t.dtype`, but with a more convenient API.
433    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
434        let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
435
436        let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
437        unsafe {
438            check_status_external(
439                function(self.array.ptr, &mut dtype),
440                "mts_array_t.dtype",
441            )?;
442        }
443
444        return Ok(dtype);
445    }
446
447    /// Get a [`dlpk::DLPackTensor`] from this array, if supported by the underlying data.
448    ///
449    /// This corresponds to `mts_array_t.as_dlpack`, but with a more convenient API.
450    pub fn as_dlpack(
451        &self,
452        device: DLDevice,
453        stream: Option<i64>,
454        max_version: dlpk::sys::DLPackVersion,
455    ) -> Result<dlpk::DLPackTensor, Error> {
456        let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
457
458        let mut tensor = std::ptr::null_mut();
459        let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
460
461        unsafe {
462            check_status_external(
463                function(self.array.ptr, &mut tensor, device, stream_c, max_version),
464                "mts_array_t.as_dlpack",
465            )?;
466        }
467
468        let tensor = NonNull::new(tensor).expect("got a NULL DLManagedTensorVersioned from `as_dlpack`");
469        let tensor = unsafe {
470            dlpk::DLPackTensor::from_ptr(tensor)
471        };
472
473        return Ok(tensor);
474    }
475
476    /// Get the shape of this array.
477    ///
478    /// This corresponds to `mts_array_t.shape`, but with a more convenient API.
479    pub fn shape(&self) -> Result<&[usize], Error> {
480        let function = self.array.shape.expect("mts_array_t.shape function is NULL");
481
482        let mut shape = std::ptr::null();
483        let mut shape_count: usize = 0;
484
485        unsafe {
486            check_status_external(
487                function(self.array.ptr, &mut shape, &mut shape_count),
488                "mts_array_t.shape"
489            )?;
490        }
491
492        if shape_count == 0 {
493            return Ok(&[]);
494        } else {
495            assert!(!shape.is_null());
496            let shape = unsafe {
497                std::slice::from_raw_parts(shape, shape_count)
498            };
499            return Ok(shape);
500        }
501    }
502
503    /// Reshape the data in this array, if supported by the underlying data.
504    ///
505    /// This corresponds to `mts_array_t.reshape`, but with a more convenient API.
506    pub fn reshape(&mut self, shape: &[usize]) -> Result<(), Error> {
507        let function = self.array.reshape.expect("mts_array_t.reshape function is NULL");
508
509        unsafe {
510            check_status_external(
511                function(self.array.ptr, shape.as_ptr(), shape.len()),
512                "mts_array_t.reshape",
513            )?;
514        }
515
516        return Ok(());
517    }
518
519    /// Swap two axes of the data in this array, if supported by the underlying data.
520    ///
521    /// This corresponds to `mts_array_t.swap_axes`, but with a more convenient API.
522    pub fn swap_axes(&mut self, axis_1: usize, axis_2: usize) -> Result<(), Error> {
523        let function = self.array.swap_axes.expect("mts_array_t.swap_axes function is NULL");
524
525        unsafe {
526            check_status_external(
527                function(self.array.ptr, axis_1, axis_2),
528                "mts_array_t.swap_axes",
529            )?;
530        }
531
532        return Ok(());
533    }
534
535    /// Create a new array with the same options as this one (dtype, device)
536    /// and the given shape, filled with zeros.
537    ///
538    /// This corresponds to `mts_array_t.create`, but with a more convenient API.
539    pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
540        let function = self.array.create.expect("mts_array_t.create function is NULL");
541
542        let mut new_array = mts_array_t::null();
543        unsafe {
544            check_status_external(
545                function(
546                    self.array.ptr,
547                    shape.as_ptr(),
548                    shape.len(),
549                    *fill_value.as_raw(),
550                    &mut new_array
551                ),
552                "mts_array_t.create",
553            )?;
554        }
555
556        return Ok(MtsArray::from_raw(new_array));
557    }
558
559    /// Copy the data in this array, if supported by the underlying data.
560    ///
561    /// This corresponds to `mts_array_t.copy`, but with a more convenient API.
562    pub fn copy(&self) -> Result<MtsArray, Error> {
563        let function = self.array.copy.expect("mts_array_t.copy function is NULL");
564        let mut new_array = mts_array_t::null();
565        unsafe {
566            check_status_external(
567                function(self.array.ptr, &mut new_array),
568                "mts_array_t.copy",
569            )?;
570        }
571
572        return Ok(MtsArray::from_raw(new_array));
573    }
574
575    /// Move the data in this array to another array, if supported by the underlying data.
576    ///
577    /// This corresponds to `mts_array_t.move_data`, but with a more convenient API.
578    pub fn move_data<'input>(
579        &mut self,
580        input: impl Into<ArrayRef<'input>>,
581        moves: &[mts_data_movement_t],
582    ) -> Result<(), Error> {
583        let function = self.array.move_data.expect("mts_array_t.move_data function is NULL");
584
585        let input = input.into();
586        unsafe {
587            check_status_external(
588                function(self.array.ptr, input.as_raw().ptr, moves.as_ptr(), moves.len()),
589                "mts_array_t.move_data",
590            )?;
591        }
592
593        return Ok(());
594    }
595}