1mod origin;
2
3mod array;
4pub use self::array::Array;
5
6mod empty;
7pub use self::empty::EmptyArray;
8
9mod ndarray_array;
10
11mod external;
12pub use self::external::MtsArray;
13
14mod array_ref;
15pub use self::array_ref::{ArrayRef, ArrayRefMut};
16
17pub use dlpk::sys::{DLDataType, DLDataTypeCode};
18pub use dlpk::sys::{DLDeviceType, DLDevice};
19pub use dlpk::sys::DLPackVersion;
20pub use dlpk::DLPackTensor;
21
22pub use metatensor_sys::mts_data_movement_t;
23
24#[cfg(test)]
25mod tests {
26 use ndarray::Array;
27 use crate::c_api::mts_data_movement_t;
28
29 use super::*;
30 use super::origin::get_data_origin;
31
32 #[test]
33 fn shape() {
34 let mut array = MtsArray::from(Array::from_elem(vec![3, 4, 2], 1.0));
35
36 assert_eq!(array.shape().unwrap(), [3, 4, 2]);
37 array.reshape(&[12, 2]).unwrap();
38 assert_eq!(array.shape().unwrap(), [12, 2]);
39 assert_eq!(*array.as_ndarray::<f64>(), Array::from_elem(vec![12, 2], 1.0));
40
41 array.swap_axes(0, 1).unwrap();
42 assert_eq!(array.shape().unwrap(), [2, 12]);
43
44 let array_ref = array.as_ref();
45 assert_eq!(array_ref.shape().unwrap(), [2, 12]);
46 assert_eq!(*array_ref.as_ndarray_lock::<f64>().read().unwrap(), Array::from_elem(vec![2, 12], 1.0));
47
48 let mut array_ref_mut = array.as_mut();
49 assert_eq!(array_ref_mut.shape().unwrap(), [2, 12]);
50
51 array_ref_mut.reshape(&[6, 4]).unwrap();
52 assert_eq!(array_ref_mut.shape().unwrap(), [6, 4]);
53 assert_eq!(*array_ref_mut.as_ndarray_lock::<f64>().read().unwrap(), Array::from_elem(vec![6, 4], 1.0));
54 }
55
56 #[test]
57 fn create() {
58 let array = MtsArray::from(Array::from_elem(vec![4, 2], 1.0));
59
60 assert_eq!(get_data_origin(array.origin().unwrap()).unwrap(), "RustArray");
61 assert_eq!(*array.as_ndarray::<f64>(), Array::from_elem(vec![4, 2], 1.0));
62
63 let fill_value = MtsArray::from(Array::from_elem(vec![], 42.0));
64 let other = array.create(&[5, 3, 7, 12], fill_value.as_ref()).unwrap();
65 assert_eq!(other.shape().unwrap(), [5, 3, 7, 12]);
66 assert_eq!(get_data_origin(other.origin().unwrap()).unwrap(), "RustArray");
67 assert_eq!(*other.as_ndarray::<f64>(), Array::from_elem(vec![5, 3, 7, 12], 42.0));
68 }
69
70 #[test]
71 fn move_data() {
72 let array = MtsArray::from(Array::from_elem(vec![3, 2, 2, 4], 1.0));
73
74 let fill_value = MtsArray::from(Array::from_elem(vec![], 0.0));
75 let mut other = array.create(&[1, 2, 2, 8], fill_value.as_ref()).unwrap();
76 assert_eq!(*other.as_ndarray::<f64>(), Array::from_elem(vec![1, 2, 2, 8], 0.0));
77
78 let mapping = mts_data_movement_t {
79 sample_in: 1,
80 sample_out: 0,
81 properties_start_in: 0,
82 properties_start_out: 2,
83 properties_length: 4,
84 };
85 other.move_data(&array, &[mapping]).unwrap();
86 let expected = Array::from_shape_vec(vec![1, 2, 2, 8], vec![
87 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
88 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
89 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
90 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
91 ]).unwrap();
92 assert_eq!(*other.as_ndarray::<f64>(), expected);
93 }
94}