metatensor/data/
array_ref.rs

1use crate::c_api::mts_array_t;
2use crate::data::origin::get_data_origin;
3
4use super::Array;
5
6/// Reference to a data array in metatensor-core
7///
8/// The data array can come from any origin, this struct provides facilities to
9/// access data that was created through the [`Array`] trait, and in particular
10/// as `ndarray::ArrayD` instances.
11#[derive(Debug, Clone, Copy)]
12pub struct ArrayRef<'a> {
13    array: mts_array_t,
14    /// `ArrayRef` should behave like `&'a mts_array_t`
15    marker: std::marker::PhantomData<&'a mts_array_t>,
16}
17
18impl<'a> ArrayRef<'a> {
19    /// Create a new `ArrayRef` from the given raw `mts_array_t`
20    ///
21    /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
22    /// Make sure the lifetime is actually constrained by the lifetime of the
23    /// owner of this `mts_array_t`.
24    pub unsafe fn from_raw(array: mts_array_t) -> ArrayRef<'a> {
25        ArrayRef {
26            array,
27            marker: std::marker::PhantomData,
28        }
29    }
30
31    /// Get the underlying array as an `&dyn Any` instance.
32    ///
33    /// This function panics if the array was not created though this crate and
34    /// the [`Array`] trait.
35    #[inline]
36    pub fn as_any(&self) -> &dyn std::any::Any {
37        let origin = self.array.origin().unwrap_or(0);
38        assert_eq!(
39            origin, *super::array::RUST_DATA_ORIGIN,
40            "this array was not created as a rust Array (origin is '{}')",
41            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
42        );
43
44        let array = self.array.ptr.cast::<Box<dyn Array>>();
45        unsafe {
46            return (*array).as_any();
47        }
48    }
49
50    /// Get a reference to the underlying array as an `&dyn Any` instance,
51    /// re-using the same lifetime as the `ArrayRef`.
52    ///
53    /// This function panics if the array was not created though this crate and
54    /// the [`Array`] trait.
55    #[inline]
56    pub fn to_any(self) -> &'a dyn std::any::Any {
57        let origin = self.array.origin().unwrap_or(0);
58        assert_eq!(
59            origin, *super::array::RUST_DATA_ORIGIN,
60            "this array was not created as a rust Array (origin is '{}')",
61            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
62        );
63
64        let array = self.array.ptr.cast::<Box<dyn Array>>();
65        unsafe {
66            return (*array).as_any();
67        }
68    }
69
70    /// Get the data in this `ArrayRef` as a `ndarray::ArrayD`. This function
71    /// will panic if the data in this `mts_array_t` is not a `ndarray::ArrayD`.
72    #[inline]
73    pub fn as_array(&self) -> &ndarray::ArrayD<f64> {
74        self.as_any().downcast_ref().expect("this is not a ndarray::ArrayD")
75    }
76
77    /// Transform this `ArrayRef` into a reference to an `ndarray::ArrayD`,
78    /// keeping the lifetime of the `ArrayRef`.
79    ///
80    /// This function will panic if the data in this `mts_array_t` is not a
81    /// `ndarray::ArrayD`.
82    #[inline]
83    pub fn to_array(self) -> &'a ndarray::ArrayD<f64> {
84        self.to_any().downcast_ref().expect("this is not a ndarray::ArrayD")
85    }
86
87    /// Get the raw underlying `mts_array_t`
88    pub fn as_raw(&self) -> &mts_array_t {
89        &self.array
90    }
91}
92
93/// Mutable reference to a data array in metatensor-core
94///
95/// The data array can come from any origin, this struct provides facilities to
96/// access data that was created through the [`Array`] trait, and in particular
97/// as `ndarray::ArrayD` instances.
98#[derive(Debug)]
99pub struct ArrayRefMut<'a> {
100    array: mts_array_t,
101    /// `ArrayRefMut` should behave like `&'a mut mts_array_t`
102    marker: std::marker::PhantomData<&'a mut mts_array_t>,
103}
104
105impl<'a> ArrayRefMut<'a> {
106    /// Create a new `ArrayRefMut` from the given raw `mts_array_t`
107    ///
108    /// This is a **VERY** unsafe function, creating a lifetime out of thin air,
109    /// and allowing mutable access to the `mts_array_t`. Make sure the lifetime
110    /// is actually constrained by the lifetime of the owner of this
111    /// `mts_array_t`; and that the owner is mutably borrowed by this
112    /// `ArrayRefMut`.
113    #[inline]
114    pub unsafe fn new(array: mts_array_t) -> ArrayRefMut<'a> {
115        ArrayRefMut {
116            array,
117            marker: std::marker::PhantomData,
118        }
119    }
120
121    /// Get the underlying array as an `&dyn Any` instance.
122    ///
123    /// This function panics if the array was not created though this crate and
124    /// the [`Array`] trait.
125    #[inline]
126    pub fn as_any(&self) -> &dyn std::any::Any {
127        let origin = self.array.origin().unwrap_or(0);
128        assert_eq!(
129            origin, *super::array::RUST_DATA_ORIGIN,
130            "this array was not created as a rust Array (origin is '{}')",
131            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
132        );
133
134        let array = self.array.ptr.cast::<Box<dyn Array>>();
135        unsafe {
136            return (*array).as_any();
137        }
138    }
139
140    /// Get the underlying array as an `&dyn Any` instance,
141    /// re-using the same lifetime as the `ArrayRefMut`.
142    ///
143    /// This function panics if the array was not created though this crate and
144    /// the [`Array`] trait.
145    #[inline]
146    pub fn to_any(&self) -> &'a dyn std::any::Any {
147        let origin = self.array.origin().unwrap_or(0);
148        assert_eq!(
149            origin, *super::array::RUST_DATA_ORIGIN,
150            "this array was not created as a rust Array (origin is '{}')",
151            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
152        );
153
154        let array = self.array.ptr.cast::<Box<dyn Array>>();
155        unsafe {
156            return (*array).as_any();
157        }
158    }
159
160    /// Get the underlying array as an `&mut dyn Any` instance.
161    ///
162    /// This function panics if the array was not created though this crate and
163    /// the [`Array`] trait.
164    #[inline]
165    pub fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
166        let origin = self.array.origin().unwrap_or(0);
167        assert_eq!(
168            origin, *super::array::RUST_DATA_ORIGIN,
169            "this array was not created as a rust Array (origin is '{}')",
170            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
171        );
172
173        let array = self.array.ptr.cast::<Box<dyn Array>>();
174        unsafe {
175            return (*array).as_any_mut();
176        }
177    }
178
179    /// Get the underlying array as an `&mut dyn Any` instance, re-using the
180    /// same lifetime as the `ArrayRefMut`.
181    ///
182    /// This function panics if the array was not created though this crate and
183    /// the [`Array`] trait.
184    #[inline]
185    pub fn to_any_mut(self) -> &'a mut dyn std::any::Any {
186        let origin = self.array.origin().unwrap_or(0);
187        assert_eq!(
188            origin, *super::array::RUST_DATA_ORIGIN,
189            "this array was not created as a rust Array (origin is '{}')",
190            get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
191        );
192
193        let array = self.array.ptr.cast::<Box<dyn Array>>();
194        unsafe {
195            return (*array).as_any_mut();
196        }
197    }
198
199    /// Get the data in this `ArrayRef` as a `ndarray::ArrayD`. This function
200    /// will panic if the data in this `mts_array_t` is not a `ndarray::ArrayD`.
201    #[inline]
202    pub fn as_array(&self) -> &ndarray::ArrayD<f64> {
203        self.as_any().downcast_ref().expect("this is not a ndarray::ArrayD")
204    }
205
206    /// Transform this `ArrayRefMut` into a reference to an `ndarray::ArrayD`,
207    /// keeping the lifetime of the `ArrayRefMut`.
208    ///
209    /// This function will panic if the data in this `mts_array_t` is not a
210    /// `ndarray::ArrayD`.
211    #[inline]
212    pub fn to_array(&self) -> &ndarray::ArrayD<f64> {
213        self.to_any().downcast_ref().expect("this is not a ndarray::ArrayD")
214    }
215
216    /// Get the data in this `ArrayRef` as a mutable reference to an
217    /// `ndarray::ArrayD`. This function will panic if the data in this
218    /// `mts_array_t` is not a `ndarray::ArrayD`.
219    #[inline]
220    pub fn as_array_mut(&mut self) -> &mut ndarray::ArrayD<f64> {
221        self.as_any_mut().downcast_mut().expect("this is not a ndarray::ArrayD")
222    }
223
224    /// Transform this `ArrayRefMut` into a mutable reference to an
225    /// `ndarray::ArrayD`, keeping the lifetime of the `ArrayRefMut`.
226    ///
227    /// This function will panic if the data in this `mts_array_t` is not a
228    /// `ndarray::ArrayD`.
229    #[inline]
230    pub fn to_array_mut(self) -> &'a mut ndarray::ArrayD<f64> {
231        self.to_any_mut().downcast_mut().expect("this is not a ndarray::ArrayD")
232    }
233
234    /// Get the raw underlying `mts_array_t`
235    pub fn as_raw(&self) -> &mts_array_t {
236        &self.array
237    }
238
239    /// Get a mutable reference to the raw underlying `mts_array_t`
240    pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
241        &mut self.array
242    }
243}