metatensor/data/
external.rs1use std::sync::{Arc, RwLock, RwLockReadGuard};
2
3use ndarray::ArrayD;
4use dlpk::sys::DLDevice;
5
6use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t};
7
8use crate::Error;
9use crate::errors::check_status;
10
11use super::{ArrayRef, ArrayRefMut};
12use super::origin::get_data_origin;
13
14pub struct MtsArray {
18 array: mts_array_t
19}
20
21impl Drop for MtsArray {
22 fn drop(&mut self) {
23 if let Some(destroy) = self.array.destroy {
24 unsafe { destroy(self.array.ptr) }
25 }
26 }
27}
28
29impl MtsArray {
30 pub fn from_raw(array: mts_array_t) -> MtsArray {
33 MtsArray { array }
34 }
35
36 pub fn into_raw(self) -> mts_array_t {
39 let array = self.array;
40 std::mem::forget(self);
43 array
44 }
45
46 #[inline]
51 pub fn as_any(&self) -> &dyn std::any::Any {
52 let origin = self.origin().unwrap_or(0);
53 assert_eq!(
54 origin, *super::array::RUST_DATA_ORIGIN,
55 "this array was not created as a rust Array (origin is '{}')",
56 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
57 );
58
59 let array = self.array.ptr.cast::<super::array::RustArray>();
60 unsafe {
61 return (*array).as_any();
62 }
63 }
64
65 #[inline]
66 fn as_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
67 self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
68 }
69
70 #[inline]
73 pub fn as_ndarray<T>(&self) -> RwLockReadGuard<'_, ArrayD<T>> where T: 'static {
74 return self.as_lock().read().expect("lock was poisoned");
75 }
76
77 pub fn as_raw(&self) -> &mts_array_t {
79 &self.array
80 }
81
82 pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
84 &mut self.array
85 }
86
87 pub fn as_ref(&'_ self) -> ArrayRef<'_> {
89 unsafe { ArrayRef::from_raw(self.array) }
90 }
91
92 pub fn as_mut(&'_ mut self) -> ArrayRefMut<'_> {
94 unsafe { ArrayRefMut::from_raw(self.array) }
95 }
96
97 pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
101 let function = self.array.origin.expect("mts_array_t.origin function is NULL");
102
103 let mut origin = 0;
104 unsafe {
105 check_status(function(self.array.ptr, &mut origin))?;
106 }
107
108 return Ok(origin);
109 }
110
111 pub fn device(&self) -> Result<DLDevice, Error> {
115 let function = self.array.device.expect("mts_array_t.device function is NULL");
116
117 let mut device = DLDevice::cpu();
118 unsafe {
119 check_status(function(self.array.ptr, &mut device))?;
120 }
121
122 return Ok(device);
123 }
124
125 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
129 let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
130
131 let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
132 unsafe {
133 check_status(function(self.array.ptr, &mut dtype))?;
134 }
135
136 return Ok(dtype);
137 }
138
139 pub fn as_dlpack(
143 &self,
144 device: DLDevice,
145 stream: Option<i64>,
146 max_version: dlpk::sys::DLPackVersion,
147 ) -> Result<dlpk::DLPackTensor, Error> {
148 let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
149
150 let mut tensor = std::ptr::null_mut();
151 let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
152
153 unsafe {
154 check_status(function(self.array.ptr, &mut tensor, device, stream_c, max_version))?;
155 }
156
157 let tensor = unsafe {
158 dlpk::DLPackTensor::from_ptr(tensor)
159 };
160
161 return Ok(tensor);
162 }
163
164 pub fn from_dlpack(&self, dlpack_tensor: dlpk::DLPackTensor) -> Result<MtsArray, Error> {
165 let function = self.array.from_dlpack.expect("mts_array_t.from_dlpack function is NULL");
166
167 let mut new_array = mts_array_t::null();
168 unsafe {
169 check_status(function(self.array.ptr, dlpack_tensor.into_raw().as_ptr(), &mut new_array))?;
170 }
171
172 return Ok(MtsArray::from_raw(new_array));
173 }
174
175 pub fn shape(&self) -> Result<&[usize], Error> {
179 let function = self.array.shape.expect("mts_array_t.shape function is NULL");
180
181 let mut shape = std::ptr::null();
182 let mut shape_count: usize = 0;
183
184 unsafe {
185 check_status(function(self.array.ptr, &mut shape, &mut shape_count))?;
186 }
187
188 if shape_count == 0 {
189 return Ok(&[]);
190 } else {
191 assert!(!shape.is_null());
192 let shape = unsafe {
193 std::slice::from_raw_parts(shape, shape_count)
194 };
195 return Ok(shape);
196 }
197 }
198
199 pub fn reshape(&mut self, shape: &[usize]) -> Result<(), Error> {
203 let function = self.array.reshape.expect("mts_array_t.reshape function is NULL");
204
205 unsafe {
206 check_status(function(self.array.ptr, shape.as_ptr(), shape.len()))?;
207 }
208
209 return Ok(());
210 }
211
212 pub fn swap_axes(&mut self, axis_1: usize, axis_2: usize) -> Result<(), Error> {
216 let function = self.array.swap_axes.expect("mts_array_t.swap_axes function is NULL");
217
218 unsafe {
219 check_status(function(self.array.ptr, axis_1, axis_2))?;
220 }
221
222 return Ok(());
223 }
224
225 pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
230 let function = self.array.create.expect("mts_array_t.create function is NULL");
231
232 let mut new_array = mts_array_t::null();
233 unsafe {
234 check_status(function(
235 self.array.ptr,
236 shape.as_ptr(),
237 shape.len(),
238 *fill_value.as_raw(),
239 &mut new_array
240 ))?;
241 }
242
243 return Ok(MtsArray::from_raw(new_array));
244 }
245
246 pub fn copy(&self, device: DLDevice) -> Result<MtsArray, Error> {
250 let function = self.array.copy.expect("mts_array_t.copy function is NULL");
251 let mut new_array = mts_array_t::null();
252 unsafe {
253 check_status(function(self.array.ptr, device, &mut new_array))?;
254 }
255
256 return Ok(MtsArray::from_raw(new_array));
257 }
258
259 pub fn move_data<'input>(
263 &mut self,
264 input: impl Into<ArrayRef<'input>>,
265 moves: &[mts_data_movement_t],
266 ) -> Result<(), Error> {
267 let function = self.array.move_data.expect("mts_array_t.move_data function is NULL");
268
269 let input = input.into();
270 unsafe {
271 check_status(function(
272 self.array.ptr,
273 input.as_raw().ptr,
274 moves.as_ptr(),
275 moves.len(),
276 ))?;
277 }
278
279 return Ok(());
280 }
281}
282
283impl<'a> From<&'a MtsArray> for ArrayRef<'a> {
284 fn from(array: &'a MtsArray) -> ArrayRef<'a> {
285 array.as_ref()
286 }
287}
288
289impl<'a> From<&'a mut MtsArray> for ArrayRefMut<'a> {
290 fn from(array: &'a mut MtsArray) -> ArrayRefMut<'a> {
291 array.as_mut()
292 }
293}