1use std::sync::{Arc, RwLock, TryLockError};
2
3use dlpk::sys::{DLDevice, DLPackVersion, DLDataType};
4use dlpk::{DLDataTypeCode, DLPackPointerCast, DLPackTensor, GetDLPackDataType, ReadOnly};
5
6use crate::errors::Error;
7use crate::c_api::mts_data_movement_t;
8
9use super::{Array, MtsArray};
10
11impl<T, D> From<ndarray::Array<T, D>> for MtsArray where
12 T: 'static + Clone + Send + Default + Sync + GetDLPackDataType + DLPackPointerCast,
13 D: ndarray::Dimension
14{
15 fn from(value: ndarray::Array<T, D>) -> Self {
16 let array = Arc::new(RwLock::new(value.into_dyn()));
17 let boxed: Box<dyn Array> = Box::new(array);
18 return MtsArray::from(boxed);
19 }
20}
21
22impl<T> Array for Arc<RwLock<ndarray::ArrayD<T>>>
23where
24 T: 'static + Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast,
25{
26 fn as_any(&self) -> &dyn std::any::Any {
27 self
28 }
29
30 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
31 self
32 }
33
34 fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array> {
35 let cpu_device = DLDevice::cpu();
36 let max_version = DLPackVersion::current();
37 let fill_value_dlpack = fill_value.as_dlpack(cpu_device, None, max_version).expect("failed to extract fill_value as DLPack");
38
39 assert_eq!(fill_value_dlpack.shape(), &[], "fill_value must be a single scalar");
41 assert_eq!(fill_value_dlpack.device(), cpu_device, "fill_value must be on CPU");
42
43 let fill_value_ptr = fill_value_dlpack.data_ptr::<T>().expect("dtype mismatch between array and fill_value");
44 let fill_value_scalar = unsafe { std::ptr::read(fill_value_ptr) };
45
46 let array = ndarray::Array::from_elem(shape, fill_value_scalar);
47 return Box::new(Arc::new(RwLock::new(array)));
48 }
49
50 fn copy(&self, device: DLDevice) -> Box<dyn Array> {
51 assert_eq!(device, DLDevice::cpu(), "Rust ndarray data can only be copied to CPU device");
52 return Box::new(self.clone());
53 }
54
55 fn shape(&self) -> Vec<usize> {
56 match self.try_read() {
57 Ok(lock) => lock.shape().to_vec(),
58 Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
59 Err(TryLockError::WouldBlock) => panic!("array is already locked"),
60 }
61 }
62
63 fn reshape(&mut self, shape: &[usize]) {
64 let mut lock = match self.try_write() {
65 Ok(lock) => lock,
66 Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
67 Err(TryLockError::WouldBlock) => panic!("array is already locked"),
68 };
69 let array = std::mem::take(&mut *lock);
70 let array = array.into_shape_clone(shape).expect("invalid shape");
71 let _ = std::mem::replace(&mut *lock, array);
72 }
73
74 fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
75 let mut lock = match self.try_write() {
76 Ok(lock) => lock,
77 Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
78 Err(TryLockError::WouldBlock) => panic!("array is already locked"),
79 };
80 lock.swap_axes(axis_1, axis_2);
81 }
82
83 fn move_data(
84 &mut self,
85 input: &dyn Array,
86 movements: &[mts_data_movement_t],
87 ) {
88 use ndarray::{Axis, Slice};
89
90 let input = input.as_any().downcast_ref::<Self>().expect("input must be a ndarray of the same type");
91 let input = match input.try_read() {
92 Ok(lock) => lock,
93 Err(TryLockError::Poisoned(_)) => panic!("input array lock is poisoned"),
94 Err(TryLockError::WouldBlock) => panic!("input array is already locked"),
95 };
96
97 let mut output = match self.try_write() {
98 Ok(lock) => lock,
99 Err(TryLockError::Poisoned(_)) => panic!("output array lock is poisoned"),
100 Err(TryLockError::WouldBlock) => panic!("output array is already locked"),
101 };
102
103 if movements.is_empty() {
104 return;
105 }
106
107 let first_prop_start_in = movements[0].properties_start_in;
109 let first_prop_start_out = movements[0].properties_start_out;
110 let first_prop_len = movements[0].properties_length;
111
112 let mut constant_properties = true;
113 let mut contiguous_input_samples = true;
114 let mut contiguous_output_samples = true;
115
116 for w in movements.windows(2) {
117 if w[0].properties_start_in != first_prop_start_in ||
118 w[0].properties_start_out != first_prop_start_out ||
119 w[0].properties_length != first_prop_len {
120 constant_properties = false;
121 break;
122 }
123
124 if w[1].sample_in != w[0].sample_in + 1 {
125 contiguous_input_samples = false;
126 }
127
128 if w[1].sample_out != w[0].sample_out + 1 {
129 contiguous_output_samples = false;
130 }
131 }
132
133 if constant_properties {
134 let last = movements.last().unwrap();
135 if last.properties_start_in != first_prop_start_in ||
136 last.properties_start_out != first_prop_start_out ||
137 last.properties_length != first_prop_len {
138 constant_properties = false;
139 }
140 }
141
142 let property_axis = output.shape().len() - 1;
143
144 if constant_properties {
145 let input_slice_info = Slice::from(first_prop_start_in..(first_prop_start_in + first_prop_len));
146 let output_slice_info = Slice::from(first_prop_start_out..(first_prop_start_out + first_prop_len));
147
148 if contiguous_input_samples && contiguous_output_samples {
149 let sample_start_in = movements[0].sample_in;
150 let sample_start_out = movements[0].sample_out;
151 let sample_count = movements.len();
152
153 let input_samples = input.slice_axis(
154 Axis(0),
155 Slice::from(sample_start_in..(sample_start_in + sample_count))
156 );
157 let mut output_samples = output.slice_axis_mut(
158 Axis(0),
159 Slice::from(sample_start_out..(sample_start_out + sample_count))
160 );
161
162 let value = input_samples.slice_axis(Axis(property_axis), input_slice_info);
163 let mut output_location = output_samples.slice_axis_mut(Axis(property_axis), output_slice_info);
164
165 output_location.assign(&value);
166 } else {
167 for move_item in movements {
168 let input_sample = input.index_axis(Axis(0), move_item.sample_in);
169 let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
170
171 let value = input_sample.slice_axis(
172 Axis(property_axis - 1),
175 input_slice_info
176 );
177 let mut output_location = output_sample.slice_axis_mut(
178 Axis(property_axis - 1),
179 output_slice_info
180 );
181 output_location.assign(&value);
182 }
183 }
184 } else {
185 for move_item in movements {
187 let input_sample = input.index_axis(Axis(0), move_item.sample_in);
188 let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
189
190 let value = input_sample.slice_axis(
191 Axis(property_axis - 1),
193 Slice::from(move_item.properties_start_in..(move_item.properties_start_in + move_item.properties_length))
194 );
195 let mut output_location = output_sample.slice_axis_mut(
196 Axis(property_axis - 1),
197 Slice::from(move_item.properties_start_out..(move_item.properties_start_out + move_item.properties_length))
198 );
199 output_location.assign(&value);
200 }
201 }
202 }
203
204 fn device(&self) -> DLDevice {
205 DLDevice::cpu()
206 }
207
208 fn dtype(&self) -> DLDataType {
209 T::get_dlpack_data_type()
210 }
211
212 fn as_dlpack(
213 &self,
214 device: DLDevice,
215 stream: Option<i64>,
216 max_version: DLPackVersion,
217 ) -> Result<DLPackTensor, Error> {
218 if stream.is_some() {
219 return Err(Error {
221 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
222 message: "CPU arrays can not be used with a stream".into(),
223 });
224 }
225 let vendored_version = DLPackVersion::current();
226 if max_version.major != vendored_version.major {
227 return Err(Error {
228 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
229 message: format!(
230 "invalid `max_version` in ndarray::ArrayD<T>::as_dlpack: \
231 we got v{}.{}, but we support v{}.{}",
232 max_version.major, max_version.minor,
233 vendored_version.major, vendored_version.minor
234 ),
235 });
236 }
237
238 let ndarray_device = DLDevice::cpu();
239
240 if device.device_type != ndarray_device.device_type || device.device_id != ndarray_device.device_id {
241 return Err(Error {
242 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
243 message: format!(
244 "Requested DLPack device ({}) does not match array device ({})",
245 device, ndarray_device
246 ),
247 });
248 }
249
250 let tensor: DLPackTensor = ReadOnly(Arc::clone(self)).try_into().map_err(|e| Error {
251 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
252 message: format!("failed to convert ndarray to DLPack: {:?}", e),
253 })?;
254
255 Ok(tensor)
256 }
257
258 #[allow(clippy::enum_glob_use)]
259 fn from_dlpack(&self, dlpack_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error> {
260 use DLDataTypeCode::*;
261
262 let dtype = dlpack_tensor.dtype();
263
264 if dtype.lanes != 1 {
265 return Err(Error {
266 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
267 message: "Only DLPack tensors with lanes == 1 are supported".into(),
268 });
269 }
270
271 let map_error = |e| Error {
272 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
273 message: format!("failed to convert DLPack to ndarray: {:?}", e),
274 };
275
276 if dtype.code == kDLFloat && dtype.bits == 64 {
277 let array: ndarray::ArrayD<f64> = dlpack_tensor.try_into().map_err(map_error)?;
278 return Ok(Box::new(Arc::new(RwLock::new(array))));
279 } else if dtype.code == kDLFloat && dtype.bits == 32 {
280 let array: ndarray::ArrayD<f32> = dlpack_tensor.try_into().map_err(map_error)?;
281 return Ok(Box::new(Arc::new(RwLock::new(array))));
282 } else if dtype.code == kDLInt && dtype.bits == 8 {
283 let array: ndarray::ArrayD<i8> = dlpack_tensor.try_into().map_err(map_error)?;
284 return Ok(Box::new(Arc::new(RwLock::new(array))));
285 } else if dtype.code == kDLInt && dtype.bits == 16 {
286 let array: ndarray::ArrayD<i16> = dlpack_tensor.try_into().map_err(map_error)?;
287 return Ok(Box::new(Arc::new(RwLock::new(array))));
288 } else if dtype.code == kDLInt && dtype.bits == 32 {
289 let array: ndarray::ArrayD<i32> = dlpack_tensor.try_into().map_err(map_error)?;
290 return Ok(Box::new(Arc::new(RwLock::new(array))));
291 } else if dtype.code == kDLInt && dtype.bits == 64 {
292 let array: ndarray::ArrayD<i64> = dlpack_tensor.try_into().map_err(map_error)?;
293 return Ok(Box::new(Arc::new(RwLock::new(array))));
294 } else if dtype.code == kDLUInt && dtype.bits == 8 {
295 let array: ndarray::ArrayD<u8> = dlpack_tensor.try_into().map_err(map_error)?;
296 return Ok(Box::new(Arc::new(RwLock::new(array))));
297 } else if dtype.code == kDLUInt && dtype.bits == 16 {
298 let array: ndarray::ArrayD<u16> = dlpack_tensor.try_into().map_err(map_error)?;
299 return Ok(Box::new(Arc::new(RwLock::new(array))));
300 } else if dtype.code == kDLUInt && dtype.bits == 32 {
301 let array: ndarray::ArrayD<u32> = dlpack_tensor.try_into().map_err(map_error)?;
302 return Ok(Box::new(Arc::new(RwLock::new(array))));
303 } else if dtype.code == kDLUInt && dtype.bits == 64 {
304 let array: ndarray::ArrayD<u64> = dlpack_tensor.try_into().map_err(map_error)?;
305 return Ok(Box::new(Arc::new(RwLock::new(array))));
306 } else if dtype.code == kDLBool && dtype.bits == 8 {
307 let array: ndarray::ArrayD<bool> = dlpack_tensor.try_into().map_err(map_error)?;
308 return Ok(Box::new(Arc::new(RwLock::new(array))));
309 } else {
310 return Err(Error {
311 code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
312 message: format!("Unsupported DLPack dtype {}", dtype),
313 });
314 }
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use dlpk::{DLPackPointerCast, GetDLPackDataType, sys::{DLDataTypeCode, DLDevice, DLPackVersion}};
321 use crate::MtsArray;
322
323 #[test]
324 fn ndarray_as_mts_array() {
325 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3, 4]);
326 let mts_array = MtsArray::from(data);
327
328 assert_eq!(mts_array.shape().unwrap(), [2, 3, 4]);
329
330 let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], 42.0));
331
332 let created = mts_array.create(&[2, 3, 4], fill_value.as_ref()).unwrap();
333 assert_eq!(created.shape().unwrap(), [2, 3, 4]);
334 }
335
336 #[test]
337 fn ndarray_as_mts_array_dlpack() {
338 let data = ndarray::Array::<f64, _>::zeros(vec![4, 5, 6]);
339 let mts_array = MtsArray::from(data);
340
341 let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
342
343 assert_eq!(dl_managed.n_dims(), 3);
344 assert_eq!(dl_managed.shape(), [4, 5, 6]);
345
346 assert_eq!(dl_managed.dtype().code, DLDataTypeCode::kDLFloat);
347 assert_eq!(dl_managed.dtype().bits, 64);
348 assert_eq!(dl_managed.dtype().lanes, 1);
349 }
350
351 #[test]
352 fn ndarray_all_dtypes() {
353 fn test_for_dtype<T>(code: DLDataTypeCode, bits: u8) where T: Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast + 'static {
354 let data = ndarray::Array::<T, _>::from_elem(vec![2, 2], T::default());
355 let mts_array = MtsArray::from(data);
356
357 assert_eq!(mts_array.shape().unwrap(), [2, 2]);
358
359 let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
361 assert_eq!(dl_managed.dtype().code, code);
362 assert_eq!(dl_managed.dtype().bits, bits);
363 assert_eq!(dl_managed.dtype().lanes, 1);
364
365
366 let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], T::default()));
368
369 let created = mts_array.create(&[1, 1], fill_value.as_ref()).unwrap();
370 let dl_managed = created.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
371
372 assert_eq!(dl_managed.dtype().code, code);
373 assert_eq!(dl_managed.dtype().bits, bits);
374 assert_eq!(dl_managed.dtype().lanes, 1);
375 }
376
377 test_for_dtype::<bool>(DLDataTypeCode::kDLBool, 8);
378 test_for_dtype::<f64>(DLDataTypeCode::kDLFloat, 64);
379 test_for_dtype::<f32>(DLDataTypeCode::kDLFloat, 32);
380 test_for_dtype::<i8>(DLDataTypeCode::kDLInt, 8);
381 test_for_dtype::<i16>(DLDataTypeCode::kDLInt, 16);
382 test_for_dtype::<i32>(DLDataTypeCode::kDLInt, 32);
383 test_for_dtype::<i64>(DLDataTypeCode::kDLInt, 64);
384 test_for_dtype::<u8>(DLDataTypeCode::kDLUInt, 8);
385 test_for_dtype::<u16>(DLDataTypeCode::kDLUInt, 16);
386 test_for_dtype::<u32>(DLDataTypeCode::kDLUInt, 32);
387 test_for_dtype::<u64>(DLDataTypeCode::kDLUInt, 64);
388 }
389
390 #[test]
391 fn ndarray_device() {
392 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
393 let mts_array = MtsArray::from(data);
394
395 assert_eq!(mts_array.device().unwrap(), DLDevice::cpu());
396 }
397
398 #[test]
399 fn as_dlpack_rejects_stream() {
400 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
401 let mts_array = MtsArray::from(data);
402 match mts_array.as_dlpack(DLDevice::cpu(), Some(42), DLPackVersion::current()) {
403 Err(e) => assert!(e.message.contains("stream"), "{}", e.message),
404 Ok(_) => panic!("expected error for non-null stream"),
405 }
406 }
407
408 #[test]
409 fn as_dlpack_rejects_wrong_device() {
410 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
411 let mts_array = MtsArray::from(data);
412 let cuda = DLDevice {
413 device_type: dlpk::sys::DLDeviceType::kDLCUDA,
414 device_id: 0,
415 };
416 match mts_array.as_dlpack(cuda, None, DLPackVersion::current()) {
417 Err(e) => assert!(e.message.contains("does not match"), "{}", e.message),
418 Ok(_) => panic!("expected error for CUDA device on CPU array"),
419 }
420 }
421
422 #[test]
423 fn as_dlpack_rejects_incompatible_version() {
424 let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
425 let mts_array = MtsArray::from(data);
426
427 let bad_version = DLPackVersion { major: 99, minor: 0 };
428 match mts_array.as_dlpack(DLDevice::cpu(), None, bad_version) {
429 Err(e) => assert!(e.message.contains("version"), "{}", e.message),
430 Ok(_) => panic!("expected error for incompatible DLPack version"),
431 }
432 }
433
434 #[test]
435 #[allow(clippy::float_cmp)]
436 fn from_dlpack() {
437 let mut f64_data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
438 f64_data[[0, 0]] = 1.573;
439 f64_data[[1, 2]] = -42.0;
440 let f64_array = MtsArray::from(f64_data);
441
442 let mut i16_data = ndarray::Array::<i16, _>::zeros(vec![2, 5, 10]);
443 i16_data[[0, 1, 3]] = 3;
444 i16_data[[1, 2, 4]] = -42;
445 let i16_array = MtsArray::from(i16_data);
446
447 let f64_dl_tensor = f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
448 let i16_dl_tensor = i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
449
450 let new_f64_array = f64_array.from_dlpack(f64_dl_tensor).unwrap();
451 let new_i16_array = i16_array.from_dlpack(i16_dl_tensor).unwrap();
452
453 assert_eq!(f64_array.origin().unwrap(), i16_array.origin().unwrap());
454 assert_eq!(new_f64_array.origin().unwrap(), f64_array.origin().unwrap());
455 assert_eq!(new_i16_array.origin().unwrap(), i16_array.origin().unwrap());
456
457 let new_f64_dl_tensor = new_f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
458 let new_i16_dl_tensor = new_i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
459
460 let new_f64_data: ndarray::ArrayD<f64> = new_f64_dl_tensor.try_into().unwrap();
461 let new_i16_data: ndarray::ArrayD<i16> = new_i16_dl_tensor.try_into().unwrap();
462
463 assert_eq!(new_f64_data[[0, 0]], 1.573);
464 assert_eq!(new_f64_data[[1, 2]], -42.0);
465
466 assert_eq!(new_i16_data[[0, 1, 3]], 3);
467 assert_eq!(new_i16_data[[1, 2, 4]], -42);
468 }
469}