1use std::sync::{Arc, RwLock, TryLockError};
2
3use dlpk::sys::{DLDevice, DLPackVersion, DLDataType};
4use dlpk::{DLDataTypeCode, DLPackPointerCast, DLPackTensor, GetDLPackDataType, ReadOnly};
5
6use crate::errors::Error;
7use crate::c_api::mts_data_movement_t;
8
9use super::{Array, MtsArray};
10
11impl<T> From<ndarray::ArrayD<T>> for MtsArray where T: 'static + Clone + Send + Default + Sync + GetDLPackDataType + DLPackPointerCast {
12 fn from(value: ndarray::ArrayD<T>) -> Self {
13 let array = Arc::new(RwLock::new(value));
14 let boxed: Box<dyn Array> = Box::new(array);
15 return MtsArray::from(boxed);
16 }
17}
18
19impl<T> Array for Arc<RwLock<ndarray::ArrayD<T>>>
20where
21 T: 'static + Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast,
22{
23 fn as_any(&self) -> &dyn std::any::Any {
24 self
25 }
26
27 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
28 self
29 }
30
31 fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array> {
32 let cpu_device = DLDevice::cpu();
33 let max_version = DLPackVersion::current();
34 let fill_value_dlpack = fill_value.as_dlpack(cpu_device, None, max_version).expect("failed to extract fill_value as DLPack");
35
36 assert_eq!(fill_value_dlpack.shape(), &[], "fill_value must be a single scalar");
38 assert_eq!(fill_value_dlpack.device(), cpu_device, "fill_value must be on CPU");
39
40 let fill_value_ptr = fill_value_dlpack.data_ptr::<T>().expect("dtype mismatch between array and fill_value");
41 let fill_value_scalar = unsafe { std::ptr::read(fill_value_ptr) };
42
43 let array = ndarray::Array::from_elem(shape, fill_value_scalar);
44 return Box::new(Arc::new(RwLock::new(array)));
45 }
46
47 fn copy(&self, device: DLDevice) -> Box<dyn Array> {
48 assert_eq!(device, DLDevice::cpu(), "Rust ndarray data can only be copied to CPU device");
49 return Box::new(self.clone());
50 }
51
52 fn shape(&self) -> Vec<usize> {
53 match self.try_read() {
54 Ok(lock) => lock.shape().to_vec(),
55 Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
56 Err(TryLockError::WouldBlock) => panic!("array is already locked"),
57 }
58 }
59
60 fn reshape(&mut self, shape: &[usize]) {
61 let mut lock = match self.try_write() {
62 Ok(lock) => lock,
63 Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
64 Err(TryLockError::WouldBlock) => panic!("array is already locked"),
65 };
66 let array = std::mem::take(&mut *lock);
67 let array = array.into_shape_clone(shape).expect("invalid shape");
68 let _ = std::mem::replace(&mut *lock, array);
69 }
70
71 fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
72 let mut lock = match self.try_write() {
73 Ok(lock) => lock,
74 Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
75 Err(TryLockError::WouldBlock) => panic!("array is already locked"),
76 };
77 lock.swap_axes(axis_1, axis_2);
78 }
79
80 fn move_data(
81 &mut self,
82 input: &dyn Array,
83 movements: &[mts_data_movement_t],
84 ) {
85 use ndarray::{Axis, Slice};
86
87 let input = input.as_any().downcast_ref::<Self>().expect("input must be a ndarray of the same type");
88 let input = match input.try_read() {
89 Ok(lock) => lock,
90 Err(TryLockError::Poisoned(_)) => panic!("input array lock is poisoned"),
91 Err(TryLockError::WouldBlock) => panic!("input array is already locked"),
92 };
93
94 let mut output = match self.try_write() {
95 Ok(lock) => lock,
96 Err(TryLockError::Poisoned(_)) => panic!("output array lock is poisoned"),
97 Err(TryLockError::WouldBlock) => panic!("output array is already locked"),
98 };
99
100 if movements.is_empty() {
101 return;
102 }
103
104 let first_prop_start_in = movements[0].properties_start_in;
106 let first_prop_start_out = movements[0].properties_start_out;
107 let first_prop_len = movements[0].properties_length;
108
109 let mut constant_properties = true;
110 let mut contiguous_input_samples = true;
111 let mut contiguous_output_samples = true;
112
113 for w in movements.windows(2) {
114 if w[0].properties_start_in != first_prop_start_in ||
115 w[0].properties_start_out != first_prop_start_out ||
116 w[0].properties_length != first_prop_len {
117 constant_properties = false;
118 break;
119 }
120
121 if w[1].sample_in != w[0].sample_in + 1 {
122 contiguous_input_samples = false;
123 }
124
125 if w[1].sample_out != w[0].sample_out + 1 {
126 contiguous_output_samples = false;
127 }
128 }
129
130 if constant_properties {
131 let last = movements.last().unwrap();
132 if last.properties_start_in != first_prop_start_in ||
133 last.properties_start_out != first_prop_start_out ||
134 last.properties_length != first_prop_len {
135 constant_properties = false;
136 }
137 }
138
139 let property_axis = output.shape().len() - 1;
140
141 if constant_properties {
142 let input_slice_info = Slice::from(first_prop_start_in..(first_prop_start_in + first_prop_len));
143 let output_slice_info = Slice::from(first_prop_start_out..(first_prop_start_out + first_prop_len));
144
145 if contiguous_input_samples && contiguous_output_samples {
146 let sample_start_in = movements[0].sample_in;
147 let sample_start_out = movements[0].sample_out;
148 let sample_count = movements.len();
149
150 let input_samples = input.slice_axis(
151 Axis(0),
152 Slice::from(sample_start_in..(sample_start_in + sample_count))
153 );
154 let mut output_samples = output.slice_axis_mut(
155 Axis(0),
156 Slice::from(sample_start_out..(sample_start_out + sample_count))
157 );
158
159 let value = input_samples.slice_axis(Axis(property_axis), input_slice_info);
160 let mut output_location = output_samples.slice_axis_mut(Axis(property_axis), output_slice_info);
161
162 output_location.assign(&value);
163 } else {
164 for move_item in movements {
165 let input_sample = input.index_axis(Axis(0), move_item.sample_in);
166 let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
167
168 let value = input_sample.slice_axis(
169 Axis(property_axis - 1),
172 input_slice_info
173 );
174 let mut output_location = output_sample.slice_axis_mut(
175 Axis(property_axis - 1),
176 output_slice_info
177 );
178 output_location.assign(&value);
179 }
180 }
181 } else {
182 for move_item in movements {
184 let input_sample = input.index_axis(Axis(0), move_item.sample_in);
185 let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
186
187 let value = input_sample.slice_axis(
188 Axis(property_axis - 1),
190 Slice::from(move_item.properties_start_in..(move_item.properties_start_in + move_item.properties_length))
191 );
192 let mut output_location = output_sample.slice_axis_mut(
193 Axis(property_axis - 1),
194 Slice::from(move_item.properties_start_out..(move_item.properties_start_out + move_item.properties_length))
195 );
196 output_location.assign(&value);
197 }
198 }
199 }
200
201 fn device(&self) -> DLDevice {
202 DLDevice::cpu()
203 }
204
205 fn dtype(&self) -> DLDataType {
206 T::get_dlpack_data_type()
207 }
208
209 fn as_dlpack(
210 &self,
211 device: DLDevice,
212 stream: Option<i64>,
213 max_version: DLPackVersion,
214 ) -> Result<DLPackTensor, Error> {
215 if stream.is_some() {
216 return Err(Error {
218 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
219 message: "CPU arrays can not be used with a stream".into(),
220 });
221 }
222 let vendored_version = DLPackVersion::current();
223 if max_version.major != vendored_version.major {
224 return Err(Error {
225 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
226 message: format!(
227 "invalid `max_version` in ndarray::ArrayD<T>::as_dlpack: \
228 we got v{}.{}, but we support v{}.{}",
229 max_version.major, max_version.minor,
230 vendored_version.major, vendored_version.minor
231 ),
232 });
233 }
234
235 let ndarray_device = DLDevice::cpu();
236
237 if device.device_type != ndarray_device.device_type || device.device_id != ndarray_device.device_id {
238 return Err(Error {
239 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
240 message: format!(
241 "Requested DLPack device ({}) does not match array device ({})",
242 device, ndarray_device
243 ),
244 });
245 }
246
247 let tensor: DLPackTensor = ReadOnly(Arc::clone(self)).try_into().map_err(|e| Error {
248 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
249 message: format!("failed to convert ndarray to DLPack: {:?}", e),
250 })?;
251
252 Ok(tensor)
253 }
254
255 #[allow(clippy::enum_glob_use)]
256 fn from_dlpack(&self, dlpack_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error> {
257 use DLDataTypeCode::*;
258
259 let dtype = dlpack_tensor.dtype();
260
261 if dtype.lanes != 1 {
262 return Err(Error {
263 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
264 message: "Only DLPack tensors with lanes == 1 are supported".into(),
265 });
266 }
267
268 let map_error = |e| Error {
269 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
270 message: format!("failed to convert DLPack to ndarray: {:?}", e),
271 };
272
273 if dtype.code == kDLFloat && dtype.bits == 64 {
274 let array: ndarray::ArrayD<f64> = dlpack_tensor.try_into().map_err(map_error)?;
275 return Ok(Box::new(Arc::new(RwLock::new(array))));
276 } else if dtype.code == kDLFloat && dtype.bits == 32 {
277 let array: ndarray::ArrayD<f32> = dlpack_tensor.try_into().map_err(map_error)?;
278 return Ok(Box::new(Arc::new(RwLock::new(array))));
279 } else if dtype.code == kDLInt && dtype.bits == 8 {
280 let array: ndarray::ArrayD<i8> = dlpack_tensor.try_into().map_err(map_error)?;
281 return Ok(Box::new(Arc::new(RwLock::new(array))));
282 } else if dtype.code == kDLInt && dtype.bits == 16 {
283 let array: ndarray::ArrayD<i16> = dlpack_tensor.try_into().map_err(map_error)?;
284 return Ok(Box::new(Arc::new(RwLock::new(array))));
285 } else if dtype.code == kDLInt && dtype.bits == 32 {
286 let array: ndarray::ArrayD<i32> = dlpack_tensor.try_into().map_err(map_error)?;
287 return Ok(Box::new(Arc::new(RwLock::new(array))));
288 } else if dtype.code == kDLInt && dtype.bits == 64 {
289 let array: ndarray::ArrayD<i64> = dlpack_tensor.try_into().map_err(map_error)?;
290 return Ok(Box::new(Arc::new(RwLock::new(array))));
291 } else if dtype.code == kDLUInt && dtype.bits == 8 {
292 let array: ndarray::ArrayD<u8> = dlpack_tensor.try_into().map_err(map_error)?;
293 return Ok(Box::new(Arc::new(RwLock::new(array))));
294 } else if dtype.code == kDLUInt && dtype.bits == 16 {
295 let array: ndarray::ArrayD<u16> = dlpack_tensor.try_into().map_err(map_error)?;
296 return Ok(Box::new(Arc::new(RwLock::new(array))));
297 } else if dtype.code == kDLUInt && dtype.bits == 32 {
298 let array: ndarray::ArrayD<u32> = dlpack_tensor.try_into().map_err(map_error)?;
299 return Ok(Box::new(Arc::new(RwLock::new(array))));
300 } else if dtype.code == kDLUInt && dtype.bits == 64 {
301 let array: ndarray::ArrayD<u64> = dlpack_tensor.try_into().map_err(map_error)?;
302 return Ok(Box::new(Arc::new(RwLock::new(array))));
303 } else if dtype.code == kDLBool && dtype.bits == 8 {
304 let array: ndarray::ArrayD<bool> = dlpack_tensor.try_into().map_err(map_error)?;
305 return Ok(Box::new(Arc::new(RwLock::new(array))));
306 } else {
307 return Err(Error {
308 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
309 message: format!("Unsupported DLPack dtype {}", dtype),
310 });
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use dlpk::{DLPackPointerCast, GetDLPackDataType, sys::{DLDataTypeCode, DLDevice, DLPackVersion}};
318 use crate::MtsArray;
319
320 #[test]
321 fn ndarray_as_mts_array() {
322 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3, 4]);
323 let mts_array = MtsArray::from(data);
324
325 assert_eq!(mts_array.shape().unwrap(), [2, 3, 4]);
326
327 let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], 42.0));
328
329 let created = mts_array.create(&[2, 3, 4], fill_value.as_ref()).unwrap();
330 assert_eq!(created.shape().unwrap(), [2, 3, 4]);
331 }
332
333 #[test]
334 fn ndarray_as_mts_array_dlpack() {
335 let data = ndarray::Array::<f64, _>::zeros(vec![4, 5, 6]);
336 let mts_array = MtsArray::from(data);
337
338 let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
339
340 assert_eq!(dl_managed.n_dims(), 3);
341 assert_eq!(dl_managed.shape(), [4, 5, 6]);
342
343 assert_eq!(dl_managed.dtype().code, DLDataTypeCode::kDLFloat);
344 assert_eq!(dl_managed.dtype().bits, 64);
345 assert_eq!(dl_managed.dtype().lanes, 1);
346 }
347
348 #[test]
349 fn ndarray_all_dtypes() {
350 fn test_for_dtype<T>(code: DLDataTypeCode, bits: u8) where T: Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast + 'static {
351 let data = ndarray::Array::<T, _>::from_elem(vec![2, 2], T::default());
352 let mts_array = MtsArray::from(data);
353
354 assert_eq!(mts_array.shape().unwrap(), [2, 2]);
355
356 let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
358 assert_eq!(dl_managed.dtype().code, code);
359 assert_eq!(dl_managed.dtype().bits, bits);
360 assert_eq!(dl_managed.dtype().lanes, 1);
361
362
363 let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], T::default()));
365
366 let created = mts_array.create(&[1, 1], fill_value.as_ref()).unwrap();
367 let dl_managed = created.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
368
369 assert_eq!(dl_managed.dtype().code, code);
370 assert_eq!(dl_managed.dtype().bits, bits);
371 assert_eq!(dl_managed.dtype().lanes, 1);
372 }
373
374 test_for_dtype::<bool>(DLDataTypeCode::kDLBool, 8);
375 test_for_dtype::<f64>(DLDataTypeCode::kDLFloat, 64);
376 test_for_dtype::<f32>(DLDataTypeCode::kDLFloat, 32);
377 test_for_dtype::<i8>(DLDataTypeCode::kDLInt, 8);
378 test_for_dtype::<i16>(DLDataTypeCode::kDLInt, 16);
379 test_for_dtype::<i32>(DLDataTypeCode::kDLInt, 32);
380 test_for_dtype::<i64>(DLDataTypeCode::kDLInt, 64);
381 test_for_dtype::<u8>(DLDataTypeCode::kDLUInt, 8);
382 test_for_dtype::<u16>(DLDataTypeCode::kDLUInt, 16);
383 test_for_dtype::<u32>(DLDataTypeCode::kDLUInt, 32);
384 test_for_dtype::<u64>(DLDataTypeCode::kDLUInt, 64);
385 }
386
387 #[test]
388 fn ndarray_device() {
389 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
390 let mts_array = MtsArray::from(data);
391
392 assert_eq!(mts_array.device().unwrap(), DLDevice::cpu());
393 }
394
395 #[test]
396 fn as_dlpack_rejects_stream() {
397 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
398 let mts_array = MtsArray::from(data);
399 match mts_array.as_dlpack(DLDevice::cpu(), Some(42), DLPackVersion::current()) {
400 Err(e) => assert!(e.message.contains("stream"), "{}", e.message),
401 Ok(_) => panic!("expected error for non-null stream"),
402 }
403 }
404
405 #[test]
406 fn as_dlpack_rejects_wrong_device() {
407 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
408 let mts_array = MtsArray::from(data);
409 let cuda = DLDevice {
410 device_type: dlpk::sys::DLDeviceType::kDLCUDA,
411 device_id: 0,
412 };
413 match mts_array.as_dlpack(cuda, None, DLPackVersion::current()) {
414 Err(e) => assert!(e.message.contains("does not match"), "{}", e.message),
415 Ok(_) => panic!("expected error for CUDA device on CPU array"),
416 }
417 }
418
419 #[test]
420 fn as_dlpack_rejects_incompatible_version() {
421 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
422 let mts_array = MtsArray::from(data);
423
424 let bad_version = DLPackVersion { major: 99, minor: 0 };
425 match mts_array.as_dlpack(DLDevice::cpu(), None, bad_version) {
426 Err(e) => assert!(e.message.contains("version"), "{}", e.message),
427 Ok(_) => panic!("expected error for incompatible DLPack version"),
428 }
429 }
430
431 #[test]
432 #[allow(clippy::float_cmp)]
433 fn from_dlpack() {
434 let mut f64_data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
435 f64_data[[0, 0]] = 1.573;
436 f64_data[[1, 2]] = -42.0;
437 let f64_array = MtsArray::from(f64_data);
438
439 let mut i16_data = ndarray::Array::<i16, _>::zeros(vec![2, 5, 10]);
440 i16_data[[0, 1, 3]] = 3;
441 i16_data[[1, 2, 4]] = -42;
442 let i16_array = MtsArray::from(i16_data);
443
444 let f64_dl_tensor = f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
445 let i16_dl_tensor = i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
446
447 let new_f64_array = f64_array.from_dlpack(f64_dl_tensor).unwrap();
448 let new_i16_array = i16_array.from_dlpack(i16_dl_tensor).unwrap();
449
450 assert_eq!(f64_array.origin().unwrap(), i16_array.origin().unwrap());
451 assert_eq!(new_f64_array.origin().unwrap(), f64_array.origin().unwrap());
452 assert_eq!(new_i16_array.origin().unwrap(), i16_array.origin().unwrap());
453
454 let new_f64_dl_tensor = new_f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
455 let new_i16_dl_tensor = new_i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
456
457 let new_f64_data: ndarray::ArrayD<f64> = new_f64_dl_tensor.try_into().unwrap();
458 let new_i16_data: ndarray::ArrayD<i16> = new_i16_dl_tensor.try_into().unwrap();
459
460 assert_eq!(new_f64_data[[0, 0]], 1.573);
461 assert_eq!(new_f64_data[[1, 2]], -42.0);
462
463 assert_eq!(new_i16_data[[0, 1, 3]], 3);
464 assert_eq!(new_i16_data[[1, 2, 4]], -42);
465 }
466}