dlpk/
ndarray.rs

1//! Conversion between DLPack and ndarray, this module requires the `ndarray`
2//! feature to be enabled.
3//!
4//! The following conversions are supported:
5//!
6//! - `DLPackTensor` => `ndarray::Array` (makes a copy of the data)
7//! - `DLPackTensor` => `ndarray::ArcArray` (makes a copy of the data)
8//! - `DLPackTensorRef` => `ndarray::ArrayView`
9//! - `DLPackTensorRefMut` => `ndarray::ArrayViewMut`
10//! - `ndarray::Array` => `DLPackTensor`
11//! - `&ndarray::Array` => `DLPackTensorRef`
12//! - `&mut ndarray::Array` => `DLPackTensorRefMut`
13//! - `ndarray::ArrayView` => `DLPackTensorRef`
14//! - `ndarray::ArrayViewMut` => `DLPackTensorRefMut`
15//! - `ndarray::ArcArray` => `DLPackTensor` (share data)
16//! - `&ndarray::ArcArray` => `DLPackTensorRef`
17//!
18//! # Examples
19//!
20//! ```no_run
21//! use dlpk::{DLPackTensor, DLPackTensorRef};
22//! # fn get_tensor_from_somewhere() -> DLPackTensor { unimplemented!() }
23//!
24//! let tensor: DLPackTensor = get_tensor_from_somewhere();
25//!
26//! // makes a copy of the data
27//! let array: ndarray::ArrayD<f32> = tensor.try_into().unwrap();
28//!
29//! // no copy, share data with the original tensor
30//! let tensor: DLPackTensor = get_tensor_from_somewhere();
31//! let tensor_ref: DLPackTensorRef = tensor.as_ref();
32//! let reference: ndarray::ArrayView2<f32> = tensor_ref.try_into().unwrap();
33//!
34//! // convert an ndarray array into a DLPack tensor
35//! let array = ndarray::Array::from_elem((2, 3), 1.0f32);
36//! let tensor: DLPackTensor = array.clone().try_into().unwrap();
37//!
38//! let tensor_ref: DLPackTensorRef = (&array).try_into().unwrap();
39//! ```
40
41use ndarray::{Array, ArcArray, Dimension, ShapeBuilder};
42
43use crate::data_types::{CastError, DLPackPointerCast, GetDLPackDataType};
44use crate::sys;
45use crate::{DLPackTensor, DLPackTensorRef, DLPackTensorRefMut};
46
47#[cfg(feature = "pyo3")]
48use pyo3::PyErr;
49
50/// Possible error causes when converting between ndarray and DLPack
51#[derive(Debug)]
52pub enum DLPackNDarrayError {
53    /// ndarray only support data which lives on CPU
54    DeviceShouldBeCpu(sys::DLDevice),
55    /// The DLPack type can not be converted to a supported Rust type
56    InvalidType(CastError),
57    /// The shape/stride of the data does not match expectations
58    ShapeError(ndarray::ShapeError),
59}
60
61impl From<CastError> for DLPackNDarrayError {
62    fn from(err: CastError) -> Self {
63        DLPackNDarrayError::InvalidType(err)
64    }
65}
66
67impl From<ndarray::ShapeError> for DLPackNDarrayError {
68    fn from(err: ndarray::ShapeError) -> Self {
69        DLPackNDarrayError::ShapeError(err)
70    }
71}
72
73#[cfg(feature = "pyo3")]
74impl From<DLPackNDarrayError> for PyErr {
75    fn from(err: DLPackNDarrayError) -> PyErr {
76        pyo3::exceptions::PyValueError::new_err(err.to_string())
77    }
78}
79
80
81impl std::fmt::Display for DLPackNDarrayError {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            DLPackNDarrayError::DeviceShouldBeCpu(device) => {
85                write!(f, "can not convert from device {} (only cpu is supported)", device)
86            }
87            DLPackNDarrayError::InvalidType(error) => {
88                write!(f, "type conversion error: {}", error)
89            }
90            DLPackNDarrayError::ShapeError(error) => {
91                write!(f, "shape error: {}", error)
92            }
93        }
94    }
95}
96
97impl std::error::Error for DLPackNDarrayError {
98    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
99        match self {
100            DLPackNDarrayError::DeviceShouldBeCpu(_) => None,
101            DLPackNDarrayError::InvalidType(err) => Some(err),
102            DLPackNDarrayError::ShapeError(err) => Some(err),
103        }
104    }
105}
106
107/*****************************************************************************/
108/*                            DLPack => ndarray                              */
109/*****************************************************************************/
110
111impl<'a, T, D> TryFrom<DLPackTensorRef<'a>> for ndarray::ArrayView<'a, T, D> where
112    T: DLPackPointerCast + 'static,
113    D: DimFromVec + 'static,
114{
115    type Error = DLPackNDarrayError;
116
117    fn try_from(tensor: DLPackTensorRef<'a>) -> Result<Self, Self::Error> {
118        if tensor.device().device_type != sys::DLDeviceType::kDLCPU {
119            return Err(DLPackNDarrayError::DeviceShouldBeCpu(tensor.device()))
120        }
121
122        let ptr = tensor.data_ptr::<T>()?;
123        let shape = tensor.shape().iter().map(|&s| s as usize).collect::<Vec<_>>();
124        let shape = <D as DimFromVec>::dim_from_vec(shape)?;
125
126        let array = match DLPackTensorRef::strides(&tensor) {
127            Some(strides) =>{
128                let s_vec = strides.iter().map(|&s| s as usize).collect::<Vec<_>>();
129                let dim_strides = <D as DimFromVec>::dim_from_vec(s_vec)?;
130                let shape = shape.strides(dim_strides);
131                unsafe { ndarray::ArrayView::from_shape_ptr(shape, ptr) }
132            }
133            None => unsafe { ndarray::ArrayView::from_shape_ptr(shape, ptr) }
134        };
135
136        return Ok(array);
137    }
138}
139
140impl<'a, T, D> TryFrom<DLPackTensorRefMut<'a>> for ndarray::ArrayViewMut<'a, T, D> where
141    T: DLPackPointerCast + 'static,
142    D: DimFromVec + 'static,
143{
144    type Error = DLPackNDarrayError;
145
146    fn try_from(mut tensor: DLPackTensorRefMut<'a>) -> Result<Self, Self::Error> {
147        if tensor.device().device_type != sys::DLDeviceType::kDLCPU {
148            return Err(DLPackNDarrayError::DeviceShouldBeCpu(tensor.device()))
149        }
150
151        let ptr = tensor.data_ptr_mut::<T>()?;
152        let shape = tensor.shape().iter().map(|&s| s as usize).collect::<Vec<_>>();
153        let shape = <D as DimFromVec>::dim_from_vec(shape)?;
154
155        let array;
156        if let Some(strides) = DLPackTensorRefMut::strides(&tensor) {
157            let strides = strides.iter().map(|&s| s as usize).collect::<Vec<_>>();
158            let strides = <D as DimFromVec>::dim_from_vec(strides)?;
159            let shape = shape.strides(strides);
160            array = unsafe {
161                ndarray::ArrayViewMut::<T, _>::from_shape_ptr(shape, ptr)
162            };
163        } else {
164            array = unsafe {
165                ndarray::ArrayViewMut::<T, _>::from_shape_ptr(shape, ptr)
166            };
167        }
168
169        return Ok(array);
170    }
171}
172
173/// This implementation provides a conversion from a DLPack `DLPackTensor` to an
174/// `ndarray::Array`.
175///
176/// **Note:** This conversion makes a copy of the underlying tensor data. The
177/// original DLPack tensor memory is released after the copy is complete.
178///
179impl<T, D> TryFrom<DLPackTensor> for Array<T, D>
180where
181    D: Dimension + DimFromVec + 'static,
182    T: DLPackPointerCast + Clone + 'static,
183{
184    type Error = DLPackNDarrayError;
185
186    fn try_from(tensor: DLPackTensor) -> Result<Self, Self::Error> {
187        let tensor_view = tensor.as_ref();
188        let array_view: ndarray::ArrayView<T, D> = tensor_view.try_into()?;
189        Ok(array_view.to_owned())
190    }
191}
192
193/// This implementation provides a conversion from a DLPack `DLPackTensor` to an
194/// `ndarray::ArcArray`.
195///
196/// **Note:** This conversion makes a copy of the underlying tensor data.
197impl<T, D> TryFrom<DLPackTensor> for ArcArray<T, D>
198where
199    D: Dimension + DimFromVec + 'static,
200    T: DLPackPointerCast + Clone + 'static,
201{
202    type Error = DLPackNDarrayError;
203
204    fn try_from(tensor: DLPackTensor) -> Result<Self, Self::Error> {
205        let array: Array<T, D> = tensor.try_into()?;
206        Ok(array.into())
207    }
208}
209
210/*****************************************************************************/
211/*                            ndarray => DLPack                              */
212/*****************************************************************************/
213
214fn array_to_tensor_view<'a, S, D, T>(array: &'a ndarray::ArrayBase<S, D>) -> Result<sys::DLTensor, DLPackNDarrayError> where
215    D: ndarray::Dimension,
216    S: ndarray::RawData<Elem = T>,
217    T: GetDLPackDataType,
218{
219    // SAFETY: we make sure that shape and strides are valid for the lifetime of
220    // the array
221    let shape: &'a [_] = array.shape();
222    let strides: &'a[_] = ndarray::ArrayBase::strides(array);
223
224    // we need a `*const i64` for DLTensor, but we have usize and isize.
225    // on 64-bit targets, isize will be the same as i64, so that's fine.
226    if std::mem::size_of::<isize>() != std::mem::size_of::<i64>() {
227        unimplemented!("DLPack conversion is only supported on 64-bit targets")
228    }
229    let strides = strides.as_ptr().cast_mut().cast();
230
231    // usize will have the same binary representation as i64 for striclty
232    // positive values, which is the most important case here.
233    if std::mem::size_of::<isize>() != std::mem::size_of::<i64>() {
234        unimplemented!("DLPack conversion is only supported on 64-bit targets")
235    }
236    let ndim = shape.len() as i32;
237    let shape = shape.as_ptr().cast_mut().cast::<i64>();
238
239    let device = sys::DLDevice {
240        device_type: sys::DLDeviceType::kDLCPU,
241        device_id: 0,
242    };
243
244    return Ok(sys::DLTensor {
245        data: array.as_ptr().cast_mut().cast(),
246        device: device,
247        ndim: ndim,
248        dtype: T::get_dlpack_data_type(),
249        shape: shape,
250        strides: strides,
251        byte_offset: 0,
252    });
253}
254
255impl<'a, T, D> TryFrom<&'a ndarray::ArrayView<'a, T, D>> for DLPackTensorRef<'a> where
256    D: ndarray::Dimension,
257    T: GetDLPackDataType,
258{
259    type Error = DLPackNDarrayError;
260
261    fn try_from(array: &'a ndarray::ArrayView<'a, T, D>) -> Result<Self, Self::Error> {
262        let tensor = array_to_tensor_view(array)?;
263
264        return Ok(unsafe {
265            // SAFETY: we are constraining the lifetime of the return value
266            DLPackTensorRef::from_raw(tensor)
267        });
268    }
269}
270
271impl<'a, T, D> TryFrom<&'a ndarray::ArrayViewMut<'a, T, D>> for DLPackTensorRefMut<'a> where
272    D: ndarray::Dimension,
273    T: GetDLPackDataType,
274{
275    type Error = DLPackNDarrayError;
276
277    fn try_from(array: &'a ndarray::ArrayViewMut<'a, T, D>) -> Result<Self, Self::Error> {
278        let tensor = array_to_tensor_view(array)?;
279
280        return Ok(unsafe {
281            // SAFETY: we are constraining the lifetime of the return value, and
282            // returning a mut ref from a mut ref
283            DLPackTensorRefMut::from_raw(tensor)
284        });
285    }
286}
287
288impl<'a, T, D> TryFrom<&'a ndarray::Array<T, D>> for DLPackTensorRef<'a> where
289    D: ndarray::Dimension,
290    T: GetDLPackDataType,
291{
292    type Error = DLPackNDarrayError;
293
294    fn try_from(array: &'a ndarray::Array<T, D>) -> Result<Self, Self::Error> {
295        let tensor = array_to_tensor_view(array)?;
296
297        return Ok(unsafe {
298            // SAFETY: we are constraining the lifetime of the return value, and
299            // returning a mut ref from a mut ref
300            DLPackTensorRef::from_raw(tensor)
301        });
302    }
303}
304
305impl<'a, T, D> TryFrom<&'a ArcArray<T, D>> for DLPackTensorRef<'a> where
306    D: ndarray::Dimension,
307    T: GetDLPackDataType,
308{
309    type Error = DLPackNDarrayError;
310
311    fn try_from(array: &'a ArcArray<T, D>) -> Result<Self, Self::Error> {
312        let tensor = array_to_tensor_view(array)?;
313
314        return Ok(unsafe {
315            // SAFETY: we are constraining the lifetime of the return value
316            DLPackTensorRef::from_raw(tensor)
317        });
318    }
319}
320
321impl<'a, T, D> TryFrom<&'a mut ndarray::Array<T, D>> for DLPackTensorRefMut<'a> where
322    D: ndarray::Dimension,
323    T: GetDLPackDataType,
324{
325    type Error = DLPackNDarrayError;
326
327    fn try_from(array: &'a mut ndarray::Array<T, D>) -> Result<Self, Self::Error> {
328        let tensor = array_to_tensor_view(array)?;
329
330        return Ok(unsafe {
331            // SAFETY: we are constraining the lifetime of the return value, and
332            // returning a mut ref from a mut ref
333            DLPackTensorRefMut::from_raw(tensor)
334        });
335    }
336}
337
338/// Internal trait that will convert a `Vec<usize>` into one of ndarray's Dim
339/// type.
340pub trait DimFromVec where Self: ndarray::Dimension {
341    fn dim_from_vec(vec: Vec<usize>) -> Result<Self, ndarray::ShapeError>;
342}
343
344macro_rules! impl_dim_for_vec_array {
345    ($N: expr) => {
346        impl DimFromVec for ndarray::Dim<[ndarray::Ix; $N]> {
347            fn dim_from_vec(vec: Vec<usize>) -> Result<Self, ndarray::ShapeError> {
348                let shape: [ndarray::Ix; $N] = match vec.try_into() {
349                    Ok(shape) => shape,
350                    Err(_) => {
351                        return Err(ndarray::ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape));
352                    },
353                };
354
355                return Ok(ndarray::Dim(shape));
356            }
357        }
358    };
359}
360
361impl_dim_for_vec_array!(0);
362impl_dim_for_vec_array!(1);
363impl_dim_for_vec_array!(2);
364impl_dim_for_vec_array!(3);
365impl_dim_for_vec_array!(4);
366impl_dim_for_vec_array!(5);
367impl_dim_for_vec_array!(6);
368
369impl DimFromVec for ndarray::IxDyn {
370    fn dim_from_vec(shape: Vec<usize>) -> Result<Self, ndarray::ShapeError> {
371        return Ok(ndarray::Dim(shape));
372    }
373}
374
375// Private struct to manage the lifetime of the array and its shape/strides
376struct ManagerContext<T> {
377    array: T,
378    shape: Vec<i64>,
379    strides: Vec<i64>,
380}
381
382unsafe extern "C" fn deleter_fn<T>(manager: *mut sys::DLManagedTensorVersioned) {
383    // Reconstruct the box and drop it, freeing the memory.
384    let ctx = (*manager).manager_ctx.cast::<ManagerContext<T>>();
385    let _ = Box::from_raw(ctx);
386}
387
388impl<T, D> TryFrom<Array<T, D>> for DLPackTensor
389where
390    D: Dimension,
391    T: GetDLPackDataType + 'static,
392{
393    type Error = DLPackNDarrayError;
394
395    fn try_from(array: Array<T, D>) -> Result<Self, Self::Error> {
396        let shape: Vec<i64> = array.shape().iter().map(|&s| s as i64).collect();
397        let strides: Vec<i64> = array.strides().iter().map(|&s| s as i64).collect();
398
399        let mut ctx = Box::new(ManagerContext {
400            array,
401            shape,
402            strides,
403        });
404
405        let dl_tensor = sys::DLTensor {
406            // Casting to a mut pointer is not necessarily safe, but is required
407            // by DLPack. The data can be mutated through this pointer, we
408            // should try to find a way to make this work in Rust type system in
409            // the future.
410            data: ctx.array.as_ptr().cast_mut().cast(),
411            device: sys::DLDevice {
412                device_type: sys::DLDeviceType::kDLCPU,
413                device_id: 0,
414            },
415            ndim: ctx.shape.len() as i32,
416            dtype: T::get_dlpack_data_type(),
417            shape: ctx.shape.as_mut_ptr(),
418            strides: ctx.strides.as_mut_ptr(),
419            byte_offset: 0,
420        };
421
422        let managed_tensor = sys::DLManagedTensorVersioned {
423            version: sys::DLPackVersion::current(),
424            manager_ctx: Box::into_raw(ctx).cast(),
425            deleter: Some(deleter_fn::<Array<T, D>>),
426            flags: sys::DLPACK_FLAG_BITMASK_IS_COPIED,
427            dl_tensor,
428        };
429
430        unsafe {
431            Ok(DLPackTensor::from_raw(managed_tensor))
432        }
433    }
434}
435
436/// Convert a shared `ArcArray` into a `DLPackTensor`.
437/// This is ZERO-COPY: it increments the reference count of the data.
438impl<'a, T, D> TryFrom<&'a ArcArray<T, D>> for DLPackTensor
439where
440    D: Dimension,
441    T: GetDLPackDataType + 'static + Clone,
442{
443    type Error = DLPackNDarrayError;
444
445    fn try_from(array: &'a ArcArray<T, D>) -> Result<Self, Self::Error> {
446        let shared_view = array.clone();
447
448        let shape: Vec<i64> = shared_view.shape().iter().map(|&s| s as i64).collect();
449        let strides: Vec<i64> = shared_view.strides().iter().map(|&s| s as i64).collect();
450        let ndim = shape.len() as i32;
451
452        let mut ctx = Box::new(ManagerContext {
453            array: shared_view,
454            shape,
455            strides,
456        });
457
458
459        let dl_tensor = sys::DLTensor {
460            // Same as above, casting to a mut pointer is not necessarily safe.
461            data: ctx.array.as_ptr().cast_mut().cast(),
462            device: sys::DLDevice {
463                device_type: sys::DLDeviceType::kDLCPU,
464                device_id: 0,
465            },
466            ndim,
467            dtype: T::get_dlpack_data_type(),
468            shape: ctx.shape.as_mut_ptr(),
469            strides: ctx.strides.as_mut_ptr(),
470            byte_offset: 0,
471        };
472
473        let managed_tensor = sys::DLManagedTensorVersioned {
474            version: sys::DLPackVersion::current(),
475            manager_ctx: Box::into_raw(ctx).cast(),
476            deleter: Some(deleter_fn::<ArcArray<T, D>>),
477            flags: 0,
478            dl_tensor,
479        };
480
481        unsafe {
482            Ok(DLPackTensor::from_raw(managed_tensor))
483        }
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::sys::{DLDevice, DLDeviceType, DLTensor};
491    use ndarray::prelude::*;
492    use ndarray::ArcArray2;
493
494    #[test]
495    fn test_dlpack_to_ndarray() {
496        let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
497        let mut shape = vec![2i64, 3];
498        let mut strides = vec![3i64, 1];
499
500        let dl_tensor = DLTensor {
501            data: data.as_ptr().cast_mut().cast(),
502            device: DLDevice {
503                device_type: DLDeviceType::kDLCPU,
504                device_id: 0,
505            },
506            ndim: 2,
507            dtype: f32::get_dlpack_data_type(),
508            shape: shape.as_mut_ptr(),
509            strides: strides.as_mut_ptr(),
510            byte_offset: 0,
511        };
512
513        let dlpack_ref = unsafe { DLPackTensorRef::from_raw(dl_tensor) };
514        let array_view = ArrayView2::<f32>::try_from(dlpack_ref).unwrap();
515
516        let expected = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
517        assert_eq!(array_view, expected);
518    }
519
520    #[test]
521    fn test_dlpack_to_ndarray_f_contiguous() {
522        let mut data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
523        let mut shape = vec![2i64, 3];
524        // Fortran-contiguous strides
525        let mut strides = vec![1i64, 2];
526
527        let dl_tensor = DLTensor {
528            data: data.as_mut_ptr().cast(),
529            device: DLDevice {
530                device_type: DLDeviceType::kDLCPU,
531                device_id: 0,
532            },
533            ndim: 2,
534            dtype: f32::get_dlpack_data_type(),
535            shape: shape.as_mut_ptr(),
536            strides: strides.as_mut_ptr(),
537            byte_offset: 0,
538        };
539
540        let dlpack_ref = unsafe { DLPackTensorRef::from_raw(dl_tensor) };
541        let array_view = ArrayView2::<f32>::try_from(dlpack_ref).unwrap();
542
543        assert!(!array_view.is_standard_layout());
544        let expected = arr2(&[[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]);
545        assert_eq!(array_view, expected);
546    }
547
548    #[test]
549    fn test_dlpack_to_ndarray_wrong_device() {
550        let mut data = vec![1.0f32];
551        let mut shape = vec![1i64];
552
553        let dl_tensor = DLTensor {
554            data: data.as_mut_ptr().cast(),
555            device: DLDevice {
556                device_type: DLDeviceType::kDLCUDA,
557                device_id: 0,
558            },
559            ndim: 1,
560            dtype: f32::get_dlpack_data_type(),
561            shape: shape.as_mut_ptr(),
562            strides: std::ptr::null_mut(),
563            byte_offset: 0,
564        };
565
566        let dlpack_ref = unsafe { DLPackTensorRef::from_raw(dl_tensor) };
567        let result = ArrayView1::<f32>::try_from(dlpack_ref);
568        assert!(matches!(result, Err(DLPackNDarrayError::DeviceShouldBeCpu(_))));
569    }
570
571    #[test]
572    fn test_ndarray_to_dlpack() {
573        let array = arr2(&[[1i64, 2, 3], [4, 5, 6]]);
574        let view = array.view();
575        let dlpack_ref = DLPackTensorRef::try_from(&view).unwrap();
576        let raw = dlpack_ref.raw;
577
578        assert_eq!(raw.ndim, 2);
579        assert_eq!(raw.device.device_type, DLDeviceType::kDLCPU);
580        assert_eq!(raw.dtype, i64::get_dlpack_data_type());
581        assert_eq!(raw.data as *const i64, array.as_ptr());
582
583        let shape = unsafe { std::slice::from_raw_parts(raw.shape, 2) };
584        assert_eq!(shape, &[2, 3]);
585
586        let strides = unsafe { std::slice::from_raw_parts(raw.strides, 2) };
587        assert_eq!(strides, &[3, 1]);
588    }
589
590    #[test]
591    fn test_dlpack_to_ndarray_mut() {
592        let mut data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
593        let mut shape = vec![2i64, 3];
594        let mut strides = vec![3i64, 1];
595
596        let dl_tensor = DLTensor {
597            data: data.as_mut_ptr().cast(),
598            device: DLDevice {
599                device_type: DLDeviceType::kDLCPU,
600                device_id: 0,
601            },
602            ndim: 2,
603            dtype: f32::get_dlpack_data_type(),
604            shape: shape.as_mut_ptr(),
605            strides: strides.as_mut_ptr(),
606            byte_offset: 0,
607        };
608
609        let dlpack_ref_mut = unsafe { DLPackTensorRefMut::from_raw(dl_tensor) };
610        let mut array_view_mut = ArrayViewMut2::<f32>::try_from(dlpack_ref_mut).unwrap();
611
612        array_view_mut[[0, 0]] = 100.0;
613        assert_eq!(data[0], 100.0);
614
615        let expected = arr2(&[[100.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
616        assert_eq!(array_view_mut, expected);
617    }
618
619    #[test]
620    fn test_ndarray_to_managed_tensor() {
621        let array = arr2(&[[1i64, 2, 3], [4, 5, 6]]);
622        // The original array is moved into the manager.
623        let tensor: DLPackTensor = array.try_into().unwrap();
624
625        let raw = unsafe {
626            &tensor.raw.as_ref().dl_tensor
627        };
628        assert_eq!(raw.ndim, 2);
629        assert_eq!(raw.device.device_type, DLDeviceType::kDLCPU);
630        assert_eq!(raw.dtype, i64::get_dlpack_data_type());
631
632        let shape = unsafe { std::slice::from_raw_parts(raw.shape, 2) };
633        assert_eq!(shape, &[2, 3]);
634
635        let strides = unsafe { std::slice::from_raw_parts(raw.strides, 2) };
636        assert_eq!(strides, &[3, 1]);
637
638        // To check correctness, we can create a view from the managed tensor's data.
639        let view = unsafe {
640            let tensor_ref = DLPackTensorRef::from_raw(raw.clone());
641            ndarray::ArrayView2::<i64>::try_from(tensor_ref).unwrap()
642        };
643        assert_eq!(view, arr2(&[[1, 2, 3], [4, 5, 6]]));
644    }
645
646    #[test]
647    fn test_roundtrip_conversion() {
648        let original_array = arr2(&[[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
649        let tensor: DLPackTensor = original_array.clone().try_into().unwrap();
650        let final_array: Array<f32, _> = tensor.try_into().unwrap();
651
652        assert_eq!(original_array, final_array);
653    }
654
655    #[test]
656    fn test_arc_array_to_dlpack_share() {
657        let array = ArcArray2::from_shape_vec((2, 3), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
658        let ptr = array.as_ptr();
659
660        // Conversion to DLPackTensor (Owned) should share data
661        let tensor: DLPackTensor = (&array).try_into().unwrap();
662        let raw = unsafe { &tensor.raw.as_ref().dl_tensor };
663
664        assert_eq!(raw.data as *const f32, ptr);
665
666        // Convert back to Array (Copy)
667        let array_copy: Array<f32, _> = tensor.try_into().unwrap();
668        assert_eq!(array, array_copy);
669        // Pointers should differ due to copy
670        assert_ne!(array_copy.as_ptr(), ptr);
671    }
672
673    #[test]
674    fn test_dlpack_to_arc_array() {
675        let array = arr2(&[[10.0f32, 11.0], [12.0, 13.0]]);
676        let tensor: DLPackTensor = array.clone().try_into().unwrap();
677
678        let arc_array: ArcArray<f32, _> = tensor.try_into().unwrap();
679        assert_eq!(arc_array, array);
680    }
681
682    #[test]
683    fn test_arc_array_to_dlpack_ref() {
684        let array = ArcArray2::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap();
685        let tensor_ref: DLPackTensorRef = (&array).try_into().unwrap();
686
687        assert_eq!(tensor_ref.n_dims(), 2);
688        let shape = tensor_ref.shape();
689        assert_eq!(shape, &[2, 2]);
690    }
691
692    #[test]
693    fn test_array_conversion_permits_mutation() {
694        let array = arr2(&[[1.0f32, 2.0], [3.0, 4.0]]);
695        let mut tensor: DLPackTensor = array.try_into().unwrap();
696
697        // This should not panic because flags include IS_COPIED
698        // and do not include READ_ONLY.
699        let mut tensor_mut = tensor.as_mut();
700        let ptr = tensor_mut.data_ptr_mut::<f32>().unwrap();
701
702        unsafe {
703            *ptr = 42.0;
704        }
705
706        let val = tensor.as_ref().data_ptr::<f32>().unwrap();
707        assert_eq!(unsafe { *val }, 42.0);
708    }
709
710    #[test]
711    fn test_arc_array_conversion_allows_readonly_access() {
712        let array = ArcArray2::from_elem((2, 2), 1.0f32);
713        let tensor: DLPackTensor = (&array).try_into().unwrap();
714
715        // Standard immutable access should remain functional.
716        let tensor_ref = tensor.as_ref();
717        assert_eq!(tensor_ref.dtype(), f32::get_dlpack_data_type());
718    }
719}