metatensor/data/array_ref.rs
1use crate::c_api::mts_array_t;
2use crate::data::origin::get_data_origin;
3
4use super::Array;
5
6/// Reference to a data array in metatensor-core
7///
8/// The data array can come from any origin, this struct provides facilities to
9/// access data that was created through the [`Array`] trait, and in particular
10/// as `ndarray::ArrayD` instances.
11#[derive(Debug, Clone, Copy)]
12pub struct ArrayRef<'a> {
13 array: mts_array_t,
14 /// `ArrayRef` should behave like `&'a mts_array_t`
15 marker: std::marker::PhantomData<&'a mts_array_t>,
16}
17
18impl<'a> ArrayRef<'a> {
19 /// Create a new `ArrayRef` from the given raw `mts_array_t`
20 ///
21 /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
22 /// Make sure the lifetime is actually constrained by the lifetime of the
23 /// owner of this `mts_array_t`.
24 pub unsafe fn from_raw(array: mts_array_t) -> ArrayRef<'a> {
25 ArrayRef {
26 array,
27 marker: std::marker::PhantomData,
28 }
29 }
30
31 /// Get the underlying array as an `&dyn Any` instance.
32 ///
33 /// This function panics if the array was not created though this crate and
34 /// the [`Array`] trait.
35 #[inline]
36 pub fn as_any(&self) -> &dyn std::any::Any {
37 let origin = self.array.origin().unwrap_or(0);
38 assert_eq!(
39 origin, *super::array::RUST_DATA_ORIGIN,
40 "this array was not created as a rust Array (origin is '{}')",
41 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
42 );
43
44 let array = self.array.ptr.cast::<Box<dyn Array>>();
45 unsafe {
46 return (*array).as_any();
47 }
48 }
49
50 /// Get a reference to the underlying array as an `&dyn Any` instance,
51 /// re-using the same lifetime as the `ArrayRef`.
52 ///
53 /// This function panics if the array was not created though this crate and
54 /// the [`Array`] trait.
55 #[inline]
56 pub fn to_any(self) -> &'a dyn std::any::Any {
57 let origin = self.array.origin().unwrap_or(0);
58 assert_eq!(
59 origin, *super::array::RUST_DATA_ORIGIN,
60 "this array was not created as a rust Array (origin is '{}')",
61 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
62 );
63
64 let array = self.array.ptr.cast::<Box<dyn Array>>();
65 unsafe {
66 return (*array).as_any();
67 }
68 }
69
70 /// Get the data in this `ArrayRef` as a `ndarray::ArrayD`. This function
71 /// will panic if the data in this `mts_array_t` is not a `ndarray::ArrayD`.
72 #[inline]
73 pub fn as_array(&self) -> &ndarray::ArrayD<f64> {
74 self.as_any().downcast_ref().expect("this is not a ndarray::ArrayD")
75 }
76
77 /// Transform this `ArrayRef` into a reference to an `ndarray::ArrayD`,
78 /// keeping the lifetime of the `ArrayRef`.
79 ///
80 /// This function will panic if the data in this `mts_array_t` is not a
81 /// `ndarray::ArrayD`.
82 #[inline]
83 pub fn to_array(self) -> &'a ndarray::ArrayD<f64> {
84 self.to_any().downcast_ref().expect("this is not a ndarray::ArrayD")
85 }
86
87 /// Get the raw underlying `mts_array_t`
88 pub fn as_raw(&self) -> &mts_array_t {
89 &self.array
90 }
91}
92
93/// Mutable reference to a data array in metatensor-core
94///
95/// The data array can come from any origin, this struct provides facilities to
96/// access data that was created through the [`Array`] trait, and in particular
97/// as `ndarray::ArrayD` instances.
98#[derive(Debug)]
99pub struct ArrayRefMut<'a> {
100 array: mts_array_t,
101 /// `ArrayRefMut` should behave like `&'a mut mts_array_t`
102 marker: std::marker::PhantomData<&'a mut mts_array_t>,
103}
104
105impl<'a> ArrayRefMut<'a> {
106 /// Create a new `ArrayRefMut` from the given raw `mts_array_t`
107 ///
108 /// This is a **VERY** unsafe function, creating a lifetime out of thin air,
109 /// and allowing mutable access to the `mts_array_t`. Make sure the lifetime
110 /// is actually constrained by the lifetime of the owner of this
111 /// `mts_array_t`; and that the owner is mutably borrowed by this
112 /// `ArrayRefMut`.
113 #[inline]
114 pub unsafe fn new(array: mts_array_t) -> ArrayRefMut<'a> {
115 ArrayRefMut {
116 array,
117 marker: std::marker::PhantomData,
118 }
119 }
120
121 /// Get the underlying array as an `&dyn Any` instance.
122 ///
123 /// This function panics if the array was not created though this crate and
124 /// the [`Array`] trait.
125 #[inline]
126 pub fn as_any(&self) -> &dyn std::any::Any {
127 let origin = self.array.origin().unwrap_or(0);
128 assert_eq!(
129 origin, *super::array::RUST_DATA_ORIGIN,
130 "this array was not created as a rust Array (origin is '{}')",
131 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
132 );
133
134 let array = self.array.ptr.cast::<Box<dyn Array>>();
135 unsafe {
136 return (*array).as_any();
137 }
138 }
139
140 /// Get the underlying array as an `&dyn Any` instance,
141 /// re-using the same lifetime as the `ArrayRefMut`.
142 ///
143 /// This function panics if the array was not created though this crate and
144 /// the [`Array`] trait.
145 #[inline]
146 pub fn to_any(&self) -> &'a dyn std::any::Any {
147 let origin = self.array.origin().unwrap_or(0);
148 assert_eq!(
149 origin, *super::array::RUST_DATA_ORIGIN,
150 "this array was not created as a rust Array (origin is '{}')",
151 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
152 );
153
154 let array = self.array.ptr.cast::<Box<dyn Array>>();
155 unsafe {
156 return (*array).as_any();
157 }
158 }
159
160 /// Get the underlying array as an `&mut dyn Any` instance.
161 ///
162 /// This function panics if the array was not created though this crate and
163 /// the [`Array`] trait.
164 #[inline]
165 pub fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
166 let origin = self.array.origin().unwrap_or(0);
167 assert_eq!(
168 origin, *super::array::RUST_DATA_ORIGIN,
169 "this array was not created as a rust Array (origin is '{}')",
170 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
171 );
172
173 let array = self.array.ptr.cast::<Box<dyn Array>>();
174 unsafe {
175 return (*array).as_any_mut();
176 }
177 }
178
179 /// Get the underlying array as an `&mut dyn Any` instance, re-using the
180 /// same lifetime as the `ArrayRefMut`.
181 ///
182 /// This function panics if the array was not created though this crate and
183 /// the [`Array`] trait.
184 #[inline]
185 pub fn to_any_mut(self) -> &'a mut dyn std::any::Any {
186 let origin = self.array.origin().unwrap_or(0);
187 assert_eq!(
188 origin, *super::array::RUST_DATA_ORIGIN,
189 "this array was not created as a rust Array (origin is '{}')",
190 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
191 );
192
193 let array = self.array.ptr.cast::<Box<dyn Array>>();
194 unsafe {
195 return (*array).as_any_mut();
196 }
197 }
198
199 /// Get the data in this `ArrayRef` as a `ndarray::ArrayD`. This function
200 /// will panic if the data in this `mts_array_t` is not a `ndarray::ArrayD`.
201 #[inline]
202 pub fn as_array(&self) -> &ndarray::ArrayD<f64> {
203 self.as_any().downcast_ref().expect("this is not a ndarray::ArrayD")
204 }
205
206 /// Transform this `ArrayRefMut` into a reference to an `ndarray::ArrayD`,
207 /// keeping the lifetime of the `ArrayRefMut`.
208 ///
209 /// This function will panic if the data in this `mts_array_t` is not a
210 /// `ndarray::ArrayD`.
211 #[inline]
212 pub fn to_array(&self) -> &ndarray::ArrayD<f64> {
213 self.to_any().downcast_ref().expect("this is not a ndarray::ArrayD")
214 }
215
216 /// Get the data in this `ArrayRef` as a mutable reference to an
217 /// `ndarray::ArrayD`. This function will panic if the data in this
218 /// `mts_array_t` is not a `ndarray::ArrayD`.
219 #[inline]
220 pub fn as_array_mut(&mut self) -> &mut ndarray::ArrayD<f64> {
221 self.as_any_mut().downcast_mut().expect("this is not a ndarray::ArrayD")
222 }
223
224 /// Transform this `ArrayRefMut` into a mutable reference to an
225 /// `ndarray::ArrayD`, keeping the lifetime of the `ArrayRefMut`.
226 ///
227 /// This function will panic if the data in this `mts_array_t` is not a
228 /// `ndarray::ArrayD`.
229 #[inline]
230 pub fn to_array_mut(self) -> &'a mut ndarray::ArrayD<f64> {
231 self.to_any_mut().downcast_mut().expect("this is not a ndarray::ArrayD")
232 }
233
234 /// Get the raw underlying `mts_array_t`
235 pub fn as_raw(&self) -> &mts_array_t {
236 &self.array
237 }
238
239 /// Get a mutable reference to the raw underlying `mts_array_t`
240 pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
241 &mut self.array
242 }
243}