1use crate::c_api::mts_block_t;
2use crate::errors::check_status;
3use crate::{ArrayRef, Error, Labels, MtsArray};
4
5use super::{TensorBlockRef, TensorBlockRefMut};
6
7#[derive(Debug)]
10#[repr(transparent)]
11pub struct TensorBlock {
12 ptr: *mut mts_block_t,
13}
14
15unsafe impl Send for TensorBlock {}
17unsafe impl Sync for TensorBlock {}
19
20impl std::ops::Drop for TensorBlock {
21 #[allow(unused_must_use)]
22 fn drop(&mut self) {
23 unsafe {
24 crate::c_api::mts_block_free(self.as_mut_ptr());
25 }
26 }
27}
28
29impl TensorBlock {
30 pub unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlock {
40 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
41
42 TensorBlock {
43 ptr: ptr,
44 }
45 }
46
47 pub fn into_raw(mut block: TensorBlock) -> *mut mts_block_t {
53 return std::mem::replace(&mut block.ptr, std::ptr::null_mut());
54 }
55
56 pub fn as_ptr(&self) -> *const mts_block_t {
61 self.ptr
62 }
63
64 pub fn as_mut_ptr(&mut self) -> *mut mts_block_t {
69 self.ptr
70 }
71
72 #[inline]
74 pub fn as_ref(&self) -> TensorBlockRef<'_> {
75 unsafe {
76 TensorBlockRef::from_raw(self.as_ptr())
77 }
78 }
79
80 #[inline]
82 pub fn as_ref_mut(&mut self) -> TensorBlockRefMut<'_> {
83 unsafe {
84 TensorBlockRefMut::from_raw(self.as_mut_ptr())
85 }
86 }
87
88 #[inline]
90 pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
91 self.as_ref().device()
92 }
93
94 #[inline]
96 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
97 self.as_ref().dtype()
98 }
99
100 #[inline]
102 pub fn values(&self) -> ArrayRef<'_> {
103 return self.as_ref().values();
104 }
105
106 #[inline]
108 pub fn samples(&self) -> Labels {
109 return self.as_ref().samples();
110 }
111
112 #[inline]
114 pub fn components(&self) -> Vec<Labels> {
115 return self.as_ref().components();
116 }
117
118 #[inline]
120 pub fn properties(&self) -> Labels {
121 return self.as_ref().properties();
122 }
123
124 #[inline]
128 pub fn new(
129 values: impl Into<MtsArray>,
130 samples: &Labels,
131 components: &[Labels],
132 properties: &Labels
133 ) -> Result<TensorBlock, Error> {
134 let mut c_components = Vec::new();
135 for component in components {
136 c_components.push(component.as_mts_labels_t());
137 }
138
139 let ptr = unsafe {
140 crate::c_api::mts_block(
141 values.into().into_raw(),
142 samples.as_mts_labels_t(),
143 c_components.as_ptr(),
144 c_components.len(),
145 properties.as_mts_labels_t(),
146 )
147 };
148
149 crate::errors::check_ptr(ptr)?;
150
151 return Ok(unsafe { TensorBlock::from_raw(ptr) });
152 }
153
154 #[allow(clippy::needless_pass_by_value)]
160 #[inline]
161 pub fn add_gradient(
162 &mut self,
163 parameter: &str,
164 mut gradient: TensorBlock
165 ) -> Result<(), Error> {
166 let mut parameter = parameter.to_owned().into_bytes();
167 parameter.push(b'\0');
168
169
170 let gradient_ptr = gradient.as_ref_mut().as_mut_ptr();
171 std::mem::forget(gradient);
174
175 unsafe {
176 check_status(crate::c_api::mts_block_add_gradient(
177 self.as_ref_mut().as_mut_ptr(),
178 parameter.as_ptr().cast(),
179 gradient_ptr,
180 ))?;
181 }
182
183 return Ok(());
184 }
185
186 pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorBlock, Error> {
190 return crate::io::load_block(path);
191 }
192
193 pub fn load_buffer(buffer: &[u8]) -> Result<TensorBlock, Error> {
197 return crate::io::load_block_buffer(buffer);
198 }
199
200 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
204 self.as_ref().save(path)
205 }
206
207 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
211 self.as_ref().save_buffer(buffer)
212 }
213}
214
215
216#[cfg(test)]
217mod tests {
218 use crate::c_api::mts_block_t;
219 use super::*;
220
221 #[test]
222 fn block() {
223 let block = TensorBlock::new(
224 ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
225 &Labels::new(["samples"], [[0], [1]]),
226 &[Labels::new(["component"], [[0]])],
227 &Labels::new(["properties"], [[-2], [0], [1]]),
228 ).unwrap();
229
230 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
231 assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
232 assert_eq!(block.samples(), Labels::new(["samples"], [[0], [1]]));
233 assert_eq!(block.components(), [Labels::new(["component"], [[0]])]);
234 assert_eq!(block.properties(), Labels::new(["properties"], [[-2], [0], [1]]));
235 }
236
237 #[test]
238 fn check_repr() {
239 assert_eq!(std::mem::size_of::<TensorBlock>(), std::mem::size_of::<*mut mts_block_t>());
242 assert_eq!(std::mem::align_of::<TensorBlock>(), std::mem::align_of::<*mut mts_block_t>());
243 }
244
245 #[test]
246 fn device_and_dtype() {
247 let block = TensorBlock::new(
248 ndarray::Array::from_elem(vec![2, 3], 1.0),
249 &Labels::new(["samples"], [[0], [1]]),
250 &[],
251 &Labels::new(["properties"], [[-2], [0], [1]]),
252 ).unwrap();
253
254 let device = block.device().unwrap();
255 assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
256
257 let dtype = block.dtype().unwrap();
258 assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
259 assert_eq!(dtype.bits, 64);
260 }
261
262 #[test]
263 fn tensor_block_into_raw() {
264 let block = TensorBlock::new(
265 ndarray::Array::from_elem(vec![3, 2], 1.0),
266 &Labels::new(["samples"], [[0], [1], [4]]),
267 &[],
268 &Labels::new(["properties"], [[5], [3]]),
269 ).unwrap();
270
271 let raw = TensorBlock::into_raw(block);
272
273 let recovered = unsafe { TensorBlock::from_raw(raw) };
274 assert_eq!(recovered.samples(), Labels::new(["samples"], [[0], [1], [4]]));
275 assert_eq!(recovered.properties(), Labels::new(["properties"], [[5], [3]]));
276 }
277}