1use std::ffi::{CString, CStr};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
5use crate::{ArrayRef, ArrayRefMut, Labels, Error};
6
7use super::{TensorBlockRef, LazyMetadata};
8use super::block_ref::{get_samples, get_components, get_properties};
9
10#[derive(Debug)]
12pub struct TensorBlockRefMut<'a> {
13 ptr: *mut mts_block_t,
14 marker: std::marker::PhantomData<&'a mut mts_block_t>,
15}
16
17unsafe impl Send for TensorBlockRefMut<'_> {}
19unsafe impl Sync for TensorBlockRefMut<'_> {}
22
23#[derive(Debug)]
33pub struct TensorBlockDataMut<'a> {
34 pub values: ArrayRefMut<'a>,
35 pub samples: LazyMetadata<Labels>,
36 pub components: LazyMetadata<Vec<Labels>>,
37 pub properties: LazyMetadata<Labels>,
38}
39
40fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
42 let mut gradient_block = std::ptr::null_mut();
43 let status = unsafe { crate::c_api::mts_block_gradient(
44 block,
45 parameter.as_ptr(),
46 &mut gradient_block
47 )
48 };
49
50 match crate::errors::check_status(status) {
51 Ok(()) => Some(gradient_block),
52 Err(error) => {
53 if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
54 None
56 } else {
57 panic!("failed to get the gradient from a block: {:?}", error)
58 }
59 }
60 }
61}
62
63impl<'a> TensorBlockRefMut<'a> {
64 pub unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
72 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
73
74 TensorBlockRefMut {
75 ptr: ptr,
76 marker: std::marker::PhantomData,
77 }
78 }
79
80 pub fn as_ptr(&self) -> *const mts_block_t {
82 self.ptr
83 }
84
85 pub fn as_mut_ptr(&mut self) -> *mut mts_block_t {
87 self.ptr
88 }
89
90 #[inline]
92 pub fn as_ref(&self) -> TensorBlockRef<'_> {
93 unsafe {
94 TensorBlockRef::from_raw(self.as_ptr())
95 }
96 }
97
98 #[inline]
101 pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
102 let samples = LazyMetadata::new(get_samples, self.as_ptr());
103 let components = LazyMetadata::new(get_components, self.as_ptr());
104 let properties = LazyMetadata::new(get_properties, self.as_ptr());
105
106 TensorBlockDataMut {
107 values: self.values_mut(),
109 samples: samples,
110 components: components,
111 properties: properties,
112 }
113 }
114
115 #[inline]
117 pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
118 self.as_ref().device()
119 }
120
121 #[inline]
123 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
124 self.as_ref().dtype()
125 }
126
127 #[inline]
129 pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
130 let mut array = mts_array_t::null();
131 unsafe {
132 crate::errors::check_status(crate::c_api::mts_block_data(
133 self.as_mut_ptr(),
134 &mut array
135 )).expect("failed to get the array for a block");
136 };
137
138 unsafe { ArrayRefMut::from_raw(array) }
140 }
141
142 #[inline]
144 pub fn values(&self) -> ArrayRef<'_> {
145 return self.as_ref().values();
146 }
147
148 #[inline]
150 pub fn samples(&self) -> Labels {
151 return self.as_ref().samples();
152 }
153
154 #[inline]
156 pub fn components(&self) -> Vec<Labels> {
157 return self.as_ref().components();
158 }
159
160 #[inline]
162 pub fn properties(&self) -> Labels {
163 return self.as_ref().properties();
164 }
165
166 #[inline]
169 pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
170 let parameter = CString::new(parameter).expect("invalid C string");
171
172 block_gradient(self.as_mut_ptr(), ¶meter)
173 .map(|gradient_block| {
174 unsafe { TensorBlockRefMut::from_raw(gradient_block) }
177 })
178 }
179
180 #[inline]
183 pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
184 let block_ptr = self.as_mut_ptr();
185 GradientsMutIter {
186 parameters: self.as_ref().gradient_list().into_iter(),
187 block: block_ptr,
188 }
189 }
190
191 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
195 self.as_ref().save(path)
196 }
197
198 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
202 self.as_ref().save_buffer(buffer)
203 }
204}
205
206pub struct GradientsMutIter<'a> {
209 parameters: std::vec::IntoIter<&'a str>,
210 block: *mut mts_block_t,
211}
212
213impl<'a> Iterator for GradientsMutIter<'a> {
214 type Item = (&'a str, TensorBlockRefMut<'a>);
215
216 #[inline]
217 fn next(&mut self) -> Option<Self::Item> {
218 self.parameters.next().map(|parameter| {
219 let parameter_c = CString::new(parameter).expect("invalid C string");
220 let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
221
222 let block = unsafe { TensorBlockRefMut::from_raw(block) };
226 return (parameter, block);
227 })
228 }
229
230 fn size_hint(&self) -> (usize, Option<usize>) {
231 self.parameters.size_hint()
232 }
233}
234
235impl ExactSizeIterator for GradientsMutIter<'_> {
236 #[inline]
237 fn len(&self) -> usize {
238 self.parameters.len()
239 }
240}
241
242impl FusedIterator for GradientsMutIter<'_> {}
243
244#[cfg(test)]
245mod tests {
246 use crate::{Labels, TensorBlock};
247
248 #[test]
249 #[allow(clippy::float_cmp)]
250 fn gradients() {
251 let properties = Labels::new(["p"], [[-2], [0], [1]]);
252 let mut block = TensorBlock::new(
253 ndarray::Array::from_elem(vec![2, 3], 1.0),
254 &Labels::new(["s"], [[0], [1]]), &[], &properties,
255 ).unwrap();
256
257 block.add_gradient("g", TensorBlock::new(
258 ndarray::Array::from_elem(vec![2, 3], -1.0),
259 &Labels::new(["sample"], [[0], [1]]), &[], &properties,
260 ).unwrap()).unwrap();
261
262 block.add_gradient("f", TensorBlock::new(
263 ndarray::Array::from_elem(vec![2, 3], -2.0),
264 &Labels::new(["sample"], [[0], [1]]), &[], &properties,
265 ).unwrap()).unwrap();
266
267
268 let mut block = block.as_ref_mut();
269 {
270 let gradient = block.gradient_mut("g").unwrap();
271 let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
272 assert_eq!(values[[0, 0]], -1.0);
273 }
274
275 {
276 let gradient = block.gradient_mut("f").unwrap();
277 let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
278 assert_eq!(values[[0, 0]], -2.0);
279 }
280
281 assert!(block.gradient_mut("h").is_none());
282
283 let mut iter = block.gradients_mut();
284 assert_eq!(iter.len(), 2);
285
286 assert_eq!(iter.next().unwrap().0, "g");
287 assert_eq!(iter.next().unwrap().0, "f");
288 assert!(iter.next().is_none());
289 }
290
291 #[test]
292 fn block_data() {
293 let mut block = TensorBlock::new(
294 ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
295 &Labels::new(["samples"], [[0], [1]]),
296 &[Labels::new(["component"], [[0]])],
297 &Labels::new(["properties"], [[-2], [0], [1]]),
298 ).unwrap();
299 let mut block = block.as_ref_mut();
300
301 {
302 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
303 assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
304 assert_eq!(block.samples(), Labels::new(["samples"], [[0], [1]]));
305 assert_eq!(block.components(), [Labels::new(["component"], [[0]])]);
306 assert_eq!(block.properties(), Labels::new(["properties"], [[-2], [0], [1]]));
307 }
308
309 let block = block.data_mut();
310 let values = block.values.to_ndarray_lock::<f64>().read().unwrap();
311 assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
312 assert_eq!(*block.samples, Labels::new(["samples"], [[0], [1]]));
313 assert_eq!(*block.components, [Labels::new(["component"], [[0]])]);
314 assert_eq!(*block.properties, Labels::new(["properties"], [[-2], [0], [1]]));
315 }
316
317 #[test]
318 fn device_and_dtype() {
319 let mut block = TensorBlock::new(
320 ndarray::Array::from_elem(vec![2, 3], 1.0),
321 &Labels::new(["samples"], [[0], [1]]),
322 &[],
323 &Labels::new(["properties"], [[-2], [0], [1]]),
324 ).unwrap();
325 let block = block.as_ref_mut();
326
327 let device = block.device().unwrap();
328 assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
329
330 let dtype = block.dtype().unwrap();
331 assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
332 assert_eq!(dtype.bits, 64);
333 }
334}