1use std::ops::Range;
2use std::os::raw::c_void;
3
4use once_cell::sync::Lazy;
5
6use crate::c_api::{mts_array_t, mts_data_origin_t, mts_sample_mapping_t, mts_status_t};
7
8pub trait Array: std::any::Any + Send + Sync {
14 fn as_any(&self) -> &dyn std::any::Any;
16
17 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
19
20 fn create(&self, shape: &[usize]) -> Box<dyn Array>;
25
26 fn copy(&self) -> Box<dyn Array>;
31
32 fn data(&mut self) -> &mut [f64];
38
39 fn shape(&self) -> &[usize];
41
42 fn reshape(&mut self, shape: &[usize]);
44
45 fn swap_axes(&mut self, axis_1: usize, axis_2: usize);
47
48 fn move_samples_from(
61 &mut self,
62 input: &dyn Array,
63 samples: &[mts_sample_mapping_t],
64 properties: Range<usize>,
65 );
66}
67
68impl From<Box<dyn Array>> for mts_array_t {
69 fn from(array: Box<dyn Array>) -> Self {
70 let array = Box::new(array);
74
75 return mts_array_t {
76 ptr: Box::into_raw(array).cast(),
77 origin: Some(rust_array_origin),
78 data: Some(rust_array_data),
79 shape: Some(rust_array_shape),
80 reshape: Some(rust_array_reshape),
81 swap_axes: Some(rust_array_swap_axes),
82 create: Some(rust_array_create),
83 copy: Some(rust_array_copy),
84 destroy: Some(rust_array_destroy),
85 move_samples_from: Some(rust_array_move_samples_from),
86 }
87 }
88}
89
90macro_rules! check_pointers {
91 ($pointer: ident) => {
92 if $pointer.is_null() {
93 panic!(
94 "got invalid NULL pointer for {} at {}:{}",
95 stringify!($pointer), file!(), line!()
96 );
97 }
98 };
99 ($($pointer: ident),* $(,)?) => {
100 $(check_pointers!($pointer);)*
101 }
102}
103
104pub(super) static RUST_DATA_ORIGIN: Lazy<mts_data_origin_t> = Lazy::new(|| {
105 super::origin::register_data_origin("rust.Box<dyn Array>".into()).expect("failed to register a new origin")
106});
107
108unsafe extern "C" fn rust_array_origin(
110 array: *const c_void,
111 origin: *mut mts_data_origin_t
112) -> mts_status_t {
113 crate::errors::catch_unwind(|| {
114 check_pointers!(array, origin);
115 *origin = *RUST_DATA_ORIGIN;
116 })
117}
118
119unsafe extern "C" fn rust_array_shape(
121 array: *const c_void,
122 shape: *mut *const usize,
123 shape_count: *mut usize,
124) -> mts_status_t {
125 crate::errors::catch_unwind(|| {
126 check_pointers!(array, shape, shape_count);
127 let array = array.cast::<Box<dyn Array>>();
128 let rust_shape = (*array).shape();
129
130 *shape = rust_shape.as_ptr();
131 *shape_count = rust_shape.len();
132 })
133}
134
135#[allow(clippy::cast_possible_truncation)]
137unsafe extern "C" fn rust_array_reshape(
138 array: *mut c_void,
139 shape: *const usize,
140 shape_count: usize,
141) -> mts_status_t {
142 crate::errors::catch_unwind(|| {
143 assert!(shape_count > 0);
144 assert!(!shape.is_null());
145 check_pointers!(array);
146 let array = array.cast::<Box<dyn Array>>();
147 let shape = std::slice::from_raw_parts(shape, shape_count);
148 (*array).reshape(shape);
149 })
150}
151
152#[allow(clippy::cast_possible_truncation)]
154unsafe extern "C" fn rust_array_swap_axes(
155 array: *mut c_void,
156 axis_1: usize,
157 axis_2: usize,
158) -> mts_status_t {
159 crate::errors::catch_unwind(|| {
160 check_pointers!(array);
161 let array = array.cast::<Box<dyn Array>>();
162 (*array).swap_axes(axis_1, axis_2);
163 })
164}
165
166#[allow(clippy::cast_possible_truncation)]
168unsafe extern "C" fn rust_array_create(
169 array: *const c_void,
170 shape: *const usize,
171 shape_count: usize,
172 array_storage: *mut mts_array_t,
173) -> mts_status_t {
174 crate::errors::catch_unwind(|| {
175 assert!(shape_count > 0);
176 assert!(!shape.is_null());
177 check_pointers!(array, shape, array_storage);
178 let array = array.cast::<Box<dyn Array>>();
179
180 let shape = std::slice::from_raw_parts(shape, shape_count);
181 let new_array = (*array).create(shape);
182
183 *array_storage = new_array.into();
184 })
185}
186
187unsafe extern "C" fn rust_array_data(
189 array: *mut c_void,
190 data: *mut *mut f64,
191) -> mts_status_t {
192 crate::errors::catch_unwind(|| {
193 check_pointers!(array, data);
194 let array = array.cast::<Box<dyn Array>>();
195 *data = (*array).data().as_mut_ptr();
196 })
197}
198
199
200unsafe extern "C" fn rust_array_copy(
202 array: *const c_void,
203 array_storage: *mut mts_array_t,
204) -> mts_status_t {
205 crate::errors::catch_unwind(|| {
206 check_pointers!(array, array_storage);
207 let array = array.cast::<Box<dyn Array>>();
208 *array_storage = (*array).copy().into();
209 })
210}
211
212unsafe extern "C" fn rust_array_destroy(
214 array: *mut c_void,
215) {
216 if !array.is_null() {
217 let array = array.cast::<Box<dyn Array>>();
218 let boxed = Box::from_raw(array);
219 std::mem::drop(boxed);
220 }
221}
222
223#[allow(clippy::cast_possible_truncation)]
225unsafe extern "C" fn rust_array_move_samples_from(
226 output: *mut c_void,
227 input: *const c_void,
228 samples: *const mts_sample_mapping_t,
229 samples_count: usize,
230 property_start: usize,
231 property_end: usize,
232) -> mts_status_t {
233 crate::errors::catch_unwind(|| {
234 check_pointers!(output, input);
235 let output = output.cast::<Box<dyn Array>>();
236 let input = input.cast::<Box<dyn Array>>();
237
238 let samples = if samples_count == 0 {
239 &[]
240 } else {
241 check_pointers!(samples);
242 std::slice::from_raw_parts(samples, samples_count)
243 };
244
245 (*output).move_samples_from(&**input, samples, property_start..property_end);
246 })
247}
248
249impl Array for ndarray::ArrayD<f64> {
252 fn as_any(&self) -> &dyn std::any::Any {
253 self
254 }
255
256 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
257 self
258 }
259
260 fn create(&self, shape: &[usize]) -> Box<dyn Array> {
261 return Box::new(ndarray::Array::from_elem(shape, 0.0));
262 }
263
264 fn copy(&self) -> Box<dyn Array> {
265 return Box::new(self.clone());
266 }
267
268 fn data(&mut self) -> &mut [f64] {
269 return self.as_slice_mut().expect("array is not contiguous")
270 }
271
272 fn shape(&self) -> &[usize] {
273 return self.shape();
274 }
275
276 fn reshape(&mut self, shape: &[usize]) {
277 let mut array = std::mem::take(self);
278 array = array.to_shape(shape).expect("invalid shape").to_owned();
279 std::mem::swap(self, &mut array);
280 }
281
282 fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
283 self.swap_axes(axis_1, axis_2);
284 }
285
286 fn move_samples_from(
287 &mut self,
288 input: &dyn Array,
289 samples: &[mts_sample_mapping_t],
290 property: Range<usize>,
291 ) {
292 use ndarray::{Axis, Slice};
293
294 let property_axis = self.shape().len() - 2;
296
297 let input = input.as_any().downcast_ref::<ndarray::ArrayD<f64>>().expect("input must be a ndarray");
298 for sample in samples {
299 let value = input.index_axis(Axis(0), sample.input);
300
301 let mut output_location = self.index_axis_mut(Axis(0), sample.output);
302 let mut output_location = output_location.slice_axis_mut(
303 Axis(property_axis), Slice::from(property.clone())
304 );
305
306 output_location.assign(&value);
307 }
308 }
309}
310
311#[derive(Debug, Clone)]
317pub struct EmptyArray {
318 shape: Vec<usize>,
319}
320
321impl EmptyArray {
322 pub fn new(shape: Vec<usize>) -> EmptyArray {
324 EmptyArray { shape }
325 }
326}
327
328impl Array for EmptyArray {
329 fn as_any(&self) -> &dyn std::any::Any {
330 self
331 }
332
333 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
334 self
335 }
336
337 fn data(&mut self) -> &mut [f64] {
338 panic!("can not call Array::data() for EmptyArray");
339 }
340
341 fn create(&self, shape: &[usize]) -> Box<dyn Array> {
342 Box::new(EmptyArray { shape: shape.to_vec() })
343 }
344
345 fn copy(&self) -> Box<dyn Array> {
346 Box::new(EmptyArray { shape: self.shape.clone() })
347 }
348
349 fn shape(&self) -> &[usize] {
350 &self.shape
351 }
352
353 fn reshape(&mut self, shape: &[usize]) {
354 self.shape = shape.to_vec();
355 }
356
357 fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
358 self.shape.swap(axis_1, axis_2);
359 }
360
361 fn move_samples_from(&mut self, _: &dyn Array, _: &[mts_sample_mapping_t], _: Range<usize>) {
362 panic!("can not call Array::move_samples_from() for EmptyArray");
363 }
364}