use std::ops::Range;
use std::os::raw::c_void;
use once_cell::sync::Lazy;
use crate::c_api::{mts_array_t, mts_data_origin_t, mts_sample_mapping_t, mts_status_t};
pub trait Array: std::any::Any + Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
fn create(&self, shape: &[usize]) -> Box<dyn Array>;
fn copy(&self) -> Box<dyn Array>;
fn data(&mut self) -> &mut [f64];
fn shape(&self) -> &[usize];
fn reshape(&mut self, shape: &[usize]);
fn swap_axes(&mut self, axis_1: usize, axis_2: usize);
fn move_samples_from(
&mut self,
input: &dyn Array,
samples: &[mts_sample_mapping_t],
properties: Range<usize>,
);
}
impl From<Box<dyn Array>> for mts_array_t {
fn from(array: Box<dyn Array>) -> Self {
let array = Box::new(array);
return mts_array_t {
ptr: Box::into_raw(array).cast(),
origin: Some(rust_array_origin),
data: Some(rust_array_data),
shape: Some(rust_array_shape),
reshape: Some(rust_array_reshape),
swap_axes: Some(rust_array_swap_axes),
create: Some(rust_array_create),
copy: Some(rust_array_copy),
destroy: Some(rust_array_destroy),
move_samples_from: Some(rust_array_move_samples_from),
}
}
}
macro_rules! check_pointers {
($pointer: ident) => {
if $pointer.is_null() {
panic!(
"got invalid NULL pointer for {} at {}:{}",
stringify!($pointer), file!(), line!()
);
}
};
($($pointer: ident),* $(,)?) => {
$(check_pointers!($pointer);)*
}
}
pub(super) static RUST_DATA_ORIGIN: Lazy<mts_data_origin_t> = Lazy::new(|| {
super::origin::register_data_origin("rust.Box<dyn Array>".into()).expect("failed to register a new origin")
});
unsafe extern fn rust_array_origin(
array: *const c_void,
origin: *mut mts_data_origin_t
) -> mts_status_t {
crate::errors::catch_unwind(|| {
check_pointers!(array, origin);
*origin = *RUST_DATA_ORIGIN;
})
}
unsafe extern fn rust_array_shape(
array: *const c_void,
shape: *mut *const usize,
shape_count: *mut usize,
) -> mts_status_t {
crate::errors::catch_unwind(|| {
check_pointers!(array, shape, shape_count);
let array = array.cast::<Box<dyn Array>>();
let rust_shape = (*array).shape();
*shape = rust_shape.as_ptr();
*shape_count = rust_shape.len();
})
}
#[allow(clippy::cast_possible_truncation)]
unsafe extern fn rust_array_reshape(
array: *mut c_void,
shape: *const usize,
shape_count: usize,
) -> mts_status_t {
crate::errors::catch_unwind(|| {
assert!(shape_count > 0);
assert!(!shape.is_null());
check_pointers!(array);
let array = array.cast::<Box<dyn Array>>();
let shape = std::slice::from_raw_parts(shape, shape_count);
(*array).reshape(shape);
})
}
#[allow(clippy::cast_possible_truncation)]
unsafe extern fn rust_array_swap_axes(
array: *mut c_void,
axis_1: usize,
axis_2: usize,
) -> mts_status_t {
crate::errors::catch_unwind(|| {
check_pointers!(array);
let array = array.cast::<Box<dyn Array>>();
(*array).swap_axes(axis_1, axis_2);
})
}
#[allow(clippy::cast_possible_truncation)]
unsafe extern fn rust_array_create(
array: *const c_void,
shape: *const usize,
shape_count: usize,
array_storage: *mut mts_array_t,
) -> mts_status_t {
crate::errors::catch_unwind(|| {
assert!(shape_count > 0);
assert!(!shape.is_null());
check_pointers!(array, shape, array_storage);
let array = array.cast::<Box<dyn Array>>();
let shape = std::slice::from_raw_parts(shape, shape_count);
let new_array = (*array).create(shape);
*array_storage = new_array.into();
})
}
unsafe extern fn rust_array_data(
array: *mut c_void,
data: *mut *mut f64,
) -> mts_status_t {
crate::errors::catch_unwind(|| {
check_pointers!(array, data);
let array = array.cast::<Box<dyn Array>>();
*data = (*array).data().as_mut_ptr();
})
}
unsafe extern fn rust_array_copy(
array: *const c_void,
array_storage: *mut mts_array_t,
) -> mts_status_t {
crate::errors::catch_unwind(|| {
check_pointers!(array, array_storage);
let array = array.cast::<Box<dyn Array>>();
*array_storage = (*array).copy().into();
})
}
unsafe extern fn rust_array_destroy(
array: *mut c_void,
) {
if !array.is_null() {
let array = array.cast::<Box<dyn Array>>();
let boxed = Box::from_raw(array);
std::mem::drop(boxed);
}
}
#[allow(clippy::cast_possible_truncation)]
unsafe extern fn rust_array_move_samples_from(
output: *mut c_void,
input: *const c_void,
samples: *const mts_sample_mapping_t,
samples_count: usize,
property_start: usize,
property_end: usize,
) -> mts_status_t {
crate::errors::catch_unwind(|| {
check_pointers!(output, input);
let output = output.cast::<Box<dyn Array>>();
let input = input.cast::<Box<dyn Array>>();
let samples = if samples_count == 0 {
&[]
} else {
check_pointers!(samples);
std::slice::from_raw_parts(samples, samples_count)
};
(*output).move_samples_from(&**input, samples, property_start..property_end);
})
}
impl Array for ndarray::ArrayD<f64> {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn create(&self, shape: &[usize]) -> Box<dyn Array> {
return Box::new(ndarray::Array::from_elem(shape, 0.0));
}
fn copy(&self) -> Box<dyn Array> {
return Box::new(self.clone());
}
fn data(&mut self) -> &mut [f64] {
return self.as_slice_mut().expect("array is not contiguous")
}
fn shape(&self) -> &[usize] {
return self.shape();
}
fn reshape(&mut self, shape: &[usize]) {
let mut array = std::mem::take(self);
array = array.to_shape(shape).expect("invalid shape").to_owned();
std::mem::swap(self, &mut array);
}
fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
self.swap_axes(axis_1, axis_2);
}
fn move_samples_from(
&mut self,
input: &dyn Array,
samples: &[mts_sample_mapping_t],
property: Range<usize>,
) {
use ndarray::{Axis, Slice};
let property_axis = self.shape().len() - 2;
let input = input.as_any().downcast_ref::<ndarray::ArrayD<f64>>().expect("input must be a ndarray");
for sample in samples {
let value = input.index_axis(Axis(0), sample.input);
let mut output_location = self.index_axis_mut(Axis(0), sample.output);
let mut output_location = output_location.slice_axis_mut(
Axis(property_axis), Slice::from(property.clone())
);
output_location.assign(&value);
}
}
}
#[derive(Debug, Clone)]
pub struct EmptyArray {
shape: Vec<usize>,
}
impl EmptyArray {
pub fn new(shape: Vec<usize>) -> EmptyArray {
EmptyArray { shape }
}
}
impl Array for EmptyArray {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn data(&mut self) -> &mut [f64] {
panic!("can not call Array::data() for EmptyArray");
}
fn create(&self, shape: &[usize]) -> Box<dyn Array> {
Box::new(EmptyArray { shape: shape.to_vec() })
}
fn copy(&self) -> Box<dyn Array> {
Box::new(EmptyArray { shape: self.shape.clone() })
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn reshape(&mut self, shape: &[usize]) {
self.shape = shape.to_vec();
}
fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
self.shape.swap(axis_1, axis_2);
}
fn move_samples_from(&mut self, _: &dyn Array, _: &[mts_sample_mapping_t], _: Range<usize>) {
panic!("can not call Array::move_samples_from() for EmptyArray");
}
}