Skip to main content

metatensor/data/
ndarray_array.rs

1use std::sync::{Arc, RwLock, TryLockError};
2
3use dlpk::sys::{DLDevice, DLPackVersion, DLDataType};
4use dlpk::{DLDataTypeCode, DLPackPointerCast, DLPackTensor, GetDLPackDataType, ReadOnly};
5
6use crate::errors::Error;
7use crate::c_api::mts_data_movement_t;
8
9use super::{Array, MtsArray};
10
11impl<T, D> From<ndarray::Array<T, D>> for MtsArray where
12    T: 'static + Clone + Send + Default + Sync + GetDLPackDataType + DLPackPointerCast,
13    D: ndarray::Dimension
14{
15    fn from(value: ndarray::Array<T, D>) -> Self {
16        let array = Arc::new(RwLock::new(value.into_dyn()));
17        let boxed: Box<dyn Array> = Box::new(array);
18        return MtsArray::from(boxed);
19    }
20}
21
22impl<T> Array for Arc<RwLock<ndarray::ArrayD<T>>>
23where
24    T: 'static + Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast,
25{
26    fn as_any(&self) -> &dyn std::any::Any {
27        self
28    }
29
30    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
31        self
32    }
33
34    fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array> {
35        let cpu_device = DLDevice::cpu();
36        let max_version = DLPackVersion::current();
37        let fill_value_dlpack = fill_value.as_dlpack(cpu_device, None, max_version).expect("failed to extract fill_value as DLPack");
38
39        // Validate fill_value shape from the DLPack tensor directly
40        assert_eq!(fill_value_dlpack.shape(), &[], "fill_value must be a single scalar");
41        assert_eq!(fill_value_dlpack.device(), cpu_device, "fill_value must be on CPU");
42
43        let fill_value_ptr = fill_value_dlpack.data_ptr::<T>().expect("dtype mismatch between array and fill_value");
44        let fill_value_scalar = unsafe { std::ptr::read(fill_value_ptr) };
45
46        let array = ndarray::Array::from_elem(shape, fill_value_scalar);
47        return Box::new(Arc::new(RwLock::new(array)));
48    }
49
50    fn copy(&self, device: DLDevice) -> Box<dyn Array> {
51        assert_eq!(device, DLDevice::cpu(), "Rust ndarray data can only be copied to CPU device");
52        return Box::new(self.clone());
53    }
54
55    fn shape(&self) -> Vec<usize> {
56        match self.try_read() {
57            Ok(lock) => lock.shape().to_vec(),
58            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
59            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
60        }
61    }
62
63    fn reshape(&mut self, shape: &[usize]) {
64        let mut lock = match self.try_write() {
65            Ok(lock) => lock,
66            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
67            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
68        };
69        let array = std::mem::take(&mut *lock);
70        let array = array.into_shape_clone(shape).expect("invalid shape");
71        let _ = std::mem::replace(&mut *lock, array);
72    }
73
74    fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
75        let mut lock = match self.try_write() {
76            Ok(lock) => lock,
77            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
78            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
79        };
80        lock.swap_axes(axis_1, axis_2);
81    }
82
83    fn move_data(
84        &mut self,
85        input: &dyn Array,
86        movements: &[mts_data_movement_t],
87    ) {
88        use ndarray::{Axis, Slice};
89
90        let input = input.as_any().downcast_ref::<Self>().expect("input must be a ndarray of the same type");
91        let input = match input.try_read() {
92            Ok(lock) => lock,
93            Err(TryLockError::Poisoned(_)) => panic!("input array lock is poisoned"),
94            Err(TryLockError::WouldBlock) => panic!("input array is already locked"),
95        };
96
97        let mut output = match self.try_write() {
98            Ok(lock) => lock,
99            Err(TryLockError::Poisoned(_)) => panic!("output array lock is poisoned"),
100            Err(TryLockError::WouldBlock) => panic!("output array is already locked"),
101        };
102
103        if movements.is_empty() {
104            return;
105        }
106
107        // Check if we can use the optimized path (all moves have same property structure)
108        let first_prop_start_in = movements[0].properties_start_in;
109        let first_prop_start_out = movements[0].properties_start_out;
110        let first_prop_len = movements[0].properties_length;
111
112        let mut constant_properties = true;
113        let mut contiguous_input_samples = true;
114        let mut contiguous_output_samples = true;
115
116        for w in movements.windows(2) {
117            if w[0].properties_start_in != first_prop_start_in ||
118               w[0].properties_start_out != first_prop_start_out ||
119               w[0].properties_length != first_prop_len {
120                constant_properties = false;
121                break;
122            }
123
124            if w[1].sample_in != w[0].sample_in + 1 {
125                contiguous_input_samples = false;
126            }
127
128            if w[1].sample_out != w[0].sample_out + 1 {
129                contiguous_output_samples = false;
130            }
131        }
132
133        if constant_properties {
134            let last = movements.last().unwrap();
135            if last.properties_start_in != first_prop_start_in ||
136               last.properties_start_out != first_prop_start_out ||
137               last.properties_length != first_prop_len {
138                constant_properties = false;
139            }
140        }
141
142        let property_axis = output.shape().len() - 1;
143
144        if constant_properties {
145            let input_slice_info = Slice::from(first_prop_start_in..(first_prop_start_in + first_prop_len));
146            let output_slice_info = Slice::from(first_prop_start_out..(first_prop_start_out + first_prop_len));
147
148            if contiguous_input_samples && contiguous_output_samples {
149                let sample_start_in = movements[0].sample_in;
150                let sample_start_out = movements[0].sample_out;
151                let sample_count = movements.len();
152
153                let input_samples = input.slice_axis(
154                    Axis(0),
155                    Slice::from(sample_start_in..(sample_start_in + sample_count))
156                );
157                let mut output_samples = output.slice_axis_mut(
158                    Axis(0),
159                    Slice::from(sample_start_out..(sample_start_out + sample_count))
160                );
161
162                let value = input_samples.slice_axis(Axis(property_axis), input_slice_info);
163                let mut output_location = output_samples.slice_axis_mut(Axis(property_axis), output_slice_info);
164
165                output_location.assign(&value);
166            } else {
167                for move_item in movements {
168                    let input_sample = input.index_axis(Axis(0), move_item.sample_in);
169                    let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
170
171                    let value = input_sample.slice_axis(
172                        // property_axis - 1 because we are slicing the sample
173                        // axis out, so the property axis is now one less
174                        Axis(property_axis - 1),
175                        input_slice_info
176                    );
177                    let mut output_location = output_sample.slice_axis_mut(
178                        Axis(property_axis - 1),
179                        output_slice_info
180                    );
181                    output_location.assign(&value);
182                }
183            }
184        } else {
185            // fallback to the general case
186            for move_item in movements {
187                let input_sample = input.index_axis(Axis(0), move_item.sample_in);
188                let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
189
190                let value = input_sample.slice_axis(
191                    // see above for property_axis - 1 explanation
192                    Axis(property_axis - 1),
193                    Slice::from(move_item.properties_start_in..(move_item.properties_start_in + move_item.properties_length))
194                );
195                let mut output_location = output_sample.slice_axis_mut(
196                    Axis(property_axis - 1),
197                    Slice::from(move_item.properties_start_out..(move_item.properties_start_out + move_item.properties_length))
198                );
199                output_location.assign(&value);
200            }
201        }
202    }
203
204    fn device(&self) -> DLDevice {
205        DLDevice::cpu()
206    }
207
208    fn dtype(&self) -> DLDataType {
209        T::get_dlpack_data_type()
210    }
211
212    fn as_dlpack(
213        &self,
214        device: DLDevice,
215        stream: Option<i64>,
216        max_version: DLPackVersion,
217    ) -> Result<DLPackTensor, Error> {
218        if stream.is_some() {
219            // we only support CPU for now
220            return Err(Error {
221                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
222                message: "CPU arrays can not be used with a stream".into(),
223            });
224        }
225        let vendored_version = DLPackVersion::current();
226        if max_version.major != vendored_version.major {
227            return Err(Error {
228                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
229                message: format!(
230                    "invalid `max_version` in ndarray::ArrayD<T>::as_dlpack: \
231                    we got v{}.{}, but we support v{}.{}",
232                    max_version.major, max_version.minor,
233                    vendored_version.major, vendored_version.minor
234                ),
235            });
236        }
237
238        let ndarray_device = DLDevice::cpu();
239
240        if device.device_type != ndarray_device.device_type || device.device_id != ndarray_device.device_id {
241            return Err(Error {
242                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
243                message: format!(
244                    "Requested DLPack device ({}) does not match array device ({})",
245                    device, ndarray_device
246                ),
247            });
248        }
249
250        let tensor: DLPackTensor = ReadOnly(Arc::clone(self)).try_into().map_err(|e| Error {
251            code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
252            message: format!("failed to convert ndarray to DLPack: {:?}", e),
253        })?;
254
255        Ok(tensor)
256    }
257
258    #[allow(clippy::enum_glob_use)]
259    fn from_dlpack(&self, dlpack_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error> {
260        use DLDataTypeCode::*;
261
262        let dtype = dlpack_tensor.dtype();
263
264        if dtype.lanes != 1 {
265            return Err(Error {
266                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
267                message: "Only DLPack tensors with lanes == 1 are supported".into(),
268            });
269        }
270
271        let map_error = |e| Error {
272            code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
273            message: format!("failed to convert DLPack to ndarray: {:?}", e),
274        };
275
276        if dtype.code == kDLFloat && dtype.bits == 64 {
277            let array: ndarray::ArrayD<f64> = dlpack_tensor.try_into().map_err(map_error)?;
278            return Ok(Box::new(Arc::new(RwLock::new(array))));
279        } else if dtype.code == kDLFloat && dtype.bits == 32 {
280            let array: ndarray::ArrayD<f32> = dlpack_tensor.try_into().map_err(map_error)?;
281            return Ok(Box::new(Arc::new(RwLock::new(array))));
282        } else if dtype.code == kDLInt && dtype.bits == 8 {
283            let array: ndarray::ArrayD<i8> = dlpack_tensor.try_into().map_err(map_error)?;
284            return Ok(Box::new(Arc::new(RwLock::new(array))));
285        } else if dtype.code == kDLInt && dtype.bits == 16 {
286            let array: ndarray::ArrayD<i16> = dlpack_tensor.try_into().map_err(map_error)?;
287            return Ok(Box::new(Arc::new(RwLock::new(array))));
288        } else if dtype.code == kDLInt && dtype.bits == 32 {
289            let array: ndarray::ArrayD<i32> = dlpack_tensor.try_into().map_err(map_error)?;
290            return Ok(Box::new(Arc::new(RwLock::new(array))));
291        } else if dtype.code == kDLInt && dtype.bits == 64 {
292            let array: ndarray::ArrayD<i64> = dlpack_tensor.try_into().map_err(map_error)?;
293            return Ok(Box::new(Arc::new(RwLock::new(array))));
294        } else if dtype.code == kDLUInt && dtype.bits == 8 {
295            let array: ndarray::ArrayD<u8> = dlpack_tensor.try_into().map_err(map_error)?;
296            return Ok(Box::new(Arc::new(RwLock::new(array))));
297        } else if dtype.code == kDLUInt && dtype.bits == 16 {
298            let array: ndarray::ArrayD<u16> = dlpack_tensor.try_into().map_err(map_error)?;
299            return Ok(Box::new(Arc::new(RwLock::new(array))));
300        } else if dtype.code == kDLUInt && dtype.bits == 32 {
301            let array: ndarray::ArrayD<u32> = dlpack_tensor.try_into().map_err(map_error)?;
302            return Ok(Box::new(Arc::new(RwLock::new(array))));
303        } else if dtype.code == kDLUInt && dtype.bits == 64 {
304            let array: ndarray::ArrayD<u64> = dlpack_tensor.try_into().map_err(map_error)?;
305            return Ok(Box::new(Arc::new(RwLock::new(array))));
306        } else if dtype.code == kDLBool && dtype.bits == 8 {
307            let array: ndarray::ArrayD<bool> = dlpack_tensor.try_into().map_err(map_error)?;
308            return Ok(Box::new(Arc::new(RwLock::new(array))));
309        } else {
310            return Err(Error {
311                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
312                message: format!("Unsupported DLPack dtype {}", dtype),
313            });
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use dlpk::{DLPackPointerCast, GetDLPackDataType, sys::{DLDataTypeCode, DLDevice, DLPackVersion}};
321    use crate::MtsArray;
322
323    #[test]
324    fn ndarray_as_mts_array() {
325        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3, 4]);
326        let mts_array = MtsArray::from(data);
327
328        assert_eq!(mts_array.shape().unwrap(), [2, 3, 4]);
329
330        let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], 42.0));
331
332        let created = mts_array.create(&[2, 3, 4], fill_value.as_ref()).unwrap();
333        assert_eq!(created.shape().unwrap(), [2, 3, 4]);
334    }
335
336    #[test]
337    fn ndarray_as_mts_array_dlpack() {
338        let data = ndarray::Array::<f64, _>::zeros(vec![4, 5, 6]);
339        let mts_array = MtsArray::from(data);
340
341        let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
342
343        assert_eq!(dl_managed.n_dims(), 3);
344        assert_eq!(dl_managed.shape(), [4, 5, 6]);
345
346        assert_eq!(dl_managed.dtype().code, DLDataTypeCode::kDLFloat);
347        assert_eq!(dl_managed.dtype().bits, 64);
348        assert_eq!(dl_managed.dtype().lanes, 1);
349    }
350
351    #[test]
352    fn ndarray_all_dtypes() {
353        fn test_for_dtype<T>(code: DLDataTypeCode, bits: u8) where T: Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast + 'static {
354            let data = ndarray::Array::<T, _>::from_elem(vec![2, 2], T::default());
355            let mts_array = MtsArray::from(data);
356
357            assert_eq!(mts_array.shape().unwrap(), [2, 2]);
358
359            // Should be able to export as DLPack
360            let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
361            assert_eq!(dl_managed.dtype().code, code);
362            assert_eq!(dl_managed.dtype().bits, bits);
363            assert_eq!(dl_managed.dtype().lanes, 1);
364
365
366            // And `create` should make an array of the same type (i32)
367            let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], T::default()));
368
369            let created = mts_array.create(&[1, 1], fill_value.as_ref()).unwrap();
370            let dl_managed = created.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
371
372            assert_eq!(dl_managed.dtype().code, code);
373            assert_eq!(dl_managed.dtype().bits, bits);
374            assert_eq!(dl_managed.dtype().lanes, 1);
375        }
376
377        test_for_dtype::<bool>(DLDataTypeCode::kDLBool, 8);
378        test_for_dtype::<f64>(DLDataTypeCode::kDLFloat, 64);
379        test_for_dtype::<f32>(DLDataTypeCode::kDLFloat, 32);
380        test_for_dtype::<i8>(DLDataTypeCode::kDLInt, 8);
381        test_for_dtype::<i16>(DLDataTypeCode::kDLInt, 16);
382        test_for_dtype::<i32>(DLDataTypeCode::kDLInt, 32);
383        test_for_dtype::<i64>(DLDataTypeCode::kDLInt, 64);
384        test_for_dtype::<u8>(DLDataTypeCode::kDLUInt, 8);
385        test_for_dtype::<u16>(DLDataTypeCode::kDLUInt, 16);
386        test_for_dtype::<u32>(DLDataTypeCode::kDLUInt, 32);
387        test_for_dtype::<u64>(DLDataTypeCode::kDLUInt, 64);
388    }
389
390    #[test]
391    fn ndarray_device() {
392        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
393        let mts_array = MtsArray::from(data);
394
395        assert_eq!(mts_array.device().unwrap(), DLDevice::cpu());
396    }
397
398    #[test]
399    fn as_dlpack_rejects_stream() {
400        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
401        let mts_array = MtsArray::from(data);
402        match mts_array.as_dlpack(DLDevice::cpu(), Some(42), DLPackVersion::current()) {
403            Err(e) => assert!(e.message.contains("stream"), "{}", e.message),
404            Ok(_) => panic!("expected error for non-null stream"),
405        }
406    }
407
408    #[test]
409    fn as_dlpack_rejects_wrong_device() {
410        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
411        let mts_array = MtsArray::from(data);
412        let cuda = DLDevice {
413            device_type: dlpk::sys::DLDeviceType::kDLCUDA,
414            device_id: 0,
415        };
416        match mts_array.as_dlpack(cuda, None, DLPackVersion::current()) {
417            Err(e) => assert!(e.message.contains("does not match"), "{}", e.message),
418            Ok(_) => panic!("expected error for CUDA device on CPU array"),
419        }
420    }
421
422    #[test]
423    fn as_dlpack_rejects_incompatible_version() {
424        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
425        let mts_array = MtsArray::from(data);
426
427        let bad_version = DLPackVersion { major: 99, minor: 0 };
428        match mts_array.as_dlpack(DLDevice::cpu(), None, bad_version) {
429            Err(e) => assert!(e.message.contains("version"), "{}", e.message),
430            Ok(_) => panic!("expected error for incompatible DLPack version"),
431        }
432    }
433
434    #[test]
435    #[allow(clippy::float_cmp)]
436    fn from_dlpack() {
437        let mut f64_data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
438        f64_data[[0, 0]] = 1.573;
439        f64_data[[1, 2]] = -42.0;
440        let f64_array = MtsArray::from(f64_data);
441
442        let mut i16_data = ndarray::Array::<i16, _>::zeros(vec![2, 5, 10]);
443        i16_data[[0, 1, 3]] = 3;
444        i16_data[[1, 2, 4]] = -42;
445        let i16_array = MtsArray::from(i16_data);
446
447        let f64_dl_tensor = f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
448        let i16_dl_tensor = i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
449
450        let new_f64_array = f64_array.from_dlpack(f64_dl_tensor).unwrap();
451        let new_i16_array = i16_array.from_dlpack(i16_dl_tensor).unwrap();
452
453        assert_eq!(f64_array.origin().unwrap(), i16_array.origin().unwrap());
454        assert_eq!(new_f64_array.origin().unwrap(), f64_array.origin().unwrap());
455        assert_eq!(new_i16_array.origin().unwrap(), i16_array.origin().unwrap());
456
457        let new_f64_dl_tensor = new_f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
458        let new_i16_dl_tensor = new_i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
459
460        let new_f64_data: ndarray::ArrayD<f64> = new_f64_dl_tensor.try_into().unwrap();
461        let new_i16_data: ndarray::ArrayD<i16> = new_i16_dl_tensor.try_into().unwrap();
462
463        assert_eq!(new_f64_data[[0, 0]], 1.573);
464        assert_eq!(new_f64_data[[1, 2]], -42.0);
465
466        assert_eq!(new_i16_data[[0, 1, 3]], 3);
467        assert_eq!(new_i16_data[[1, 2, 4]], -42);
468    }
469}