1use ndarray::{Array, ArcArray, Dimension, ShapeBuilder};
42
43use crate::data_types::{CastError, DLPackPointerCast, GetDLPackDataType};
44use crate::sys;
45use crate::{DLPackTensor, DLPackTensorRef, DLPackTensorRefMut};
46
47#[cfg(feature = "pyo3")]
48use pyo3::PyErr;
49
50#[derive(Debug)]
52pub enum DLPackNDarrayError {
53 DeviceShouldBeCpu(sys::DLDevice),
55 InvalidType(CastError),
57 ShapeError(ndarray::ShapeError),
59}
60
61impl From<CastError> for DLPackNDarrayError {
62 fn from(err: CastError) -> Self {
63 DLPackNDarrayError::InvalidType(err)
64 }
65}
66
67impl From<ndarray::ShapeError> for DLPackNDarrayError {
68 fn from(err: ndarray::ShapeError) -> Self {
69 DLPackNDarrayError::ShapeError(err)
70 }
71}
72
73#[cfg(feature = "pyo3")]
74impl From<DLPackNDarrayError> for PyErr {
75 fn from(err: DLPackNDarrayError) -> PyErr {
76 pyo3::exceptions::PyValueError::new_err(err.to_string())
77 }
78}
79
80
81impl std::fmt::Display for DLPackNDarrayError {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 DLPackNDarrayError::DeviceShouldBeCpu(device) => {
85 write!(f, "can not convert from device {} (only cpu is supported)", device)
86 }
87 DLPackNDarrayError::InvalidType(error) => {
88 write!(f, "type conversion error: {}", error)
89 }
90 DLPackNDarrayError::ShapeError(error) => {
91 write!(f, "shape error: {}", error)
92 }
93 }
94 }
95}
96
97impl std::error::Error for DLPackNDarrayError {
98 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
99 match self {
100 DLPackNDarrayError::DeviceShouldBeCpu(_) => None,
101 DLPackNDarrayError::InvalidType(err) => Some(err),
102 DLPackNDarrayError::ShapeError(err) => Some(err),
103 }
104 }
105}
106
107impl<'a, T, D> TryFrom<DLPackTensorRef<'a>> for ndarray::ArrayView<'a, T, D> where
112 T: DLPackPointerCast + 'static,
113 D: DimFromVec + 'static,
114{
115 type Error = DLPackNDarrayError;
116
117 fn try_from(tensor: DLPackTensorRef<'a>) -> Result<Self, Self::Error> {
118 if tensor.device().device_type != sys::DLDeviceType::kDLCPU {
119 return Err(DLPackNDarrayError::DeviceShouldBeCpu(tensor.device()))
120 }
121
122 let ptr = tensor.data_ptr::<T>()?;
123 let shape = tensor.shape().iter().map(|&s| s as usize).collect::<Vec<_>>();
124 let shape = <D as DimFromVec>::dim_from_vec(shape)?;
125
126 let array = match DLPackTensorRef::strides(&tensor) {
127 Some(strides) =>{
128 let s_vec = strides.iter().map(|&s| s as usize).collect::<Vec<_>>();
129 let dim_strides = <D as DimFromVec>::dim_from_vec(s_vec)?;
130 let shape = shape.strides(dim_strides);
131 unsafe { ndarray::ArrayView::from_shape_ptr(shape, ptr) }
132 }
133 None => unsafe { ndarray::ArrayView::from_shape_ptr(shape, ptr) }
134 };
135
136 return Ok(array);
137 }
138}
139
140impl<'a, T, D> TryFrom<DLPackTensorRefMut<'a>> for ndarray::ArrayViewMut<'a, T, D> where
141 T: DLPackPointerCast + 'static,
142 D: DimFromVec + 'static,
143{
144 type Error = DLPackNDarrayError;
145
146 fn try_from(mut tensor: DLPackTensorRefMut<'a>) -> Result<Self, Self::Error> {
147 if tensor.device().device_type != sys::DLDeviceType::kDLCPU {
148 return Err(DLPackNDarrayError::DeviceShouldBeCpu(tensor.device()))
149 }
150
151 let ptr = tensor.data_ptr_mut::<T>()?;
152 let shape = tensor.shape().iter().map(|&s| s as usize).collect::<Vec<_>>();
153 let shape = <D as DimFromVec>::dim_from_vec(shape)?;
154
155 let array;
156 if let Some(strides) = DLPackTensorRefMut::strides(&tensor) {
157 let strides = strides.iter().map(|&s| s as usize).collect::<Vec<_>>();
158 let strides = <D as DimFromVec>::dim_from_vec(strides)?;
159 let shape = shape.strides(strides);
160 array = unsafe {
161 ndarray::ArrayViewMut::<T, _>::from_shape_ptr(shape, ptr)
162 };
163 } else {
164 array = unsafe {
165 ndarray::ArrayViewMut::<T, _>::from_shape_ptr(shape, ptr)
166 };
167 }
168
169 return Ok(array);
170 }
171}
172
173impl<T, D> TryFrom<DLPackTensor> for Array<T, D>
180where
181 D: Dimension + DimFromVec + 'static,
182 T: DLPackPointerCast + Clone + 'static,
183{
184 type Error = DLPackNDarrayError;
185
186 fn try_from(tensor: DLPackTensor) -> Result<Self, Self::Error> {
187 let tensor_view = tensor.as_ref();
188 let array_view: ndarray::ArrayView<T, D> = tensor_view.try_into()?;
189 Ok(array_view.to_owned())
190 }
191}
192
193impl<T, D> TryFrom<DLPackTensor> for ArcArray<T, D>
198where
199 D: Dimension + DimFromVec + 'static,
200 T: DLPackPointerCast + Clone + 'static,
201{
202 type Error = DLPackNDarrayError;
203
204 fn try_from(tensor: DLPackTensor) -> Result<Self, Self::Error> {
205 let array: Array<T, D> = tensor.try_into()?;
206 Ok(array.into())
207 }
208}
209
210fn array_to_tensor_view<'a, S, D, T>(array: &'a ndarray::ArrayBase<S, D>) -> Result<sys::DLTensor, DLPackNDarrayError> where
215 D: ndarray::Dimension,
216 S: ndarray::RawData<Elem = T>,
217 T: GetDLPackDataType,
218{
219 let shape: &'a [_] = array.shape();
222 let strides: &'a[_] = ndarray::ArrayBase::strides(array);
223
224 if std::mem::size_of::<isize>() != std::mem::size_of::<i64>() {
227 unimplemented!("DLPack conversion is only supported on 64-bit targets")
228 }
229 let strides = strides.as_ptr().cast_mut().cast();
230
231 if std::mem::size_of::<isize>() != std::mem::size_of::<i64>() {
234 unimplemented!("DLPack conversion is only supported on 64-bit targets")
235 }
236 let ndim = shape.len() as i32;
237 let shape = shape.as_ptr().cast_mut().cast::<i64>();
238
239 let device = sys::DLDevice {
240 device_type: sys::DLDeviceType::kDLCPU,
241 device_id: 0,
242 };
243
244 return Ok(sys::DLTensor {
245 data: array.as_ptr().cast_mut().cast(),
246 device: device,
247 ndim: ndim,
248 dtype: T::get_dlpack_data_type(),
249 shape: shape,
250 strides: strides,
251 byte_offset: 0,
252 });
253}
254
255impl<'a, T, D> TryFrom<&'a ndarray::ArrayView<'a, T, D>> for DLPackTensorRef<'a> where
256 D: ndarray::Dimension,
257 T: GetDLPackDataType,
258{
259 type Error = DLPackNDarrayError;
260
261 fn try_from(array: &'a ndarray::ArrayView<'a, T, D>) -> Result<Self, Self::Error> {
262 let tensor = array_to_tensor_view(array)?;
263
264 return Ok(unsafe {
265 DLPackTensorRef::from_raw(tensor)
267 });
268 }
269}
270
271impl<'a, T, D> TryFrom<&'a ndarray::ArrayViewMut<'a, T, D>> for DLPackTensorRefMut<'a> where
272 D: ndarray::Dimension,
273 T: GetDLPackDataType,
274{
275 type Error = DLPackNDarrayError;
276
277 fn try_from(array: &'a ndarray::ArrayViewMut<'a, T, D>) -> Result<Self, Self::Error> {
278 let tensor = array_to_tensor_view(array)?;
279
280 return Ok(unsafe {
281 DLPackTensorRefMut::from_raw(tensor)
284 });
285 }
286}
287
288impl<'a, T, D> TryFrom<&'a ndarray::Array<T, D>> for DLPackTensorRef<'a> where
289 D: ndarray::Dimension,
290 T: GetDLPackDataType,
291{
292 type Error = DLPackNDarrayError;
293
294 fn try_from(array: &'a ndarray::Array<T, D>) -> Result<Self, Self::Error> {
295 let tensor = array_to_tensor_view(array)?;
296
297 return Ok(unsafe {
298 DLPackTensorRef::from_raw(tensor)
301 });
302 }
303}
304
305impl<'a, T, D> TryFrom<&'a ArcArray<T, D>> for DLPackTensorRef<'a> where
306 D: ndarray::Dimension,
307 T: GetDLPackDataType,
308{
309 type Error = DLPackNDarrayError;
310
311 fn try_from(array: &'a ArcArray<T, D>) -> Result<Self, Self::Error> {
312 let tensor = array_to_tensor_view(array)?;
313
314 return Ok(unsafe {
315 DLPackTensorRef::from_raw(tensor)
317 });
318 }
319}
320
321impl<'a, T, D> TryFrom<&'a mut ndarray::Array<T, D>> for DLPackTensorRefMut<'a> where
322 D: ndarray::Dimension,
323 T: GetDLPackDataType,
324{
325 type Error = DLPackNDarrayError;
326
327 fn try_from(array: &'a mut ndarray::Array<T, D>) -> Result<Self, Self::Error> {
328 let tensor = array_to_tensor_view(array)?;
329
330 return Ok(unsafe {
331 DLPackTensorRefMut::from_raw(tensor)
334 });
335 }
336}
337
338pub trait DimFromVec where Self: ndarray::Dimension {
341 fn dim_from_vec(vec: Vec<usize>) -> Result<Self, ndarray::ShapeError>;
342}
343
344macro_rules! impl_dim_for_vec_array {
345 ($N: expr) => {
346 impl DimFromVec for ndarray::Dim<[ndarray::Ix; $N]> {
347 fn dim_from_vec(vec: Vec<usize>) -> Result<Self, ndarray::ShapeError> {
348 let shape: [ndarray::Ix; $N] = match vec.try_into() {
349 Ok(shape) => shape,
350 Err(_) => {
351 return Err(ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape));
352 },
353 };
354
355 return Ok(ndarray::Dim(shape));
356 }
357 }
358 };
359}
360
361impl_dim_for_vec_array!(0);
362impl_dim_for_vec_array!(1);
363impl_dim_for_vec_array!(2);
364impl_dim_for_vec_array!(3);
365impl_dim_for_vec_array!(4);
366impl_dim_for_vec_array!(5);
367impl_dim_for_vec_array!(6);
368
369impl DimFromVec for ndarray::IxDyn {
370 fn dim_from_vec(shape: Vec<usize>) -> Result<Self, ndarray::ShapeError> {
371 return Ok(ndarray::Dim(shape));
372 }
373}
374
375struct ManagerContext<T> {
377 array: T,
378 shape: Vec<i64>,
379 strides: Vec<i64>,
380}
381
382unsafe extern "C" fn deleter_fn<T>(manager: *mut sys::DLManagedTensorVersioned) {
383 let ctx = (*manager).manager_ctx.cast::<ManagerContext<T>>();
385 let _ = Box::from_raw(ctx);
386}
387
388impl<T, D> TryFrom<Array<T, D>> for DLPackTensor
389where
390 D: Dimension,
391 T: GetDLPackDataType + 'static,
392{
393 type Error = DLPackNDarrayError;
394
395 fn try_from(array: Array<T, D>) -> Result<Self, Self::Error> {
396 let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
397 let strides: Vec<i64> = array.strides().iter().map(|&s| s as i64).collect();
398
399 let mut ctx = Box::new(ManagerContext {
400 array,
401 shape,
402 strides,
403 });
404
405 let dl_tensor = sys::DLTensor {
406 data: ctx.array.as_ptr().cast_mut().cast(),
411 device: sys::DLDevice {
412 device_type: sys::DLDeviceType::kDLCPU,
413 device_id: 0,
414 },
415 ndim: ctx.shape.len() as i32,
416 dtype: T::get_dlpack_data_type(),
417 shape: ctx.shape.as_mut_ptr(),
418 strides: ctx.strides.as_mut_ptr(),
419 byte_offset: 0,
420 };
421
422 let managed_tensor = sys::DLManagedTensorVersioned {
423 version: sys::DLPackVersion::current(),
424 manager_ctx: Box::into_raw(ctx).cast(),
425 deleter: Some(deleter_fn::<Array<T, D>>),
426 flags: sys::DLPACK_FLAG_BITMASK_IS_COPIED,
427 dl_tensor,
428 };
429
430 unsafe {
431 Ok(DLPackTensor::from_raw(managed_tensor))
432 }
433 }
434}
435
436impl<'a, T, D> TryFrom<&'a ArcArray<T, D>> for DLPackTensor
439where
440 D: Dimension,
441 T: GetDLPackDataType + 'static + Clone,
442{
443 type Error = DLPackNDarrayError;
444
445 fn try_from(array: &'a ArcArray<T, D>) -> Result<Self, Self::Error> {
446 let shared_view = array.clone();
447
448 let shape: Vec<i64> = shared_view.shape().iter().map(|&s| s as i64).collect();
449 let strides: Vec<i64> = shared_view.strides().iter().map(|&s| s as i64).collect();
450 let ndim = shape.len() as i32;
451
452 let mut ctx = Box::new(ManagerContext {
453 array: shared_view,
454 shape,
455 strides,
456 });
457
458
459 let dl_tensor = sys::DLTensor {
460 data: ctx.array.as_ptr().cast_mut().cast(),
462 device: sys::DLDevice {
463 device_type: sys::DLDeviceType::kDLCPU,
464 device_id: 0,
465 },
466 ndim,
467 dtype: T::get_dlpack_data_type(),
468 shape: ctx.shape.as_mut_ptr(),
469 strides: ctx.strides.as_mut_ptr(),
470 byte_offset: 0,
471 };
472
473 let managed_tensor = sys::DLManagedTensorVersioned {
474 version: sys::DLPackVersion::current(),
475 manager_ctx: Box::into_raw(ctx).cast(),
476 deleter: Some(deleter_fn::<ArcArray<T, D>>),
477 flags: 0,
478 dl_tensor,
479 };
480
481 unsafe {
482 Ok(DLPackTensor::from_raw(managed_tensor))
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::sys::{DLDevice, DLDeviceType, DLTensor};
491 use ndarray::prelude::*;
492 use ndarray::ArcArray2;
493
494 #[test]
495 fn test_dlpack_to_ndarray() {
496 let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
497 let mut shape = vec![2i64, 3];
498 let mut strides = vec![3i64, 1];
499
500 let dl_tensor = DLTensor {
501 data: data.as_ptr().cast_mut().cast(),
502 device: DLDevice {
503 device_type: DLDeviceType::kDLCPU,
504 device_id: 0,
505 },
506 ndim: 2,
507 dtype: f32::get_dlpack_data_type(),
508 shape: shape.as_mut_ptr(),
509 strides: strides.as_mut_ptr(),
510 byte_offset: 0,
511 };
512
513 let dlpack_ref = unsafe { DLPackTensorRef::from_raw(dl_tensor) };
514 let array_view = ArrayView2::<f32>::try_from(dlpack_ref).unwrap();
515
516 let expected = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
517 assert_eq!(array_view, expected);
518 }
519
520 #[test]
521 fn test_dlpack_to_ndarray_f_contiguous() {
522 let mut data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
523 let mut shape = vec![2i64, 3];
524 let mut strides = vec![1i64, 2];
526
527 let dl_tensor = DLTensor {
528 data: data.as_mut_ptr().cast(),
529 device: DLDevice {
530 device_type: DLDeviceType::kDLCPU,
531 device_id: 0,
532 },
533 ndim: 2,
534 dtype: f32::get_dlpack_data_type(),
535 shape: shape.as_mut_ptr(),
536 strides: strides.as_mut_ptr(),
537 byte_offset: 0,
538 };
539
540 let dlpack_ref = unsafe { DLPackTensorRef::from_raw(dl_tensor) };
541 let array_view = ArrayView2::<f32>::try_from(dlpack_ref).unwrap();
542
543 assert!(!array_view.is_standard_layout());
544 let expected = arr2(&[[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]);
545 assert_eq!(array_view, expected);
546 }
547
548 #[test]
549 fn test_dlpack_to_ndarray_wrong_device() {
550 let mut data = vec![1.0f32];
551 let mut shape = vec![1i64];
552
553 let dl_tensor = DLTensor {
554 data: data.as_mut_ptr().cast(),
555 device: DLDevice {
556 device_type: DLDeviceType::kDLCUDA,
557 device_id: 0,
558 },
559 ndim: 1,
560 dtype: f32::get_dlpack_data_type(),
561 shape: shape.as_mut_ptr(),
562 strides: std::ptr::null_mut(),
563 byte_offset: 0,
564 };
565
566 let dlpack_ref = unsafe { DLPackTensorRef::from_raw(dl_tensor) };
567 let result = ArrayView1::<f32>::try_from(dlpack_ref);
568 assert!(matches!(result, Err(DLPackNDarrayError::DeviceShouldBeCpu(_))));
569 }
570
571 #[test]
572 fn test_ndarray_to_dlpack() {
573 let array = arr2(&[[1i64, 2, 3], [4, 5, 6]]);
574 let view = array.view();
575 let dlpack_ref = DLPackTensorRef::try_from(&view).unwrap();
576 let raw = dlpack_ref.raw;
577
578 assert_eq!(raw.ndim, 2);
579 assert_eq!(raw.device.device_type, DLDeviceType::kDLCPU);
580 assert_eq!(raw.dtype, i64::get_dlpack_data_type());
581 assert_eq!(raw.data as *const i64, array.as_ptr());
582
583 let shape = unsafe { std::slice::from_raw_parts(raw.shape, 2) };
584 assert_eq!(shape, &[2, 3]);
585
586 let strides = unsafe { std::slice::from_raw_parts(raw.strides, 2) };
587 assert_eq!(strides, &[3, 1]);
588 }
589
590 #[test]
591 fn test_dlpack_to_ndarray_mut() {
592 let mut data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
593 let mut shape = vec![2i64, 3];
594 let mut strides = vec![3i64, 1];
595
596 let dl_tensor = DLTensor {
597 data: data.as_mut_ptr().cast(),
598 device: DLDevice {
599 device_type: DLDeviceType::kDLCPU,
600 device_id: 0,
601 },
602 ndim: 2,
603 dtype: f32::get_dlpack_data_type(),
604 shape: shape.as_mut_ptr(),
605 strides: strides.as_mut_ptr(),
606 byte_offset: 0,
607 };
608
609 let dlpack_ref_mut = unsafe { DLPackTensorRefMut::from_raw(dl_tensor) };
610 let mut array_view_mut = ArrayViewMut2::<f32>::try_from(dlpack_ref_mut).unwrap();
611
612 array_view_mut[[0, 0]] = 100.0;
613 assert_eq!(data[0], 100.0);
614
615 let expected = arr2(&[[100.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
616 assert_eq!(array_view_mut, expected);
617 }
618
619 #[test]
620 fn test_ndarray_to_managed_tensor() {
621 let array = arr2(&[[1i64, 2, 3], [4, 5, 6]]);
622 let tensor: DLPackTensor = array.try_into().unwrap();
624
625 let raw = unsafe {
626 &tensor.raw.as_ref().dl_tensor
627 };
628 assert_eq!(raw.ndim, 2);
629 assert_eq!(raw.device.device_type, DLDeviceType::kDLCPU);
630 assert_eq!(raw.dtype, i64::get_dlpack_data_type());
631
632 let shape = unsafe { std::slice::from_raw_parts(raw.shape, 2) };
633 assert_eq!(shape, &[2, 3]);
634
635 let strides = unsafe { std::slice::from_raw_parts(raw.strides, 2) };
636 assert_eq!(strides, &[3, 1]);
637
638 let view = unsafe {
640 let tensor_ref = DLPackTensorRef::from_raw(raw.clone());
641 ndarray::ArrayView2::<i64>::try_from(tensor_ref).unwrap()
642 };
643 assert_eq!(view, arr2(&[[1, 2, 3], [4, 5, 6]]));
644 }
645
646 #[test]
647 fn test_roundtrip_conversion() {
648 let original_array = arr2(&[[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
649 let tensor: DLPackTensor = original_array.clone().try_into().unwrap();
650 let final_array: Array<f32, _> = tensor.try_into().unwrap();
651
652 assert_eq!(original_array, final_array);
653 }
654
655 #[test]
656 fn test_arc_array_to_dlpack_share() {
657 let array = ArcArray2::from_shape_vec((2, 3), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
658 let ptr = array.as_ptr();
659
660 let tensor: DLPackTensor = (&array).try_into().unwrap();
662 let raw = unsafe { &tensor.raw.as_ref().dl_tensor };
663
664 assert_eq!(raw.data as *const f32, ptr);
665
666 let array_copy: Array<f32, _> = tensor.try_into().unwrap();
668 assert_eq!(array, array_copy);
669 assert_ne!(array_copy.as_ptr(), ptr);
671 }
672
673 #[test]
674 fn test_dlpack_to_arc_array() {
675 let array = arr2(&[[10.0f32, 11.0], [12.0, 13.0]]);
676 let tensor: DLPackTensor = array.clone().try_into().unwrap();
677
678 let arc_array: ArcArray<f32, _> = tensor.try_into().unwrap();
679 assert_eq!(arc_array, array);
680 }
681
682 #[test]
683 fn test_arc_array_to_dlpack_ref() {
684 let array = ArcArray2::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap();
685 let tensor_ref: DLPackTensorRef = (&array).try_into().unwrap();
686
687 assert_eq!(tensor_ref.n_dims(), 2);
688 let shape = tensor_ref.shape();
689 assert_eq!(shape, &[2, 2]);
690 }
691
692 #[test]
693 fn test_array_conversion_permits_mutation() {
694 let array = arr2(&[[1.0f32, 2.0], [3.0, 4.0]]);
695 let mut tensor: DLPackTensor = array.try_into().unwrap();
696
697 let mut tensor_mut = tensor.as_mut();
700 let ptr = tensor_mut.data_ptr_mut::<f32>().unwrap();
701
702 unsafe {
703 *ptr = 42.0;
704 }
705
706 let val = tensor.as_ref().data_ptr::<f32>().unwrap();
707 assert_eq!(unsafe { *val }, 42.0);
708 }
709
710 #[test]
711 fn test_arc_array_conversion_allows_readonly_access() {
712 let array = ArcArray2::from_elem((2, 2), 1.0f32);
713 let tensor: DLPackTensor = (&array).try_into().unwrap();
714
715 let tensor_ref = tensor.as_ref();
717 assert_eq!(tensor_ref.dtype(), f32::get_dlpack_data_type());
718 }
719}