1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t};
5use crate::c_api::MTS_INVALID_PARAMETER_ERROR;
6
7use crate::errors::check_status;
8use crate::{ArrayRef, Labels, Error};
9
10use super::{TensorBlock, LazyMetadata};
11
12#[derive(Debug, Clone, Copy)]
14pub struct TensorBlockRef<'a> {
15 ptr: *const mts_block_t,
16 marker: std::marker::PhantomData<&'a mts_block_t>,
17}
18
19unsafe impl Send for TensorBlockRef<'_> {}
21unsafe impl Sync for TensorBlockRef<'_> {}
23
24#[derive(Debug)]
34pub struct TensorBlockData<'a> {
35 pub values: ArrayRef<'a>,
36 pub samples: LazyMetadata<Labels>,
37 pub components: LazyMetadata<Vec<Labels>>,
38 pub properties: LazyMetadata<Labels>,
39}
40
41impl<'a> TensorBlockRef<'a> {
42 pub unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
48 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
49
50 TensorBlockRef {
51 ptr: ptr,
52 marker: std::marker::PhantomData,
53 }
54 }
55
56 pub fn as_ptr(&self) -> *const mts_block_t {
58 self.ptr
59 }
60}
61
62fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
64 let mut gradient_block = std::ptr::null_mut();
65 let status = unsafe { crate::c_api::mts_block_gradient(
66 block.cast_mut(),
69 parameter.as_ptr(),
70 &mut gradient_block
71 )
72 };
73
74 match crate::errors::check_status(status) {
75 Ok(()) => Some(gradient_block.cast_const()),
76 Err(error) => {
77 if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
78 None
80 } else {
81 panic!("failed to get the gradient from a block: {:?}", error)
82 }
83 }
84 }
85}
86
87pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
88 unsafe {
89 TensorBlockRef::from_raw(ptr).samples()
90 }
91}
92
93pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
94 unsafe {
95 TensorBlockRef::from_raw(ptr).components()
96 }
97}
98
99pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
100 unsafe {
101 TensorBlockRef::from_raw(ptr).properties()
102 }
103}
104
105impl<'a> TensorBlockRef<'a> {
106 #[inline]
108 pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
109 let mut device = dlpk::sys::DLDevice::cpu();
110 unsafe {
111 check_status(crate::c_api::mts_block_device(
112 self.as_ptr(),
113 &mut device,
114 ))?;
115 }
116 return Ok(device);
117 }
118
119 #[inline]
121 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
122 let mut dtype = dlpk::sys::DLDataType {
123 code: dlpk::sys::DLDataTypeCode::kDLFloat,
124 bits: 0,
125 lanes: 0,
126 };
127 unsafe {
128 check_status(crate::c_api::mts_block_dtype(
129 self.as_ptr(),
130 &mut dtype,
131 ))?;
132 }
133 return Ok(dtype);
134 }
135
136 #[inline]
139 pub fn data(&'a self) -> TensorBlockData<'a> {
140 TensorBlockData {
141 values: self.values(),
142 samples: LazyMetadata::new(get_samples, self.as_ptr()),
143 components: LazyMetadata::new(get_components, self.as_ptr()),
144 properties: LazyMetadata::new(get_properties, self.as_ptr()),
145 }
146 }
147
148 #[inline]
150 pub fn values(&self) -> ArrayRef<'a> {
151 let mut array = mts_array_t::null();
152 unsafe {
153 crate::errors::check_status(crate::c_api::mts_block_data(
154 self.as_ptr().cast_mut(),
155 &mut array
156 )).expect("failed to get the array for a block");
157 };
158
159 unsafe { ArrayRef::from_raw(array) }
165 }
166
167 #[inline]
168 fn labels(&self, dimension: usize) -> Labels {
169 let ptr = unsafe {
170 crate::c_api::mts_block_labels(self.as_ptr(), dimension)
171 };
172 crate::errors::check_ptr(ptr).expect("failed to get labels");
173
174 unsafe { Labels::from_raw(ptr) }
175 }
176
177 #[inline]
179 pub fn samples(&self) -> Labels {
180 return self.labels(0);
181 }
182
183 #[inline]
185 pub fn components(&self) -> Vec<Labels> {
186 let values = self.values();
187 let shape = values.shape().expect("failed to get the data shape");
188
189 let mut result = Vec::new();
190 for i in 1..(shape.len() - 1) {
191 result.push(self.labels(i));
192 }
193 return result;
194 }
195
196 #[inline]
198 pub fn properties(&self) -> Labels {
199 let values = self.values();
200 let shape = values.shape().expect("failed to get the data shape");
201
202 return self.labels(shape.len() - 1);
203 }
204
205 #[inline]
207 pub fn gradient_list(&self) -> Vec<&'a str> {
208 let mut parameters_ptr = std::ptr::null();
209 let mut parameters_count = 0;
210 unsafe {
211 check_status(crate::c_api::mts_block_gradients_list(
212 self.as_ptr(),
213 &mut parameters_ptr,
214 &mut parameters_count
215 )).expect("failed to get gradient list");
216 }
217
218 if parameters_count == 0 {
219 return Vec::new();
220 } else {
221 assert!(!parameters_ptr.is_null());
222 unsafe {
226 let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
227 return parameters.iter()
228 .map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
229 .collect();
230 }
231 }
232 }
233
234 #[inline]
237 pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
238 let parameter = CString::new(parameter).expect("invalid C string");
241
242 block_gradient(self.as_ptr(), ¶meter)
243 .map(|gradient_block| {
244 unsafe { TensorBlockRef::from_raw(gradient_block) }
248 })
249 }
250
251 #[inline]
256 pub fn try_clone(&self) -> Result<TensorBlock, Error> {
257 let ptr = unsafe {
258 crate::c_api::mts_block_copy(self.as_ptr())
259 };
260 crate::errors::check_ptr(ptr)?;
261
262 return Ok(unsafe { TensorBlock::from_raw(ptr) });
263 }
264
265 #[inline]
268 pub fn gradients(&self) -> GradientsIter<'_> {
269 GradientsIter {
270 parameters: self.gradient_list().into_iter(),
271 block: self.as_ptr(),
272 }
273 }
274
275 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
279 return crate::io::save_block(path, *self);
280 }
281
282 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
286 return crate::io::save_block_buffer(*self, buffer);
287 }
288}
289
290pub struct GradientsIter<'a> {
293 parameters: std::vec::IntoIter<&'a str>,
294 block: *const mts_block_t,
295}
296
297impl<'a> Iterator for GradientsIter<'a> {
298 type Item = (&'a str, TensorBlockRef<'a>);
299
300 #[inline]
301 fn next(&mut self) -> Option<Self::Item> {
302 self.parameters.next().map(|parameter| {
303 let parameter_c = CString::new(parameter).expect("invalid C string");
304 let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
305
306 let block = unsafe { TensorBlockRef::from_raw(block) };
310 return (parameter, block);
311 })
312 }
313
314 fn size_hint(&self) -> (usize, Option<usize>) {
315 (self.len(), Some(self.len()))
316 }
317}
318
319impl ExactSizeIterator for GradientsIter<'_> {
320 #[inline]
321 fn len(&self) -> usize {
322 self.parameters.len()
323 }
324}
325
326impl FusedIterator for GradientsIter<'_> {}
327
328#[cfg(test)]
329mod tests {
330 use crate::{Labels, TensorBlock};
331
332 #[test]
333 #[allow(clippy::float_cmp)]
334 fn gradients() {
335 let properties = Labels::new(["p"], [[-2], [0], [1]]);
336 let mut block = TensorBlock::new(
337 ndarray::Array::from_elem(vec![2, 3], 1.0),
338 &Labels::new(["s"], [[0], [1]]), &[], &properties,
339 ).unwrap();
340
341 block.add_gradient("g", TensorBlock::new(
342 ndarray::Array::from_elem(vec![2, 3], -1.0),
343 &Labels::new(["sample"], [[0], [1]]), &[], &properties,
344 ).unwrap()).unwrap();
345
346 block.add_gradient("f", TensorBlock::new(
347 ndarray::Array::from_elem(vec![2, 3], -2.0),
348 &Labels::new(["sample"], [[0], [1]]), &[], &properties,
349 ).unwrap()).unwrap();
350
351
352 let block = block.as_ref();
353 let gradient = block.gradient("g").unwrap();
354 let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
355 assert_eq!(values[[0, 0]], -1.0);
356
357 let gradient = block.gradient("f").unwrap();
358 let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
359 assert_eq!(values[[0, 0]], -2.0);
360
361 assert!(block.gradient("h").is_none());
362
363 let mut iter = block.gradients();
364 assert_eq!(iter.len(), 2);
365
366 assert_eq!(iter.next().unwrap().0, "g");
367 assert_eq!(iter.next().unwrap().0, "f");
368 assert!(iter.next().is_none());
369 }
370
371 #[test]
372 fn block_data() {
373 let block = TensorBlock::new(
374 ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
375 &Labels::new(["samples"], [[0], [1]]),
376 &[Labels::new(["component"], [[0]])],
377 &Labels::new(["properties"], [[-2], [0], [1]]),
378 ).unwrap();
379 let block = block.as_ref();
380
381 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
382 assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
383 assert_eq!(block.samples(), Labels::new(["samples"], [[0], [1]]));
384 assert_eq!(block.components(), [Labels::new(["component"], [[0]])]);
385 assert_eq!(block.properties(), Labels::new(["properties"], [[-2], [0], [1]]));
386
387 let block = block.data();
388 let values = block.values.to_ndarray_lock::<f64>().read().unwrap();
389 assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
390 assert_eq!(*block.samples, Labels::new(["samples"], [[0], [1]]));
391 assert_eq!(*block.components, [Labels::new(["component"], [[0]])]);
392 assert_eq!(*block.properties, Labels::new(["properties"], [[-2], [0], [1]]));
393 }
394
395 #[test]
396 fn device_and_dtype() {
397 let block = TensorBlock::new(
398 ndarray::Array::from_elem(vec![2, 3], 1.0),
399 &Labels::new(["samples"], [[0], [1]]),
400 &[],
401 &Labels::new(["properties"], [[-2], [0], [1]]),
402 ).unwrap();
403 let block = block.as_ref();
404
405 let device = block.device().unwrap();
406 assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
407
408 let dtype = block.dtype().unwrap();
409 assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
410 assert_eq!(dtype.bits, 64);
411 }
412}