metatensor/data/
array.rs

1use std::ops::Range;
2use std::os::raw::c_void;
3
4use once_cell::sync::Lazy;
5
6use crate::c_api::{mts_array_t, mts_data_origin_t, mts_sample_mapping_t, mts_status_t};
7
8/// The Array trait is used by metatensor to manage different kind of data array
9/// with a single API. Metatensor only knows about `Box<dyn Array>`, and
10/// manipulate the data through the functions on this trait.
11///
12/// This corresponds to the `mts_array_t` struct in metatensor-core.
13pub trait Array: std::any::Any + Send + Sync {
14    /// Get the array as a `Any` reference
15    fn as_any(&self) -> &dyn std::any::Any;
16
17    /// Get the array as a mutable `Any` reference
18    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
19
20    /// Create a new array with the same options as the current one (data type,
21    /// data location, etc.) and the requested `shape`.
22    ///
23    /// The new array should be filled with zeros.
24    fn create(&self, shape: &[usize]) -> Box<dyn Array>;
25
26    /// Make a copy of this `array`
27    ///
28    /// The new array is expected to have the same data origin and parameters
29    /// (data type, data location, etc.)
30    fn copy(&self) -> Box<dyn Array>;
31
32    /// Get the underlying data storage as a contiguous slice
33    ///
34    /// This function is allowed to panic if the data is not accessible in RAM,
35    /// not stored as 64-bit floating point values, or not stored as a
36    /// C-contiguous array.
37    fn data(&mut self) -> &mut [f64];
38
39    /// Get the shape of the array
40    fn shape(&self) -> &[usize];
41
42    /// Change the shape of the array to the given `shape`
43    fn reshape(&mut self, shape: &[usize]);
44
45    /// Swap the axes `axis_1` and `axis_2` in this array
46    fn swap_axes(&mut self, axis_1: usize, axis_2: usize);
47
48    /// Set entries in `self` taking data from the `input` array.
49    ///
50    /// The `output` array is guaranteed to be created by calling
51    /// `mts_array_t::create` with one of the arrays in the same block or tensor
52    /// map as the `input`.
53    ///
54    /// The `samples` indicate where the data should be moved from `input` to
55    /// `output`.
56    ///
57    /// This function should copy data from `input[sample.input, ..., :]` to
58    /// `array[sample.output, ..., properties]` for each sample in `samples`.
59    /// All indexes are 0-based.
60    fn move_samples_from(
61        &mut self,
62        input: &dyn Array,
63        samples: &[mts_sample_mapping_t],
64        properties: Range<usize>,
65    );
66}
67
68impl From<Box<dyn Array>> for mts_array_t {
69    fn from(array: Box<dyn Array>) -> Self {
70        // We need to box the box to make sure the pointer is a normal 1-word
71        // pointer (`Box<dyn Trait>` contains a 2-words, *fat* pointer which can
72        // not be casted to `*mut c_void`)
73        let array = Box::new(array);
74
75        return mts_array_t {
76            ptr: Box::into_raw(array).cast(),
77            origin: Some(rust_array_origin),
78            data: Some(rust_array_data),
79            shape: Some(rust_array_shape),
80            reshape: Some(rust_array_reshape),
81            swap_axes: Some(rust_array_swap_axes),
82            create: Some(rust_array_create),
83            copy: Some(rust_array_copy),
84            destroy: Some(rust_array_destroy),
85            move_samples_from: Some(rust_array_move_samples_from),
86        }
87    }
88}
89
90macro_rules! check_pointers {
91    ($pointer: ident) => {
92        if $pointer.is_null() {
93            panic!(
94                "got invalid NULL pointer for {} at {}:{}",
95                stringify!($pointer), file!(), line!()
96            );
97        }
98    };
99    ($($pointer: ident),* $(,)?) => {
100        $(check_pointers!($pointer);)*
101    }
102}
103
104pub(super) static RUST_DATA_ORIGIN: Lazy<mts_data_origin_t> = Lazy::new(|| {
105    super::origin::register_data_origin("rust.Box<dyn Array>".into()).expect("failed to register a new origin")
106});
107
108/// Implementation of `mts_array_t.origin` using `Box<dyn Array>`
109unsafe extern "C" fn rust_array_origin(
110    array: *const c_void,
111    origin: *mut mts_data_origin_t
112) -> mts_status_t {
113    crate::errors::catch_unwind(|| {
114        check_pointers!(array, origin);
115        *origin = *RUST_DATA_ORIGIN;
116    })
117}
118
119/// Implementation of `mts_array_t.shape` using `Box<dyn Array>`
120unsafe extern "C" fn rust_array_shape(
121    array: *const c_void,
122    shape: *mut *const usize,
123    shape_count: *mut usize,
124) -> mts_status_t {
125    crate::errors::catch_unwind(|| {
126        check_pointers!(array, shape, shape_count);
127        let array = array.cast::<Box<dyn Array>>();
128        let rust_shape = (*array).shape();
129
130        *shape = rust_shape.as_ptr();
131        *shape_count = rust_shape.len();
132    })
133}
134
135/// Implementation of `mts_array_t.reshape` using `Box<dyn Array>`
136#[allow(clippy::cast_possible_truncation)]
137unsafe extern "C" fn rust_array_reshape(
138    array: *mut c_void,
139    shape: *const usize,
140    shape_count: usize,
141) -> mts_status_t {
142    crate::errors::catch_unwind(|| {
143        assert!(shape_count > 0);
144        assert!(!shape.is_null());
145        check_pointers!(array);
146        let array = array.cast::<Box<dyn Array>>();
147        let shape = std::slice::from_raw_parts(shape, shape_count);
148        (*array).reshape(shape);
149    })
150}
151
152/// Implementation of `mts_array_t.swap_axes` using `Box<dyn Array>`
153#[allow(clippy::cast_possible_truncation)]
154unsafe extern "C" fn rust_array_swap_axes(
155    array: *mut c_void,
156    axis_1: usize,
157    axis_2: usize,
158) -> mts_status_t {
159    crate::errors::catch_unwind(|| {
160        check_pointers!(array);
161        let array = array.cast::<Box<dyn Array>>();
162        (*array).swap_axes(axis_1, axis_2);
163    })
164}
165
166/// Implementation of `mts_array_t.create` using `Box<dyn Array>`
167#[allow(clippy::cast_possible_truncation)]
168unsafe extern "C" fn rust_array_create(
169    array: *const c_void,
170    shape: *const usize,
171    shape_count: usize,
172    array_storage: *mut mts_array_t,
173) -> mts_status_t {
174    crate::errors::catch_unwind(|| {
175        assert!(shape_count > 0);
176        assert!(!shape.is_null());
177        check_pointers!(array, shape, array_storage);
178        let array = array.cast::<Box<dyn Array>>();
179
180        let shape = std::slice::from_raw_parts(shape, shape_count);
181        let new_array = (*array).create(shape);
182
183        *array_storage = new_array.into();
184    })
185}
186
187/// Implementation of `mts_array_t.data` for `Box<dyn Array>`
188unsafe extern "C" fn rust_array_data(
189    array: *mut c_void,
190    data: *mut *mut f64,
191) -> mts_status_t {
192    crate::errors::catch_unwind(|| {
193        check_pointers!(array, data);
194        let array = array.cast::<Box<dyn Array>>();
195        *data = (*array).data().as_mut_ptr();
196    })
197}
198
199
200/// Implementation of `mts_array_t.copy` using `Box<dyn Array>`
201unsafe extern "C" fn rust_array_copy(
202    array: *const c_void,
203    array_storage: *mut mts_array_t,
204) -> mts_status_t {
205    crate::errors::catch_unwind(|| {
206        check_pointers!(array, array_storage);
207        let array = array.cast::<Box<dyn Array>>();
208        *array_storage = (*array).copy().into();
209    })
210}
211
212/// Implementation of `mts_array_t.destroy` for `Box<dyn Array>`
213unsafe extern "C" fn rust_array_destroy(
214    array: *mut c_void,
215) {
216    if !array.is_null() {
217        let array = array.cast::<Box<dyn Array>>();
218        let boxed = Box::from_raw(array);
219        std::mem::drop(boxed);
220    }
221}
222
223/// Implementation of `mts_array_t.move_sample` using `Box<dyn Array>`
224#[allow(clippy::cast_possible_truncation)]
225unsafe extern "C" fn rust_array_move_samples_from(
226    output: *mut c_void,
227    input: *const c_void,
228    samples: *const mts_sample_mapping_t,
229    samples_count: usize,
230    property_start: usize,
231    property_end: usize,
232) -> mts_status_t {
233    crate::errors::catch_unwind(|| {
234        check_pointers!(output, input);
235        let output = output.cast::<Box<dyn Array>>();
236        let input = input.cast::<Box<dyn Array>>();
237
238        let samples = if samples_count == 0 {
239            &[]
240        } else {
241            check_pointers!(samples);
242            std::slice::from_raw_parts(samples, samples_count)
243        };
244
245        (*output).move_samples_from(&**input, samples, property_start..property_end);
246    })
247}
248
249/******************************************************************************/
250
251impl Array for ndarray::ArrayD<f64> {
252    fn as_any(&self) -> &dyn std::any::Any {
253        self
254    }
255
256    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
257        self
258    }
259
260    fn create(&self, shape: &[usize]) -> Box<dyn Array> {
261        return Box::new(ndarray::Array::from_elem(shape, 0.0));
262    }
263
264    fn copy(&self) -> Box<dyn Array> {
265        return Box::new(self.clone());
266    }
267
268    fn data(&mut self) -> &mut [f64] {
269        return self.as_slice_mut().expect("array is not contiguous")
270    }
271
272    fn shape(&self) -> &[usize] {
273        return self.shape();
274    }
275
276    fn reshape(&mut self, shape: &[usize]) {
277        let mut array = std::mem::take(self);
278        array = array.to_shape(shape).expect("invalid shape").to_owned();
279        std::mem::swap(self, &mut array);
280    }
281
282    fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
283        self.swap_axes(axis_1, axis_2);
284    }
285
286    fn move_samples_from(
287        &mut self,
288        input: &dyn Array,
289        samples: &[mts_sample_mapping_t],
290        property: Range<usize>,
291    ) {
292        use ndarray::{Axis, Slice};
293
294        // -2 since we also remove one axis with `index_axis_mut` below
295        let property_axis = self.shape().len() - 2;
296
297        let input = input.as_any().downcast_ref::<ndarray::ArrayD<f64>>().expect("input must be a ndarray");
298        for sample in samples {
299            let value = input.index_axis(Axis(0), sample.input);
300
301            let mut output_location = self.index_axis_mut(Axis(0), sample.output);
302            let mut output_location = output_location.slice_axis_mut(
303                Axis(property_axis), Slice::from(property.clone())
304            );
305
306            output_location.assign(&value);
307        }
308    }
309}
310
311/******************************************************************************/
312
313/// An implementation of the [`Array`] trait without any data.
314///
315/// This only tracks the shape of the array.
316#[derive(Debug, Clone)]
317pub struct EmptyArray {
318    shape: Vec<usize>,
319}
320
321impl EmptyArray {
322    /// Create a new `EmptyArray` with the given shape.
323    pub fn new(shape: Vec<usize>) -> EmptyArray {
324        EmptyArray { shape }
325    }
326}
327
328impl Array for EmptyArray {
329    fn as_any(&self) -> &dyn std::any::Any {
330        self
331    }
332
333    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
334        self
335    }
336
337    fn data(&mut self) -> &mut [f64] {
338        panic!("can not call Array::data() for EmptyArray");
339    }
340
341    fn create(&self, shape: &[usize]) -> Box<dyn Array> {
342        Box::new(EmptyArray { shape: shape.to_vec() })
343    }
344
345    fn copy(&self) -> Box<dyn Array> {
346        Box::new(EmptyArray { shape: self.shape.clone() })
347    }
348
349    fn shape(&self) -> &[usize] {
350        &self.shape
351    }
352
353    fn reshape(&mut self, shape: &[usize]) {
354        self.shape = shape.to_vec();
355    }
356
357    fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
358        self.shape.swap(axis_1, axis_2);
359    }
360
361    fn move_samples_from(&mut self, _: &dyn Array, _: &[mts_sample_mapping_t], _: Range<usize>) {
362        panic!("can not call Array::move_samples_from() for EmptyArray");
363    }
364}