Skip to main content

metatensor/data/
mod.rs

1mod origin;
2
3mod array;
4pub use self::array::Array;
5
6mod empty;
7pub use self::empty::EmptyArray;
8
9mod ndarray_array;
10
11mod external;
12pub use self::external::MtsArray;
13
14mod array_ref;
15pub use self::array_ref::{ArrayRef, ArrayRefMut};
16
17pub use dlpk::sys::{DLDataType, DLDataTypeCode};
18pub use dlpk::sys::{DLDeviceType, DLDevice};
19pub use dlpk::sys::DLPackVersion;
20pub use dlpk::DLPackTensor;
21
22pub use metatensor_sys::mts_data_movement_t;
23
24#[cfg(test)]
25mod tests {
26    use ndarray::Array;
27    use crate::c_api::mts_data_movement_t;
28
29    use super::*;
30    use super::origin::get_data_origin;
31
32    #[test]
33    fn shape() {
34        let mut array = MtsArray::from(Array::from_elem(vec![3, 4, 2], 1.0));
35
36        assert_eq!(array.shape().unwrap(), [3, 4, 2]);
37        array.reshape(&[12, 2]).unwrap();
38        assert_eq!(array.shape().unwrap(), [12, 2]);
39        assert_eq!(*array.as_ndarray::<f64>(), Array::from_elem(vec![12, 2], 1.0));
40
41        array.swap_axes(0, 1).unwrap();
42        assert_eq!(array.shape().unwrap(), [2, 12]);
43
44        let array_ref = array.as_ref();
45        assert_eq!(array_ref.shape().unwrap(), [2, 12]);
46        assert_eq!(*array_ref.as_ndarray_lock::<f64>().read().unwrap(), Array::from_elem(vec![2, 12], 1.0));
47
48        let mut array_ref_mut = array.as_mut();
49        assert_eq!(array_ref_mut.shape().unwrap(), [2, 12]);
50
51        array_ref_mut.reshape(&[6, 4]).unwrap();
52        assert_eq!(array_ref_mut.shape().unwrap(), [6, 4]);
53        assert_eq!(*array_ref_mut.as_ndarray_lock::<f64>().read().unwrap(), Array::from_elem(vec![6, 4], 1.0));
54    }
55
56    #[test]
57    fn create() {
58        let array = MtsArray::from(Array::from_elem(vec![4, 2], 1.0));
59
60        assert_eq!(get_data_origin(array.origin().unwrap()).unwrap(), "RustArray");
61        assert_eq!(*array.as_ndarray::<f64>(), Array::from_elem(vec![4, 2], 1.0));
62
63        let fill_value = MtsArray::from(Array::from_elem(vec![], 42.0));
64        let other = array.create(&[5, 3, 7, 12], fill_value.as_ref()).unwrap();
65        assert_eq!(other.shape().unwrap(), [5, 3, 7, 12]);
66        assert_eq!(get_data_origin(other.origin().unwrap()).unwrap(), "RustArray");
67        assert_eq!(*other.as_ndarray::<f64>(), Array::from_elem(vec![5, 3, 7, 12], 42.0));
68    }
69
70    #[test]
71    fn move_data() {
72        let array = MtsArray::from(Array::from_elem(vec![3, 2, 2, 4], 1.0));
73
74        let fill_value = MtsArray::from(Array::from_elem(vec![], 0.0));
75        let mut other = array.create(&[1, 2, 2, 8], fill_value.as_ref()).unwrap();
76        assert_eq!(*other.as_ndarray::<f64>(), Array::from_elem(vec![1, 2, 2, 8], 0.0));
77
78        let mapping = mts_data_movement_t {
79            sample_in: 1,
80            sample_out: 0,
81            properties_start_in: 0,
82            properties_start_out: 2,
83            properties_length: 4,
84        };
85        other.move_data(&array, &[mapping]).unwrap();
86        let expected = Array::from_shape_vec(vec![1, 2, 2, 8], vec![
87                0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
88                0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
89                0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
90                0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0,
91        ]).unwrap();
92        assert_eq!(*other.as_ndarray::<f64>(), expected);
93    }
94}