Skip to main content

metatensor/data/
ndarray_array.rs

1use std::sync::{Arc, RwLock, TryLockError};
2
3use dlpk::sys::{DLDevice, DLPackVersion, DLDataType};
4use dlpk::{DLDataTypeCode, DLPackPointerCast, DLPackTensor, GetDLPackDataType, ReadOnly};
5
6use crate::errors::Error;
7use crate::c_api::mts_data_movement_t;
8
9use super::{Array, MtsArray};
10
11impl<T> From<ndarray::ArrayD<T>> for MtsArray where T: 'static + Clone + Send + Default + Sync + GetDLPackDataType + DLPackPointerCast {
12    fn from(value: ndarray::ArrayD<T>) -> Self {
13        let array = Arc::new(RwLock::new(value));
14        let boxed: Box<dyn Array> = Box::new(array);
15        return MtsArray::from(boxed);
16    }
17}
18
19impl<T> Array for Arc<RwLock<ndarray::ArrayD<T>>>
20where
21    T: 'static + Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast,
22{
23    fn as_any(&self) -> &dyn std::any::Any {
24        self
25    }
26
27    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
28        self
29    }
30
31    fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array> {
32        let cpu_device = DLDevice::cpu();
33        let max_version = DLPackVersion::current();
34        let fill_value_dlpack = fill_value.as_dlpack(cpu_device, None, max_version).expect("failed to extract fill_value as DLPack");
35
36        // Validate fill_value shape from the DLPack tensor directly
37        assert_eq!(fill_value_dlpack.shape(), &[], "fill_value must be a single scalar");
38        assert_eq!(fill_value_dlpack.device(), cpu_device, "fill_value must be on CPU");
39
40        let fill_value_ptr = fill_value_dlpack.data_ptr::<T>().expect("dtype mismatch between array and fill_value");
41        let fill_value_scalar = unsafe { std::ptr::read(fill_value_ptr) };
42
43        let array = ndarray::Array::from_elem(shape, fill_value_scalar);
44        return Box::new(Arc::new(RwLock::new(array)));
45    }
46
47    fn copy(&self, device: DLDevice) -> Box<dyn Array> {
48        assert_eq!(device, DLDevice::cpu(), "Rust ndarray data can only be copied to CPU device");
49        return Box::new(self.clone());
50    }
51
52    fn shape(&self) -> Vec<usize> {
53        match self.try_read() {
54            Ok(lock) => lock.shape().to_vec(),
55            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
56            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
57        }
58    }
59
60    fn reshape(&mut self, shape: &[usize]) {
61        let mut lock = match self.try_write() {
62            Ok(lock) => lock,
63            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
64            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
65        };
66        let array = std::mem::take(&mut *lock);
67        let array = array.into_shape_clone(shape).expect("invalid shape");
68        let _ = std::mem::replace(&mut *lock, array);
69    }
70
71    fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
72        let mut lock = match self.try_write() {
73            Ok(lock) => lock,
74            Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
75            Err(TryLockError::WouldBlock) => panic!("array is already locked"),
76        };
77        lock.swap_axes(axis_1, axis_2);
78    }
79
80    fn move_data(
81        &mut self,
82        input: &dyn Array,
83        movements: &[mts_data_movement_t],
84    ) {
85        use ndarray::{Axis, Slice};
86
87        let input = input.as_any().downcast_ref::<Self>().expect("input must be a ndarray of the same type");
88        let input = match input.try_read() {
89            Ok(lock) => lock,
90            Err(TryLockError::Poisoned(_)) => panic!("input array lock is poisoned"),
91            Err(TryLockError::WouldBlock) => panic!("input array is already locked"),
92        };
93
94        let mut output = match self.try_write() {
95            Ok(lock) => lock,
96            Err(TryLockError::Poisoned(_)) => panic!("output array lock is poisoned"),
97            Err(TryLockError::WouldBlock) => panic!("output array is already locked"),
98        };
99
100        if movements.is_empty() {
101            return;
102        }
103
104        // Check if we can use the optimized path (all moves have same property structure)
105        let first_prop_start_in = movements[0].properties_start_in;
106        let first_prop_start_out = movements[0].properties_start_out;
107        let first_prop_len = movements[0].properties_length;
108
109        let mut constant_properties = true;
110        let mut contiguous_input_samples = true;
111        let mut contiguous_output_samples = true;
112
113        for w in movements.windows(2) {
114            if w[0].properties_start_in != first_prop_start_in ||
115               w[0].properties_start_out != first_prop_start_out ||
116               w[0].properties_length != first_prop_len {
117                constant_properties = false;
118                break;
119            }
120
121            if w[1].sample_in != w[0].sample_in + 1 {
122                contiguous_input_samples = false;
123            }
124
125            if w[1].sample_out != w[0].sample_out + 1 {
126                contiguous_output_samples = false;
127            }
128        }
129
130        if constant_properties {
131            let last = movements.last().unwrap();
132            if last.properties_start_in != first_prop_start_in ||
133               last.properties_start_out != first_prop_start_out ||
134               last.properties_length != first_prop_len {
135                constant_properties = false;
136            }
137        }
138
139        let property_axis = output.shape().len() - 1;
140
141        if constant_properties {
142            let input_slice_info = Slice::from(first_prop_start_in..(first_prop_start_in + first_prop_len));
143            let output_slice_info = Slice::from(first_prop_start_out..(first_prop_start_out + first_prop_len));
144
145            if contiguous_input_samples && contiguous_output_samples {
146                let sample_start_in = movements[0].sample_in;
147                let sample_start_out = movements[0].sample_out;
148                let sample_count = movements.len();
149
150                let input_samples = input.slice_axis(
151                    Axis(0),
152                    Slice::from(sample_start_in..(sample_start_in + sample_count))
153                );
154                let mut output_samples = output.slice_axis_mut(
155                    Axis(0),
156                    Slice::from(sample_start_out..(sample_start_out + sample_count))
157                );
158
159                let value = input_samples.slice_axis(Axis(property_axis), input_slice_info);
160                let mut output_location = output_samples.slice_axis_mut(Axis(property_axis), output_slice_info);
161
162                output_location.assign(&value);
163            } else {
164                for move_item in movements {
165                    let input_sample = input.index_axis(Axis(0), move_item.sample_in);
166                    let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
167
168                    let value = input_sample.slice_axis(
169                        // property_axis - 1 because we are slicing the sample
170                        // axis out, so the property axis is now one less
171                        Axis(property_axis - 1),
172                        input_slice_info
173                    );
174                    let mut output_location = output_sample.slice_axis_mut(
175                        Axis(property_axis - 1),
176                        output_slice_info
177                    );
178                    output_location.assign(&value);
179                }
180            }
181        } else {
182            // fallback to the general case
183            for move_item in movements {
184                let input_sample = input.index_axis(Axis(0), move_item.sample_in);
185                let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
186
187                let value = input_sample.slice_axis(
188                    // see above for property_axis - 1 explanation
189                    Axis(property_axis - 1),
190                    Slice::from(move_item.properties_start_in..(move_item.properties_start_in + move_item.properties_length))
191                );
192                let mut output_location = output_sample.slice_axis_mut(
193                    Axis(property_axis - 1),
194                    Slice::from(move_item.properties_start_out..(move_item.properties_start_out + move_item.properties_length))
195                );
196                output_location.assign(&value);
197            }
198        }
199    }
200
201    fn device(&self) -> DLDevice {
202        DLDevice::cpu()
203    }
204
205    fn dtype(&self) -> DLDataType {
206        T::get_dlpack_data_type()
207    }
208
209    fn as_dlpack(
210        &self,
211        device: DLDevice,
212        stream: Option<i64>,
213        max_version: DLPackVersion,
214    ) -> Result<DLPackTensor, Error> {
215        if stream.is_some() {
216            // we only support CPU for now
217            return Err(Error {
218                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
219                message: "CPU arrays can not be used with a stream".into(),
220            });
221        }
222        let vendored_version = DLPackVersion::current();
223        if max_version.major != vendored_version.major {
224            return Err(Error {
225                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
226                message: format!(
227                    "invalid `max_version` in ndarray::ArrayD<T>::as_dlpack: \
228                    we got v{}.{}, but we support v{}.{}",
229                    max_version.major, max_version.minor,
230                    vendored_version.major, vendored_version.minor
231                ),
232            });
233        }
234
235        let ndarray_device = DLDevice::cpu();
236
237        if device.device_type != ndarray_device.device_type || device.device_id != ndarray_device.device_id {
238            return Err(Error {
239                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
240                message: format!(
241                    "Requested DLPack device ({}) does not match array device ({})",
242                    device, ndarray_device
243                ),
244            });
245        }
246
247        let tensor: DLPackTensor = ReadOnly(Arc::clone(self)).try_into().map_err(|e| Error {
248            code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
249            message: format!("failed to convert ndarray to DLPack: {:?}", e),
250        })?;
251
252        Ok(tensor)
253    }
254
255    #[allow(clippy::enum_glob_use)]
256    fn from_dlpack(&self, dlpack_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error> {
257        use DLDataTypeCode::*;
258
259        let dtype = dlpack_tensor.dtype();
260
261        if dtype.lanes != 1 {
262            return Err(Error {
263                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
264                message: "Only DLPack tensors with lanes == 1 are supported".into(),
265            });
266        }
267
268        let map_error = |e| Error {
269            code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
270            message: format!("failed to convert DLPack to ndarray: {:?}", e),
271        };
272
273        if dtype.code == kDLFloat && dtype.bits == 64 {
274            let array: ndarray::ArrayD<f64> = dlpack_tensor.try_into().map_err(map_error)?;
275            return Ok(Box::new(Arc::new(RwLock::new(array))));
276        } else if dtype.code == kDLFloat && dtype.bits == 32 {
277            let array: ndarray::ArrayD<f32> = dlpack_tensor.try_into().map_err(map_error)?;
278            return Ok(Box::new(Arc::new(RwLock::new(array))));
279        } else if dtype.code == kDLInt && dtype.bits == 8 {
280            let array: ndarray::ArrayD<i8> = dlpack_tensor.try_into().map_err(map_error)?;
281            return Ok(Box::new(Arc::new(RwLock::new(array))));
282        } else if dtype.code == kDLInt && dtype.bits == 16 {
283            let array: ndarray::ArrayD<i16> = dlpack_tensor.try_into().map_err(map_error)?;
284            return Ok(Box::new(Arc::new(RwLock::new(array))));
285        } else if dtype.code == kDLInt && dtype.bits == 32 {
286            let array: ndarray::ArrayD<i32> = dlpack_tensor.try_into().map_err(map_error)?;
287            return Ok(Box::new(Arc::new(RwLock::new(array))));
288        } else if dtype.code == kDLInt && dtype.bits == 64 {
289            let array: ndarray::ArrayD<i64> = dlpack_tensor.try_into().map_err(map_error)?;
290            return Ok(Box::new(Arc::new(RwLock::new(array))));
291        } else if dtype.code == kDLUInt && dtype.bits == 8 {
292            let array: ndarray::ArrayD<u8> = dlpack_tensor.try_into().map_err(map_error)?;
293            return Ok(Box::new(Arc::new(RwLock::new(array))));
294        } else if dtype.code == kDLUInt && dtype.bits == 16 {
295            let array: ndarray::ArrayD<u16> = dlpack_tensor.try_into().map_err(map_error)?;
296            return Ok(Box::new(Arc::new(RwLock::new(array))));
297        } else if dtype.code == kDLUInt && dtype.bits == 32 {
298            let array: ndarray::ArrayD<u32> = dlpack_tensor.try_into().map_err(map_error)?;
299            return Ok(Box::new(Arc::new(RwLock::new(array))));
300        } else if dtype.code == kDLUInt && dtype.bits == 64 {
301            let array: ndarray::ArrayD<u64> = dlpack_tensor.try_into().map_err(map_error)?;
302            return Ok(Box::new(Arc::new(RwLock::new(array))));
303        } else if dtype.code == kDLBool && dtype.bits == 8 {
304            let array: ndarray::ArrayD<bool> = dlpack_tensor.try_into().map_err(map_error)?;
305            return Ok(Box::new(Arc::new(RwLock::new(array))));
306        } else {
307            return Err(Error {
308                code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
309                message: format!("Unsupported DLPack dtype {}", dtype),
310            });
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use dlpk::{DLPackPointerCast, GetDLPackDataType, sys::{DLDataTypeCode, DLDevice, DLPackVersion}};
318    use crate::MtsArray;
319
320    #[test]
321    fn ndarray_as_mts_array() {
322        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3, 4]);
323        let mts_array = MtsArray::from(data);
324
325        assert_eq!(mts_array.shape().unwrap(), [2, 3, 4]);
326
327        let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], 42.0));
328
329        let created = mts_array.create(&[2, 3, 4], fill_value.as_ref()).unwrap();
330        assert_eq!(created.shape().unwrap(), [2, 3, 4]);
331    }
332
333    #[test]
334    fn ndarray_as_mts_array_dlpack() {
335        let data = ndarray::Array::<f64, _>::zeros(vec![4, 5, 6]);
336        let mts_array = MtsArray::from(data);
337
338        let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
339
340        assert_eq!(dl_managed.n_dims(), 3);
341        assert_eq!(dl_managed.shape(), [4, 5, 6]);
342
343        assert_eq!(dl_managed.dtype().code, DLDataTypeCode::kDLFloat);
344        assert_eq!(dl_managed.dtype().bits, 64);
345        assert_eq!(dl_managed.dtype().lanes, 1);
346    }
347
348    #[test]
349    fn ndarray_all_dtypes() {
350        fn test_for_dtype<T>(code: DLDataTypeCode, bits: u8) where T: Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast + 'static {
351            let data = ndarray::Array::<T, _>::from_elem(vec![2, 2], T::default());
352            let mts_array = MtsArray::from(data);
353
354            assert_eq!(mts_array.shape().unwrap(), [2, 2]);
355
356            // Should be able to export as DLPack
357            let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
358            assert_eq!(dl_managed.dtype().code, code);
359            assert_eq!(dl_managed.dtype().bits, bits);
360            assert_eq!(dl_managed.dtype().lanes, 1);
361
362
363            // And `create` should make an array of the same type (i32)
364            let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], T::default()));
365
366            let created = mts_array.create(&[1, 1], fill_value.as_ref()).unwrap();
367            let dl_managed = created.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
368
369            assert_eq!(dl_managed.dtype().code, code);
370            assert_eq!(dl_managed.dtype().bits, bits);
371            assert_eq!(dl_managed.dtype().lanes, 1);
372        }
373
374        test_for_dtype::<bool>(DLDataTypeCode::kDLBool, 8);
375        test_for_dtype::<f64>(DLDataTypeCode::kDLFloat, 64);
376        test_for_dtype::<f32>(DLDataTypeCode::kDLFloat, 32);
377        test_for_dtype::<i8>(DLDataTypeCode::kDLInt, 8);
378        test_for_dtype::<i16>(DLDataTypeCode::kDLInt, 16);
379        test_for_dtype::<i32>(DLDataTypeCode::kDLInt, 32);
380        test_for_dtype::<i64>(DLDataTypeCode::kDLInt, 64);
381        test_for_dtype::<u8>(DLDataTypeCode::kDLUInt, 8);
382        test_for_dtype::<u16>(DLDataTypeCode::kDLUInt, 16);
383        test_for_dtype::<u32>(DLDataTypeCode::kDLUInt, 32);
384        test_for_dtype::<u64>(DLDataTypeCode::kDLUInt, 64);
385    }
386
387    #[test]
388    fn ndarray_device() {
389        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
390        let mts_array = MtsArray::from(data);
391
392        assert_eq!(mts_array.device().unwrap(), DLDevice::cpu());
393    }
394
395    #[test]
396    fn as_dlpack_rejects_stream() {
397        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
398        let mts_array = MtsArray::from(data);
399        match mts_array.as_dlpack(DLDevice::cpu(), Some(42), DLPackVersion::current()) {
400            Err(e) => assert!(e.message.contains("stream"), "{}", e.message),
401            Ok(_) => panic!("expected error for non-null stream"),
402        }
403    }
404
405    #[test]
406    fn as_dlpack_rejects_wrong_device() {
407        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
408        let mts_array = MtsArray::from(data);
409        let cuda = DLDevice {
410            device_type: dlpk::sys::DLDeviceType::kDLCUDA,
411            device_id: 0,
412        };
413        match mts_array.as_dlpack(cuda, None, DLPackVersion::current()) {
414            Err(e) => assert!(e.message.contains("does not match"), "{}", e.message),
415            Ok(_) => panic!("expected error for CUDA device on CPU array"),
416        }
417    }
418
419    #[test]
420    fn as_dlpack_rejects_incompatible_version() {
421        let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
422        let mts_array = MtsArray::from(data);
423
424        let bad_version = DLPackVersion { major: 99, minor: 0 };
425        match mts_array.as_dlpack(DLDevice::cpu(), None, bad_version) {
426            Err(e) => assert!(e.message.contains("version"), "{}", e.message),
427            Ok(_) => panic!("expected error for incompatible DLPack version"),
428        }
429    }
430
431    #[test]
432    #[allow(clippy::float_cmp)]
433    fn from_dlpack() {
434        let mut f64_data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
435        f64_data[[0, 0]] = 1.573;
436        f64_data[[1, 2]] = -42.0;
437        let f64_array = MtsArray::from(f64_data);
438
439        let mut i16_data = ndarray::Array::<i16, _>::zeros(vec![2, 5, 10]);
440        i16_data[[0, 1, 3]] = 3;
441        i16_data[[1, 2, 4]] = -42;
442        let i16_array = MtsArray::from(i16_data);
443
444        let f64_dl_tensor = f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
445        let i16_dl_tensor = i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
446
447        let new_f64_array = f64_array.from_dlpack(f64_dl_tensor).unwrap();
448        let new_i16_array = i16_array.from_dlpack(i16_dl_tensor).unwrap();
449
450        assert_eq!(f64_array.origin().unwrap(), i16_array.origin().unwrap());
451        assert_eq!(new_f64_array.origin().unwrap(), f64_array.origin().unwrap());
452        assert_eq!(new_i16_array.origin().unwrap(), i16_array.origin().unwrap());
453
454        let new_f64_dl_tensor = new_f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
455        let new_i16_dl_tensor = new_i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
456
457        let new_f64_data: ndarray::ArrayD<f64> = new_f64_dl_tensor.try_into().unwrap();
458        let new_i16_data: ndarray::ArrayD<i16> = new_i16_dl_tensor.try_into().unwrap();
459
460        assert_eq!(new_f64_data[[0, 0]], 1.573);
461        assert_eq!(new_f64_data[[1, 2]], -42.0);
462
463        assert_eq!(new_i16_data[[0, 1, 3]], 3);
464        assert_eq!(new_i16_data[[1, 2, 4]], -42);
465    }
466}