1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t};
5use crate::c_api::MTS_INVALID_PARAMETER_ERROR;
6
7use crate::errors::check_status;
8use crate::{ArrayRef, Labels, Error};
9
10use super::{TensorBlock, LazyMetadata};
11
12#[derive(Debug, Clone, Copy)]
14pub struct TensorBlockRef<'a> {
15 ptr: *const mts_block_t,
16 marker: std::marker::PhantomData<&'a mts_block_t>,
17}
18
19unsafe impl Send for TensorBlockRef<'_> {}
21unsafe impl Sync for TensorBlockRef<'_> {}
23
24#[derive(Debug)]
34pub struct TensorBlockData<'a> {
35 pub values: ArrayRef<'a>,
36 pub samples: LazyMetadata<Labels>,
37 pub components: LazyMetadata<Vec<Labels>>,
38 pub properties: LazyMetadata<Labels>,
39}
40
41impl<'a> TensorBlockRef<'a> {
42 pub(crate) unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
48 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
49
50 TensorBlockRef {
51 ptr: ptr,
52 marker: std::marker::PhantomData,
53 }
54 }
55
56 pub(crate) fn as_ptr(&self) -> *const mts_block_t {
58 self.ptr
59 }
60}
61
62fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
64 let mut gradient_block = std::ptr::null_mut();
65 let status = unsafe { crate::c_api::mts_block_gradient(
66 block.cast_mut(),
69 parameter.as_ptr(),
70 &mut gradient_block
71 )
72 };
73
74 match crate::errors::check_status(status) {
75 Ok(()) => Some(gradient_block.cast_const()),
76 Err(error) => {
77 if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
78 None
80 } else {
81 panic!("failed to get the gradient from a block: {:?}", error)
82 }
83 }
84 }
85}
86
87pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
88 unsafe {
89 TensorBlockRef::from_raw(ptr).samples()
90 }
91}
92
93pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
94 unsafe {
95 TensorBlockRef::from_raw(ptr).components()
96 }
97}
98
99pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
100 unsafe {
101 TensorBlockRef::from_raw(ptr).properties()
102 }
103}
104
105impl<'a> TensorBlockRef<'a> {
106 #[inline]
109 pub fn data(&'a self) -> TensorBlockData<'a> {
110 TensorBlockData {
111 values: self.values(),
112 samples: LazyMetadata::new(get_samples, self.as_ptr()),
113 components: LazyMetadata::new(get_components, self.as_ptr()),
114 properties: LazyMetadata::new(get_properties, self.as_ptr()),
115 }
116 }
117
118 #[inline]
120 pub fn values(&self) -> ArrayRef<'a> {
121 let mut array = mts_array_t::null();
122 unsafe {
123 crate::errors::check_status(crate::c_api::mts_block_data(
124 self.as_ptr().cast_mut(),
125 &mut array
126 )).expect("failed to get the array for a block");
127 };
128
129 unsafe { ArrayRef::from_raw(array) }
135 }
136
137 #[inline]
138 fn labels(&self, dimension: usize) -> Labels {
139 let ptr = unsafe {
140 crate::c_api::mts_block_labels(self.as_ptr(), dimension)
141 };
142 crate::errors::check_ptr(ptr).expect("failed to get labels");
143
144 unsafe { Labels::from_raw(ptr) }
145 }
146
147 #[inline]
149 pub fn samples(&self) -> Labels {
150 return self.labels(0);
151 }
152
153 #[inline]
155 pub fn components(&self) -> Vec<Labels> {
156 let values = self.values();
157 let shape = values.shape().expect("failed to get the data shape");
158
159 let mut result = Vec::new();
160 for i in 1..(shape.len() - 1) {
161 result.push(self.labels(i));
162 }
163 return result;
164 }
165
166 #[inline]
168 pub fn properties(&self) -> Labels {
169 let values = self.values();
170 let shape = values.shape().expect("failed to get the data shape");
171
172 return self.labels(shape.len() - 1);
173 }
174
175 #[inline]
177 pub fn gradient_list(&self) -> Vec<&'a str> {
178 let mut parameters_ptr = std::ptr::null();
179 let mut parameters_count = 0;
180 unsafe {
181 check_status(crate::c_api::mts_block_gradients_list(
182 self.as_ptr(),
183 &mut parameters_ptr,
184 &mut parameters_count
185 )).expect("failed to get gradient list");
186 }
187
188 if parameters_count == 0 {
189 return Vec::new();
190 } else {
191 assert!(!parameters_ptr.is_null());
192 unsafe {
196 let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
197 return parameters.iter()
198 .map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
199 .collect();
200 }
201 }
202 }
203
204 #[inline]
207 pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
208 let parameter = CString::new(parameter).expect("invalid C string");
211
212 block_gradient(self.as_ptr(), ¶meter)
213 .map(|gradient_block| {
214 unsafe { TensorBlockRef::from_raw(gradient_block) }
218 })
219 }
220
221 #[inline]
226 pub fn try_clone(&self) -> Result<TensorBlock, Error> {
227 let ptr = unsafe {
228 crate::c_api::mts_block_copy(self.as_ptr())
229 };
230 crate::errors::check_ptr(ptr)?;
231
232 return Ok(unsafe { TensorBlock::from_raw(ptr) });
233 }
234
235 #[inline]
238 pub fn gradients(&self) -> GradientsIter<'_> {
239 GradientsIter {
240 parameters: self.gradient_list().into_iter(),
241 block: self.as_ptr(),
242 }
243 }
244
245 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
249 return crate::io::save_block(path, *self);
250 }
251
252 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
256 return crate::io::save_block_buffer(*self, buffer);
257 }
258}
259
260pub struct GradientsIter<'a> {
263 parameters: std::vec::IntoIter<&'a str>,
264 block: *const mts_block_t,
265}
266
267impl<'a> Iterator for GradientsIter<'a> {
268 type Item = (&'a str, TensorBlockRef<'a>);
269
270 #[inline]
271 fn next(&mut self) -> Option<Self::Item> {
272 self.parameters.next().map(|parameter| {
273 let parameter_c = CString::new(parameter).expect("invalid C string");
274 let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
275
276 let block = unsafe { TensorBlockRef::from_raw(block) };
280 return (parameter, block);
281 })
282 }
283
284 fn size_hint(&self) -> (usize, Option<usize>) {
285 (self.len(), Some(self.len()))
286 }
287}
288
289impl ExactSizeIterator for GradientsIter<'_> {
290 #[inline]
291 fn len(&self) -> usize {
292 self.parameters.len()
293 }
294}
295
296impl FusedIterator for GradientsIter<'_> {}
297
298#[cfg(test)]
299mod tests {
300 use crate::{Labels, TensorBlock};
301
302 #[test]
303 #[allow(clippy::float_cmp)]
304 fn gradients() {
305 let properties = Labels::new(["p"], &[[-2], [0], [1]]);
306 let mut block = TensorBlock::new(
307 ndarray::Array::from_elem(vec![2, 3], 1.0),
308 &Labels::new(["s"], &[[0], [1]]), &[], &properties,
309 ).unwrap();
310
311 block.add_gradient("g", TensorBlock::new(
312 ndarray::Array::from_elem(vec![2, 3], -1.0),
313 &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
314 ).unwrap()).unwrap();
315
316 block.add_gradient("f", TensorBlock::new(
317 ndarray::Array::from_elem(vec![2, 3], -2.0),
318 &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
319 ).unwrap()).unwrap();
320
321
322 let block = block.as_ref();
323 let gradient = block.gradient("g").unwrap();
324 let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
325 assert_eq!(values[[0, 0]], -1.0);
326
327 let gradient = block.gradient("f").unwrap();
328 let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
329 assert_eq!(values[[0, 0]], -2.0);
330
331 assert!(block.gradient("h").is_none());
332
333 let mut iter = block.gradients();
334 assert_eq!(iter.len(), 2);
335
336 assert_eq!(iter.next().unwrap().0, "g");
337 assert_eq!(iter.next().unwrap().0, "f");
338 assert!(iter.next().is_none());
339 }
340
341 #[test]
342 fn block_data() {
343 let block = TensorBlock::new(
344 ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
345 &Labels::new(["samples"], &[[0], [1]]),
346 &[Labels::new(["component"], &[[0]])],
347 &Labels::new(["properties"], &[[-2], [0], [1]]),
348 ).unwrap();
349 let block = block.as_ref();
350
351 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
352 assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
353 assert_eq!(block.samples(), Labels::new(["samples"], &[[0], [1]]));
354 assert_eq!(block.components(), [Labels::new(["component"], &[[0]])]);
355 assert_eq!(block.properties(), Labels::new(["properties"], &[[-2], [0], [1]]));
356
357 let block = block.data();
358 let values = block.values.to_ndarray_lock::<f64>().read().unwrap();
359 assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
360 assert_eq!(*block.samples, Labels::new(["samples"], &[[0], [1]]));
361 assert_eq!(*block.components, [Labels::new(["component"], &[[0]])]);
362 assert_eq!(*block.properties, Labels::new(["properties"], &[[-2], [0], [1]]));
363 }
364}