Skip to main content

metatensor/data/
external.rs

1use std::sync::{Arc, RwLock, RwLockReadGuard};
2
3use ndarray::ArrayD;
4use dlpk::sys::DLDevice;
5
6use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t};
7
8use crate::Error;
9use crate::errors::check_status;
10
11use super::{ArrayRef, ArrayRefMut};
12use super::origin::get_data_origin;
13
14/// Wrapper around `mts_array_t` that provides a more convenient API to use it
15/// in Rust code, and in particular to access the underlying array as an `&dyn
16/// Any` instance where possible.
17pub struct MtsArray {
18    array: mts_array_t
19}
20
21impl Drop for MtsArray {
22    fn drop(&mut self) {
23        if let Some(destroy) = self.array.destroy {
24             unsafe { destroy(self.array.ptr) }
25        }
26    }
27}
28
29impl MtsArray {
30    /// Create a new `MtsArray` from a `mts_array_t`, taking ownership of the
31    /// data.
32    pub fn from_raw(array: mts_array_t) -> MtsArray {
33        MtsArray { array }
34    }
35
36    /// Get the underlying `mts_array_t`, transferring ownership of the data to
37    /// the caller.
38    pub fn into_raw(self) -> mts_array_t {
39        let array = self.array;
40        // since mts_array_t is Copy, we need to forget self to avoid calling
41        // Drop when this function returns
42        std::mem::forget(self);
43        array
44    }
45
46    /// Get the underlying array as an `&dyn Any` instance.
47    ///
48    /// This function panics if the array was not created though this crate and
49    /// the [`crate::Array`] trait.
50    #[inline]
51    pub fn as_any(&self) -> &dyn std::any::Any {
52        let origin = self.origin().unwrap_or(0);
53        assert_eq!(
54            origin, *super::array::RUST_DATA_ORIGIN,
55            "this array was not created as a rust Array (origin is '{}')",
56            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
57        );
58
59        let array = self.array.ptr.cast::<super::array::RustArray>();
60        unsafe {
61            return (*array).as_any();
62        }
63    }
64
65    #[inline]
66    fn as_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
67        self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
68    }
69
70    /// Get the data in this `ArrayRef` as a `ndarray::ArcArray`. This function
71    /// will panic if the data in this `mts_array_t` is not a `ndarray::ArcArray`.
72    #[inline]
73    pub fn as_ndarray<T>(&self) -> RwLockReadGuard<'_, ArrayD<T>> where T: 'static {
74        return self.as_lock().read().expect("lock was poisoned");
75    }
76
77    /// Get the underlying `mts_array_t`.
78    pub fn as_raw(&self) -> &mts_array_t {
79        &self.array
80    }
81
82    /// Get the underlying `mts_array_t` as a mutable reference.
83    pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
84        &mut self.array
85    }
86
87    /// Get a reference to this array
88    pub fn as_ref(&'_ self) -> ArrayRef<'_> {
89        unsafe { ArrayRef::from_raw(self.array) }
90    }
91
92    /// Get a mutable reference to this array
93    pub fn as_mut(&'_ mut self) -> ArrayRefMut<'_> {
94        unsafe { ArrayRefMut::from_raw(self.array) }
95    }
96
97    /// Get the origin of this array.
98    ///
99    /// This corresponds to `mts_array_t.origin`, but with a more convenient API.
100    pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
101        let function = self.array.origin.expect("mts_array_t.origin function is NULL");
102
103        let mut origin = 0;
104        unsafe {
105            check_status(function(self.array.ptr, &mut origin))?;
106        }
107
108        return Ok(origin);
109    }
110
111    /// Get the device of this array.
112    ///
113    /// This corresponds to `mts_array_t.device`, but with a more convenient API.
114    pub fn device(&self) -> Result<DLDevice, Error> {
115        let function = self.array.device.expect("mts_array_t.device function is NULL");
116
117        let mut device = DLDevice::cpu();
118        unsafe {
119            check_status(function(self.array.ptr, &mut device))?;
120        }
121
122        return Ok(device);
123    }
124
125    /// Get the dtype of this array.
126    ///
127    /// This corresponds to `mts_array_t.dtype`, but with a more convenient API.
128    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
129        let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
130
131        let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
132        unsafe {
133            check_status(function(self.array.ptr, &mut dtype))?;
134        }
135
136        return Ok(dtype);
137    }
138
139    /// Get a [`dlpk::DLPackTensor`] from this array, if supported by the underlying data.
140    ///
141    /// This corresponds to `mts_array_t.as_dlpack`, but with a more convenient API.
142    pub fn as_dlpack(
143        &self,
144        device: DLDevice,
145        stream: Option<i64>,
146        max_version: dlpk::sys::DLPackVersion,
147    ) -> Result<dlpk::DLPackTensor, Error> {
148        let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
149
150        let mut tensor = std::ptr::null_mut();
151        let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
152
153        unsafe {
154            check_status(function(self.array.ptr, &mut tensor, device, stream_c, max_version))?;
155        }
156
157        let tensor = unsafe {
158            dlpk::DLPackTensor::from_ptr(tensor)
159        };
160
161        return Ok(tensor);
162    }
163
164    pub fn from_dlpack(&self, dlpack_tensor: dlpk::DLPackTensor) -> Result<MtsArray, Error> {
165        let function = self.array.from_dlpack.expect("mts_array_t.from_dlpack function is NULL");
166
167        let mut new_array = mts_array_t::null();
168        unsafe {
169            check_status(function(self.array.ptr, dlpack_tensor.into_raw().as_ptr(), &mut new_array))?;
170        }
171
172        return Ok(MtsArray::from_raw(new_array));
173    }
174
175    /// Get the shape of this array.
176    ///
177    /// This corresponds to `mts_array_t.shape`, but with a more convenient API.
178    pub fn shape(&self) -> Result<&[usize], Error> {
179        let function = self.array.shape.expect("mts_array_t.shape function is NULL");
180
181        let mut shape = std::ptr::null();
182        let mut shape_count: usize = 0;
183
184        unsafe {
185            check_status(function(self.array.ptr, &mut shape, &mut shape_count))?;
186        }
187
188        if shape_count == 0 {
189            return Ok(&[]);
190        } else {
191            assert!(!shape.is_null());
192            let shape = unsafe {
193                std::slice::from_raw_parts(shape, shape_count)
194            };
195            return Ok(shape);
196        }
197    }
198
199    /// Reshape the data in this array, if supported by the underlying data.
200    ///
201    /// This corresponds to `mts_array_t.reshape`, but with a more convenient API.
202    pub fn reshape(&mut self, shape: &[usize]) -> Result<(), Error> {
203        let function = self.array.reshape.expect("mts_array_t.reshape function is NULL");
204
205        unsafe {
206            check_status(function(self.array.ptr, shape.as_ptr(), shape.len()))?;
207        }
208
209        return Ok(());
210    }
211
212    /// Swap two axes of the data in this array, if supported by the underlying data.
213    ///
214    /// This corresponds to `mts_array_t.swap_axes`, but with a more convenient API.
215    pub fn swap_axes(&mut self, axis_1: usize, axis_2: usize) -> Result<(), Error> {
216        let function = self.array.swap_axes.expect("mts_array_t.swap_axes function is NULL");
217
218        unsafe {
219            check_status(function(self.array.ptr, axis_1, axis_2))?;
220        }
221
222        return Ok(());
223    }
224
225    /// Create a new array with the same options as this one (dtype, device)
226    /// and the given shape, filled with zeros.
227    ///
228    /// This corresponds to `mts_array_t.create`, but with a more convenient API.
229    pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
230        let function = self.array.create.expect("mts_array_t.create function is NULL");
231
232        let mut new_array = mts_array_t::null();
233        unsafe {
234            check_status(function(
235                self.array.ptr,
236                shape.as_ptr(),
237                shape.len(),
238                *fill_value.as_raw(),
239                &mut new_array
240            ))?;
241        }
242
243        return Ok(MtsArray::from_raw(new_array));
244    }
245
246    /// Copy the data in this array, if supported by the underlying data.
247    ///
248    /// This corresponds to `mts_array_t.copy`, but with a more convenient API.
249    pub fn copy(&self, device: DLDevice) -> Result<MtsArray, Error> {
250        let function = self.array.copy.expect("mts_array_t.copy function is NULL");
251        let mut new_array = mts_array_t::null();
252        unsafe {
253            check_status(function(self.array.ptr, device, &mut new_array))?;
254        }
255
256        return Ok(MtsArray::from_raw(new_array));
257    }
258
259    /// Move the data in this array to another array, if supported by the underlying data.
260    ///
261    /// This corresponds to `mts_array_t.move_data`, but with a more convenient API.
262    pub fn move_data<'input>(
263        &mut self,
264        input: impl Into<ArrayRef<'input>>,
265        moves: &[mts_data_movement_t],
266    ) -> Result<(), Error> {
267        let function = self.array.move_data.expect("mts_array_t.move_data function is NULL");
268
269        let input = input.into();
270        unsafe {
271            check_status(function(
272                self.array.ptr,
273                input.as_raw().ptr,
274                moves.as_ptr(),
275                moves.len(),
276            ))?;
277        }
278
279        return Ok(());
280    }
281}
282
283impl<'a> From<&'a MtsArray> for ArrayRef<'a> {
284    fn from(array: &'a MtsArray) -> ArrayRef<'a> {
285        array.as_ref()
286    }
287}
288
289impl<'a> From<&'a mut MtsArray> for ArrayRefMut<'a> {
290    fn from(array: &'a mut MtsArray) -> ArrayRefMut<'a> {
291        array.as_mut()
292    }
293}