1use std::ffi::{CString, CStr};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
5use crate::{ArrayRef, ArrayRefMut, Labels, Error};
6
7use super::{TensorBlockRef, LazyMetadata};
8use super::block_ref::{get_samples, get_components, get_properties};
9
10#[derive(Debug)]
12pub struct TensorBlockRefMut<'a> {
13 ptr: *mut mts_block_t,
14 marker: std::marker::PhantomData<&'a mut mts_block_t>,
15}
16
17unsafe impl Send for TensorBlockRefMut<'_> {}
19unsafe impl Sync for TensorBlockRefMut<'_> {}
22
23#[derive(Debug)]
33pub struct TensorBlockDataMut<'a> {
34 pub values: ArrayRefMut<'a>,
35 pub samples: LazyMetadata<Labels>,
36 pub components: LazyMetadata<Vec<Labels>>,
37 pub properties: LazyMetadata<Labels>,
38}
39
40fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
42 let mut gradient_block = std::ptr::null_mut();
43 let status = unsafe { crate::c_api::mts_block_gradient(
44 block,
45 parameter.as_ptr(),
46 &mut gradient_block
47 )
48 };
49
50 match crate::errors::check_status(status) {
51 Ok(()) => Some(gradient_block),
52 Err(error) => {
53 if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
54 None
56 } else {
57 panic!("failed to get the gradient from a block: {:?}", error)
58 }
59 }
60 }
61}
62
63impl<'a> TensorBlockRefMut<'a> {
64 pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
72 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
73
74 TensorBlockRefMut {
75 ptr: ptr,
76 marker: std::marker::PhantomData,
77 }
78 }
79
80 pub(super) fn as_ptr(&self) -> *const mts_block_t {
82 self.ptr
83 }
84
85 pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
87 self.ptr
88 }
89
90 #[inline]
92 pub fn as_ref(&self) -> TensorBlockRef<'_> {
93 unsafe {
94 TensorBlockRef::from_raw(self.as_ptr())
95 }
96 }
97
98 #[inline]
101 pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
102 let samples = LazyMetadata::new(get_samples, self.as_ptr());
103 let components = LazyMetadata::new(get_components, self.as_ptr());
104 let properties = LazyMetadata::new(get_properties, self.as_ptr());
105
106 TensorBlockDataMut {
107 values: self.values_mut(),
109 samples: samples,
110 components: components,
111 properties: properties,
112 }
113 }
114
115 #[inline]
117 pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
118 let mut array = mts_array_t::null();
119 unsafe {
120 crate::errors::check_status(crate::c_api::mts_block_data(
121 self.as_mut_ptr(),
122 &mut array
123 )).expect("failed to get the array for a block");
124 };
125
126 unsafe { ArrayRefMut::new(array) }
128 }
129
130 #[inline]
132 pub fn values(&self) -> ArrayRef<'_> {
133 return self.as_ref().values();
134 }
135
136 #[inline]
138 pub fn samples(&self) -> Labels {
139 return self.as_ref().samples();
140 }
141
142 #[inline]
144 pub fn components(&self) -> Vec<Labels> {
145 return self.as_ref().components();
146 }
147
148 #[inline]
150 pub fn properties(&self) -> Labels {
151 return self.as_ref().properties();
152 }
153
154 #[inline]
157 pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
158 let parameter = CString::new(parameter).expect("invalid C string");
159
160 block_gradient(self.as_mut_ptr(), ¶meter)
161 .map(|gradient_block| {
162 unsafe { TensorBlockRefMut::from_raw(gradient_block) }
165 })
166 }
167
168 #[inline]
171 pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
172 let block_ptr = self.as_mut_ptr();
173 GradientsMutIter {
174 parameters: self.as_ref().gradient_list().into_iter(),
175 block: block_ptr,
176 }
177 }
178
179 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
183 self.as_ref().save(path)
184 }
185
186 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
190 self.as_ref().save_buffer(buffer)
191 }
192}
193
194pub struct GradientsMutIter<'a> {
197 parameters: std::vec::IntoIter<&'a str>,
198 block: *mut mts_block_t,
199}
200
201impl<'a> Iterator for GradientsMutIter<'a> {
202 type Item = (&'a str, TensorBlockRefMut<'a>);
203
204 #[inline]
205 fn next(&mut self) -> Option<Self::Item> {
206 self.parameters.next().map(|parameter| {
207 let parameter_c = CString::new(parameter).expect("invalid C string");
208 let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
209
210 let block = unsafe { TensorBlockRefMut::from_raw(block) };
214 return (parameter, block);
215 })
216 }
217
218 fn size_hint(&self) -> (usize, Option<usize>) {
219 self.parameters.size_hint()
220 }
221}
222
223impl ExactSizeIterator for GradientsMutIter<'_> {
224 #[inline]
225 fn len(&self) -> usize {
226 self.parameters.len()
227 }
228}
229
230impl FusedIterator for GradientsMutIter<'_> {}
231
232#[cfg(test)]
233mod tests {
234 use crate::{Labels, TensorBlock};
235
236 #[test]
237 #[allow(clippy::float_cmp)]
238 fn gradients() {
239 let properties = Labels::new(["p"], &[[-2], [0], [1]]);
240 let mut block = TensorBlock::new(
241 ndarray::ArrayD::from_elem(vec![2, 3], 1.0),
242 &Labels::new(["s"], &[[0], [1]]), &[], &properties,
243 ).unwrap();
244
245 block.add_gradient("g", TensorBlock::new(
246 ndarray::ArrayD::from_elem(vec![2, 3], -1.0),
247 &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
248 ).unwrap()).unwrap();
249
250 block.add_gradient("f", TensorBlock::new(
251 ndarray::ArrayD::from_elem(vec![2, 3], -2.0),
252 &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
253 ).unwrap()).unwrap();
254
255
256 let mut block = block.as_ref_mut();
257 let gradient = block.gradient_mut("g").unwrap();
258 assert_eq!(gradient.values().as_array()[[0, 0]], -1.0);
259
260 let gradient = block.gradient_mut("f").unwrap();
261 assert_eq!(gradient.values().as_array()[[0, 0]], -2.0);
262
263 assert!(block.gradient_mut("h").is_none());
264
265 let mut iter = block.gradients_mut();
266 assert_eq!(iter.len(), 2);
267
268 assert_eq!(iter.next().unwrap().0, "g");
269 assert_eq!(iter.next().unwrap().0, "f");
270 assert!(iter.next().is_none());
271 }
272
273 #[test]
274 fn block_data() {
275 let mut block = TensorBlock::new(
276 ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0),
277 &Labels::new(["samples"], &[[0], [1]]),
278 &[Labels::new(["component"], &[[0]])],
279 &Labels::new(["properties"], &[[-2], [0], [1]]),
280 ).unwrap();
281 let mut block = block.as_ref_mut();
282
283 assert_eq!(block.values().as_array(), ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0));
284 assert_eq!(block.samples(), Labels::new(["samples"], &[[0], [1]]));
285 assert_eq!(block.components(), [Labels::new(["component"], &[[0]])]);
286 assert_eq!(block.properties(), Labels::new(["properties"], &[[-2], [0], [1]]));
287
288 let block = block.data_mut();
289 assert_eq!(block.values.as_array(), ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0));
290 assert_eq!(*block.samples, Labels::new(["samples"], &[[0], [1]]));
291 assert_eq!(*block.components, [Labels::new(["component"], &[[0]])]);
292 assert_eq!(*block.properties, Labels::new(["properties"], &[[-2], [0], [1]]));
293 }
294}