1use std::ptr::NonNull;
2use std::sync::{Arc, RwLock};
3
4use ndarray::ArrayD;
5
6use dlpk::sys::DLDevice;
7
8use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t};
9use crate::Error;
10
11use super::external::{check_status_external, MtsArray};
12use super::origin::get_data_origin;
13
14#[derive(Debug, Clone, Copy)]
20pub struct ArrayRef<'a> {
21 array: mts_array_t,
22 marker: std::marker::PhantomData<&'a mts_array_t>,
24}
25
26impl<'a> ArrayRef<'a> {
27 pub unsafe fn from_raw(array: mts_array_t) -> ArrayRef<'a> {
33 ArrayRef {
34 array: mts_array_t {
35 destroy: None,
39 ..array
40 },
41 marker: std::marker::PhantomData,
42 }
43 }
44
45 #[inline]
50 pub fn as_any(&self) -> &dyn std::any::Any {
51 let origin = self.origin().unwrap_or(0);
52 assert_eq!(
53 origin, *super::array::RUST_DATA_ORIGIN,
54 "this array was not created as a rust Array (origin is '{}')",
55 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
56 );
57
58 let array = self.array.ptr.cast::<super::array::RustArray>();
59 unsafe {
60 return (*array).as_any();
61 }
62 }
63
64 #[inline]
70 pub fn to_any(self) -> &'a dyn std::any::Any {
71 let origin = self.origin().unwrap_or(0);
72 assert_eq!(
73 origin, *super::array::RUST_DATA_ORIGIN,
74 "this array was not created as a rust Array (origin is '{}')",
75 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
76 );
77
78 let array = self.array.ptr.cast::<super::array::RustArray>();
79 unsafe {
80 return (*array).as_any();
81 }
82 }
83
84 #[inline]
90 pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
91 self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
92 }
93
94 #[inline]
100 pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
101 self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
102 }
103
104 pub fn as_raw(&self) -> &mts_array_t {
106 &self.array
107 }
108
109 pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
113 let function = self.array.origin.expect("mts_array_t.origin function is NULL");
114
115 let mut origin = 0;
116 unsafe {
117 check_status_external(
118 function(self.array.ptr, &mut origin),
119 "mts_array_t.origin",
120 )?;
121 }
122
123 return Ok(origin);
124 }
125
126 pub fn device(&self) -> Result<DLDevice, Error> {
130 let function = self.array.device.expect("mts_array_t.device function is NULL");
131
132 let mut device = DLDevice::cpu();
133 unsafe {
134 check_status_external(
135 function(self.array.ptr, &mut device),
136 "mts_array_t.device",
137 )?;
138 }
139
140 return Ok(device);
141 }
142
143 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
147 let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
148
149 let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
150 unsafe {
151 check_status_external(
152 function(self.array.ptr, &mut dtype),
153 "mts_array_t.dtype",
154 )?;
155 }
156
157 return Ok(dtype);
158 }
159
160 pub fn as_dlpack(
164 &self,
165 device: DLDevice,
166 stream: Option<i64>,
167 max_version: dlpk::sys::DLPackVersion,
168 ) -> Result<dlpk::DLPackTensor, Error> {
169 let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
170
171 let mut tensor = std::ptr::null_mut();
172 let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
173
174 unsafe {
175 check_status_external(
176 function(self.array.ptr, &mut tensor, device, stream_c, max_version),
177 "mts_array_t.as_dlpack",
178 )?;
179 }
180
181 let tensor = NonNull::new(tensor).expect("got a NULL DLManagedTensorVersioned from `as_dlpack`");
182 let tensor = unsafe {
183 dlpk::DLPackTensor::from_ptr(tensor)
184 };
185
186 return Ok(tensor);
187 }
188
189 pub fn shape(&self) -> Result<&[usize], Error> {
193 let function = self.array.shape.expect("mts_array_t.shape function is NULL");
194
195 let mut shape = std::ptr::null();
196 let mut shape_count: usize = 0;
197
198 unsafe {
199 check_status_external(
200 function(self.array.ptr, &mut shape, &mut shape_count),
201 "mts_array_t.shape"
202 )?;
203 }
204
205 if shape_count == 0 {
206 return Ok(&[]);
207 } else {
208 assert!(!shape.is_null());
209 let shape = unsafe {
210 std::slice::from_raw_parts(shape, shape_count)
211 };
212 return Ok(shape);
213 }
214 }
215
216 pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
221 let function = self.array.create.expect("mts_array_t.create function is NULL");
222
223 let mut new_array = mts_array_t::null();
224 unsafe {
225 check_status_external(
226 function(
227 self.array.ptr,
228 shape.as_ptr(),
229 shape.len(),
230 *fill_value.as_raw(),
231 &mut new_array
232 ),
233 "mts_array_t.create",
234 )?;
235 }
236
237 return Ok(MtsArray::from_raw(new_array));
238 }
239
240 pub fn copy(&self) -> Result<MtsArray, Error> {
244 let function = self.array.copy.expect("mts_array_t.copy function is NULL");
245 let mut new_array = mts_array_t::null();
246 unsafe {
247 check_status_external(
248 function(self.array.ptr, &mut new_array),
249 "mts_array_t.copy",
250 )?;
251 }
252
253 return Ok(MtsArray::from_raw(new_array));
254 }
255}
256
257#[derive(Debug)]
263pub struct ArrayRefMut<'a> {
264 array: mts_array_t,
265 marker: std::marker::PhantomData<&'a mut mts_array_t>,
267}
268
269impl<'a> ArrayRefMut<'a> {
270 #[inline]
278 pub unsafe fn from_raw(array: mts_array_t) -> ArrayRefMut<'a> {
279 ArrayRefMut {
280 array: mts_array_t {
281 destroy: None,
285 ..array
286 },
287 marker: std::marker::PhantomData,
288 }
289 }
290
291 #[inline]
296 pub fn as_any(&self) -> &dyn std::any::Any {
297 let origin = self.origin().unwrap_or(0);
298 assert_eq!(
299 origin, *super::array::RUST_DATA_ORIGIN,
300 "this array was not created as a rust Array (origin is '{}')",
301 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
302 );
303
304 let array = self.array.ptr.cast::<super::array::RustArray>();
305 unsafe {
306 return (*array).as_any();
307 }
308 }
309
310 #[inline]
316 pub fn to_any(self) -> &'a dyn std::any::Any {
317 let origin = self.origin().unwrap_or(0);
318 assert_eq!(
319 origin, *super::array::RUST_DATA_ORIGIN,
320 "this array was not created as a rust Array (origin is '{}')",
321 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
322 );
323
324 let array = self.array.ptr.cast::<super::array::RustArray>();
325 unsafe {
326 return (*array).as_any();
327 }
328 }
329
330 #[inline]
335 pub fn to_any_mut(self) -> &'a mut dyn std::any::Any {
336 let origin = self.origin().unwrap_or(0);
337 assert_eq!(
338 origin, *super::array::RUST_DATA_ORIGIN,
339 "this array was not created as a rust Array (origin is '{}')",
340 get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
341 );
342
343 let array = self.array.ptr.cast::<super::array::RustArray>();
344 unsafe {
345 return (*array).as_any_mut();
346 }
347 }
348
349 #[inline]
355 pub fn as_ndarray_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
356 self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
357 }
358
359 #[inline]
365 pub fn to_ndarray_lock<T>(self) -> &'a Arc<RwLock<ArrayD<T>>> where T: 'static {
366 self.to_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
367 }
368
369 #[inline]
380 pub fn get_ndarray_mut<T>(self) -> &'a mut ArrayD<T> where T: 'static {
381 let arc = self.to_any_mut().downcast_mut::<Arc<RwLock<ArrayD<T>>>>().expect("this is not an Arc<RwLock<ArrayD>>");
382 let lock = Arc::get_mut(arc).expect("the outer Arc already has multiple owners");
383 return lock.get_mut().expect("lock was poisoned");
384 }
385
386 pub fn as_raw(&self) -> &mts_array_t {
388 &self.array
389 }
390
391 pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
393 &mut self.array
394 }
395
396 pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
400 let function = self.array.origin.expect("mts_array_t.origin function is NULL");
401
402 let mut origin = 0;
403 unsafe {
404 check_status_external(
405 function(self.array.ptr, &mut origin),
406 "mts_array_t.origin",
407 )?;
408 }
409
410 return Ok(origin);
411 }
412
413 pub fn device(&self) -> Result<DLDevice, Error> {
417 let function = self.array.device.expect("mts_array_t.device function is NULL");
418
419 let mut device = DLDevice::cpu();
420 unsafe {
421 check_status_external(
422 function(self.array.ptr, &mut device),
423 "mts_array_t.device",
424 )?;
425 }
426
427 return Ok(device);
428 }
429
430 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
434 let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
435
436 let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
437 unsafe {
438 check_status_external(
439 function(self.array.ptr, &mut dtype),
440 "mts_array_t.dtype",
441 )?;
442 }
443
444 return Ok(dtype);
445 }
446
447 pub fn as_dlpack(
451 &self,
452 device: DLDevice,
453 stream: Option<i64>,
454 max_version: dlpk::sys::DLPackVersion,
455 ) -> Result<dlpk::DLPackTensor, Error> {
456 let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
457
458 let mut tensor = std::ptr::null_mut();
459 let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
460
461 unsafe {
462 check_status_external(
463 function(self.array.ptr, &mut tensor, device, stream_c, max_version),
464 "mts_array_t.as_dlpack",
465 )?;
466 }
467
468 let tensor = NonNull::new(tensor).expect("got a NULL DLManagedTensorVersioned from `as_dlpack`");
469 let tensor = unsafe {
470 dlpk::DLPackTensor::from_ptr(tensor)
471 };
472
473 return Ok(tensor);
474 }
475
476 pub fn shape(&self) -> Result<&[usize], Error> {
480 let function = self.array.shape.expect("mts_array_t.shape function is NULL");
481
482 let mut shape = std::ptr::null();
483 let mut shape_count: usize = 0;
484
485 unsafe {
486 check_status_external(
487 function(self.array.ptr, &mut shape, &mut shape_count),
488 "mts_array_t.shape"
489 )?;
490 }
491
492 if shape_count == 0 {
493 return Ok(&[]);
494 } else {
495 assert!(!shape.is_null());
496 let shape = unsafe {
497 std::slice::from_raw_parts(shape, shape_count)
498 };
499 return Ok(shape);
500 }
501 }
502
503 pub fn reshape(&mut self, shape: &[usize]) -> Result<(), Error> {
507 let function = self.array.reshape.expect("mts_array_t.reshape function is NULL");
508
509 unsafe {
510 check_status_external(
511 function(self.array.ptr, shape.as_ptr(), shape.len()),
512 "mts_array_t.reshape",
513 )?;
514 }
515
516 return Ok(());
517 }
518
519 pub fn swap_axes(&mut self, axis_1: usize, axis_2: usize) -> Result<(), Error> {
523 let function = self.array.swap_axes.expect("mts_array_t.swap_axes function is NULL");
524
525 unsafe {
526 check_status_external(
527 function(self.array.ptr, axis_1, axis_2),
528 "mts_array_t.swap_axes",
529 )?;
530 }
531
532 return Ok(());
533 }
534
535 pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
540 let function = self.array.create.expect("mts_array_t.create function is NULL");
541
542 let mut new_array = mts_array_t::null();
543 unsafe {
544 check_status_external(
545 function(
546 self.array.ptr,
547 shape.as_ptr(),
548 shape.len(),
549 *fill_value.as_raw(),
550 &mut new_array
551 ),
552 "mts_array_t.create",
553 )?;
554 }
555
556 return Ok(MtsArray::from_raw(new_array));
557 }
558
559 pub fn copy(&self) -> Result<MtsArray, Error> {
563 let function = self.array.copy.expect("mts_array_t.copy function is NULL");
564 let mut new_array = mts_array_t::null();
565 unsafe {
566 check_status_external(
567 function(self.array.ptr, &mut new_array),
568 "mts_array_t.copy",
569 )?;
570 }
571
572 return Ok(MtsArray::from_raw(new_array));
573 }
574
575 pub fn move_data<'input>(
579 &mut self,
580 input: impl Into<ArrayRef<'input>>,
581 moves: &[mts_data_movement_t],
582 ) -> Result<(), Error> {
583 let function = self.array.move_data.expect("mts_array_t.move_data function is NULL");
584
585 let input = input.into();
586 unsafe {
587 check_status_external(
588 function(self.array.ptr, input.as_raw().ptr, moves.as_ptr(), moves.len()),
589 "mts_array_t.move_data",
590 )?;
591 }
592
593 return Ok(());
594 }
595}