1use std::sync::{Arc, RwLock};
2
3use ndarray::ArrayD;
4
5use dlpk::sys::DLDevice;
6
7use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t};
8use crate::Error;
9use crate::errors::check_status;
10
11use super::external::MtsArray;
12use super::origin::get_data_origin;
13
14#[derive(Debug, Clone, Copy)]
20pub struct ArrayRef<'a> {
21 array: mts_array_t,
22 marker: std::marker::PhantomData<&'a mts_array_t>,
24}
25
26impl<'a> ArrayRef<'a> {
27 pub unsafe fn from_raw(array: mts_array_t) -> ArrayRef<'a> {
33 ArrayRef {
34 array: mts_array_t {
35 destroy: None,
39 ..array
40 },
41 marker: std::marker::PhantomData,
42 }
43 }
44
45 #[inline]
50 pub fn as_any(&self) -> &dyn std::any::Any {
51 let origin = self.origin().unwrap_or(0);
52 assert_eq!(
53 origin, *super::array::RUST_DATA_ORIGIN,
54 "this array was not created as a rust Array (origin is '{}')",
55 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
56 );
57
58 let array = self.array.ptr.cast::<super::array::RustArray>();
59 unsafe {
60 return (*array).as_any();
61 }
62 }
63
64 #[inline]
70 pub fn to_any(self) -> &'a dyn std::any::Any {
71 let origin = self.origin().unwrap_or(0);
72 assert_eq!(
73 origin, *super::array::RUST_DATA_ORIGIN,
74 "this array was not created as a rust Array (origin is '{}')",
75 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
76 );
77
78 let array = self.array.ptr.cast::<super::array::RustArray>();
79 unsafe {
80 return (*array).as_any();
81 }
82 }
83
84 #[inline]
90 pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
91 self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
92 }
93
94 #[inline]
100 pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
101 self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
102 }
103
104 pub fn as_raw(&self) -> &mts_array_t {
106 &self.array
107 }
108
109 pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
113 let function = self.array.origin.expect("mts_array_t.origin function is NULL");
114
115 let mut origin = 0;
116 unsafe {
117 check_status(function(self.array.ptr, &mut origin))?;
118 }
119
120 return Ok(origin);
121 }
122
123 pub fn device(&self) -> Result<DLDevice, Error> {
127 let function = self.array.device.expect("mts_array_t.device function is NULL");
128
129 let mut device = DLDevice::cpu();
130 unsafe {
131 check_status(function(self.array.ptr, &mut device))?;
132 }
133
134 return Ok(device);
135 }
136
137 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
141 let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
142
143 let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
144 unsafe {
145 check_status(function(self.array.ptr, &mut dtype))?;
146 }
147
148 return Ok(dtype);
149 }
150
151 pub fn as_dlpack(
155 &self,
156 device: DLDevice,
157 stream: Option<i64>,
158 max_version: dlpk::sys::DLPackVersion,
159 ) -> Result<dlpk::DLPackTensor, Error> {
160 let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
161
162 let mut tensor = std::ptr::null_mut();
163 let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
164
165 unsafe {
166 check_status(function(self.array.ptr, &mut tensor, device, stream_c, max_version))?;
167 }
168
169 let tensor = unsafe {
170 dlpk::DLPackTensor::from_ptr(tensor)
171 };
172
173 return Ok(tensor);
174 }
175
176 pub fn shape(&self) -> Result<&[usize], Error> {
180 let function = self.array.shape.expect("mts_array_t.shape function is NULL");
181
182 let mut shape = std::ptr::null();
183 let mut shape_count: usize = 0;
184
185 unsafe {
186 check_status(function(self.array.ptr, &mut shape, &mut shape_count))?;
187 }
188
189 if shape_count == 0 {
190 return Ok(&[]);
191 } else {
192 assert!(!shape.is_null());
193 let shape = unsafe {
194 std::slice::from_raw_parts(shape, shape_count)
195 };
196 return Ok(shape);
197 }
198 }
199
200 pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
205 let function = self.array.create.expect("mts_array_t.create function is NULL");
206
207 let mut new_array = mts_array_t::null();
208 unsafe {
209 check_status(function(
210 self.array.ptr,
211 shape.as_ptr(),
212 shape.len(),
213 *fill_value.as_raw(),
214 &mut new_array
215 ))?;
216 }
217
218 return Ok(MtsArray::from_raw(new_array));
219 }
220
221 pub fn copy(&self, device: DLDevice) -> Result<MtsArray, Error> {
225 let function = self.array.copy.expect("mts_array_t.copy function is NULL");
226 let mut new_array = mts_array_t::null();
227 unsafe {
228 check_status(function(self.array.ptr, device, &mut new_array))?;
229 }
230
231 return Ok(MtsArray::from_raw(new_array));
232 }
233}
234
235#[derive(Debug)]
241pub struct ArrayRefMut<'a> {
242 array: mts_array_t,
243 marker: std::marker::PhantomData<&'a mut mts_array_t>,
245}
246
247impl<'a> ArrayRefMut<'a> {
248 #[inline]
256 pub unsafe fn from_raw(array: mts_array_t) -> ArrayRefMut<'a> {
257 ArrayRefMut {
258 array: mts_array_t {
259 destroy: None,
263 ..array
264 },
265 marker: std::marker::PhantomData,
266 }
267 }
268
269 #[inline]
274 pub fn as_any(&self) -> &dyn std::any::Any {
275 let origin = self.origin().unwrap_or(0);
276 assert_eq!(
277 origin, *super::array::RUST_DATA_ORIGIN,
278 "this array was not created as a rust Array (origin is '{}')",
279 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
280 );
281
282 let array = self.array.ptr.cast::<super::array::RustArray>();
283 unsafe {
284 return (*array).as_any();
285 }
286 }
287
288 #[inline]
294 pub fn to_any(self) -> &'a dyn std::any::Any {
295 let origin = self.origin().unwrap_or(0);
296 assert_eq!(
297 origin, *super::array::RUST_DATA_ORIGIN,
298 "this array was not created as a rust Array (origin is '{}')",
299 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
300 );
301
302 let array = self.array.ptr.cast::<super::array::RustArray>();
303 unsafe {
304 return (*array).as_any();
305 }
306 }
307
308 #[inline]
313 pub fn to_any_mut(self) -> &'a mut dyn std::any::Any {
314 let origin = self.origin().unwrap_or(0);
315 assert_eq!(
316 origin, *super::array::RUST_DATA_ORIGIN,
317 "this array was not created as a rust Array (origin is '{}')",
318 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
319 );
320
321 let array = self.array.ptr.cast::<super::array::RustArray>();
322 unsafe {
323 return (*array).as_any_mut();
324 }
325 }
326
327 #[inline]
333 pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
334 self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
335 }
336
337 #[inline]
343 pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
344 self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
345 }
346
347 #[inline]
358 pub fn get_ndarray_mut<T>(self) -> &'a mut ArrayD<T> where T: 'static {
359 let arc = self.to_any_mut().downcast_mut::<Arc<RwLock<ArrayD<T>>>>().expect("this is not an Arc<RwLock<ArrayD>>");
360 let lock = Arc::get_mut(arc).expect("the outer Arc already has multiple owners");
361 return lock.get_mut().expect("lock was poisoned");
362 }
363
364 pub fn as_raw(&self) -> &mts_array_t {
366 &self.array
367 }
368
369 pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
371 &mut self.array
372 }
373
374 pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
378 let function = self.array.origin.expect("mts_array_t.origin function is NULL");
379
380 let mut origin = 0;
381 unsafe {
382 check_status(function(self.array.ptr, &mut origin))?;
383 }
384
385 return Ok(origin);
386 }
387
388 pub fn device(&self) -> Result<DLDevice, Error> {
392 let function = self.array.device.expect("mts_array_t.device function is NULL");
393
394 let mut device = DLDevice::cpu();
395 unsafe {
396 check_status(function(self.array.ptr, &mut device))?;
397 }
398
399 return Ok(device);
400 }
401
402 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
406 let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
407
408 let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
409 unsafe {
410 check_status(function(self.array.ptr, &mut dtype))?;
411 }
412
413 return Ok(dtype);
414 }
415
416 pub fn as_dlpack(
420 &self,
421 device: DLDevice,
422 stream: Option<i64>,
423 max_version: dlpk::sys::DLPackVersion,
424 ) -> Result<dlpk::DLPackTensor, Error> {
425 let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
426
427 let mut tensor = std::ptr::null_mut();
428 let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
429
430 unsafe {
431 check_status(function(
432 self.array.ptr,
433 &mut tensor,
434 device,
435 stream_c,
436 max_version
437 ))?;
438 }
439
440 let tensor = unsafe {
441 dlpk::DLPackTensor::from_ptr(tensor)
442 };
443
444 return Ok(tensor);
445 }
446
447 pub fn shape(&self) -> Result<&[usize], Error> {
451 let function = self.array.shape.expect("mts_array_t.shape function is NULL");
452
453 let mut shape = std::ptr::null();
454 let mut shape_count: usize = 0;
455
456 unsafe {
457 check_status(function(self.array.ptr, &mut shape, &mut shape_count))?;
458 }
459
460 if shape_count == 0 {
461 return Ok(&[]);
462 } else {
463 assert!(!shape.is_null());
464 let shape = unsafe {
465 std::slice::from_raw_parts(shape, shape_count)
466 };
467 return Ok(shape);
468 }
469 }
470
471 pub fn reshape(&mut self, shape: &[usize]) -> Result<(), Error> {
475 let function = self.array.reshape.expect("mts_array_t.reshape function is NULL");
476
477 unsafe {
478 check_status(function(self.array.ptr, shape.as_ptr(), shape.len()))?;
479 }
480
481 return Ok(());
482 }
483
484 pub fn swap_axes(&mut self, axis_1: usize, axis_2: usize) -> Result<(), Error> {
488 let function = self.array.swap_axes.expect("mts_array_t.swap_axes function is NULL");
489
490 unsafe {
491 check_status(function(self.array.ptr, axis_1, axis_2))?;
492 }
493
494 return Ok(());
495 }
496
497 pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
502 let function = self.array.create.expect("mts_array_t.create function is NULL");
503
504 let mut new_array = mts_array_t::null();
505 unsafe {
506 check_status(function(
507 self.array.ptr,
508 shape.as_ptr(),
509 shape.len(),
510 *fill_value.as_raw(),
511 &mut new_array
512 ))?;
513 }
514
515 return Ok(MtsArray::from_raw(new_array));
516 }
517
518 pub fn copy(&self, device: DLDevice) -> Result<MtsArray, Error> {
522 let function = self.array.copy.expect("mts_array_t.copy function is NULL");
523 let mut new_array = mts_array_t::null();
524 unsafe {
525 check_status(function(self.array.ptr, device, &mut new_array))?;
526 }
527
528 return Ok(MtsArray::from_raw(new_array));
529 }
530
531 pub fn move_data<'input>(
535 &mut self,
536 input: impl Into<ArrayRef<'input>>,
537 moves: &[mts_data_movement_t],
538 ) -> Result<(), Error> {
539 let function = self.array.move_data.expect("mts_array_t.move_data function is NULL");
540
541 let input = input.into();
542 unsafe {
543 check_status(function(
544 self.array.ptr,
545 input.as_raw().ptr,
546 moves.as_ptr(),
547 moves.len(),
548 ))?;
549 }
550
551 return Ok(());
552 }
553}