1use std::ffi::{CString, CStr};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
5use crate::{ArrayRef, ArrayRefMut, Labels, Error};
6
7use super::{TensorBlockRef, LazyMetadata};
8use super::block_ref::{get_samples, get_components, get_properties};
9
10#[derive(Debug)]
12pub struct TensorBlockRefMut<'a> {
13    ptr: *mut mts_block_t,
14    marker: std::marker::PhantomData<&'a mut mts_block_t>,
15}
16
17unsafe impl Send for TensorBlockRefMut<'_> {}
19unsafe impl Sync for TensorBlockRefMut<'_> {}
22
23#[derive(Debug)]
33pub struct TensorBlockDataMut<'a> {
34    pub values: ArrayRefMut<'a>,
35    pub samples: LazyMetadata<Labels>,
36    pub components: LazyMetadata<Vec<Labels>>,
37    pub properties: LazyMetadata<Labels>,
38}
39
40fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
42    let mut gradient_block = std::ptr::null_mut();
43    let status = unsafe { crate::c_api::mts_block_gradient(
44            block,
45            parameter.as_ptr(),
46            &mut gradient_block
47        )
48    };
49
50    match crate::errors::check_status(status) {
51        Ok(()) => Some(gradient_block),
52        Err(error) => {
53            if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
54                None
56            } else {
57                panic!("failed to get the gradient from a block: {:?}", error)
58            }
59        }
60    }
61}
62
63impl<'a> TensorBlockRefMut<'a> {
64    pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
72        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
73
74        TensorBlockRefMut {
75            ptr: ptr,
76            marker: std::marker::PhantomData,
77        }
78    }
79
80    pub(super) fn as_ptr(&self) -> *const mts_block_t {
82        self.ptr
83    }
84
85    pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
87        self.ptr
88    }
89
90    #[inline]
92    pub fn as_ref(&self) -> TensorBlockRef<'_> {
93        unsafe {
94            TensorBlockRef::from_raw(self.as_ptr())
95        }
96    }
97
98    #[inline]
101    pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
102        let samples = LazyMetadata::new(get_samples, self.as_ptr());
103        let components = LazyMetadata::new(get_components, self.as_ptr());
104        let properties = LazyMetadata::new(get_properties, self.as_ptr());
105
106        TensorBlockDataMut {
107            values: self.values_mut(),
109            samples: samples,
110            components: components,
111            properties: properties,
112        }
113    }
114
115    #[inline]
117    pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
118        let mut array = mts_array_t::null();
119        unsafe {
120            crate::errors::check_status(crate::c_api::mts_block_data(
121                self.as_mut_ptr(),
122                &mut array
123            )).expect("failed to get the array for a block");
124        };
125
126        unsafe { ArrayRefMut::new(array) }
128    }
129
130    #[inline]
132    pub fn values(&self) -> ArrayRef<'_> {
133        return self.as_ref().values();
134    }
135
136    #[inline]
138    pub fn samples(&self) -> Labels {
139        return self.as_ref().samples();
140    }
141
142    #[inline]
144    pub fn components(&self) -> Vec<Labels> {
145        return self.as_ref().components();
146    }
147
148    #[inline]
150    pub fn properties(&self) -> Labels {
151        return self.as_ref().properties();
152    }
153
154    #[inline]
157    pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
158        let parameter = CString::new(parameter).expect("invalid C string");
159
160        block_gradient(self.as_mut_ptr(), ¶meter)
161            .map(|gradient_block| {
162                unsafe { TensorBlockRefMut::from_raw(gradient_block) }
165            })
166    }
167
168    #[inline]
171    pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
172        let block_ptr = self.as_mut_ptr();
173        GradientsMutIter {
174            parameters: self.as_ref().gradient_list().into_iter(),
175            block: block_ptr,
176        }
177    }
178
179    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
183        self.as_ref().save(path)
184    }
185
186    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
190        self.as_ref().save_buffer(buffer)
191    }
192}
193
194pub struct GradientsMutIter<'a> {
197    parameters: std::vec::IntoIter<&'a str>,
198    block: *mut mts_block_t,
199}
200
201impl<'a> Iterator for GradientsMutIter<'a> {
202    type Item = (&'a str, TensorBlockRefMut<'a>);
203
204    #[inline]
205    fn next(&mut self) -> Option<Self::Item> {
206        self.parameters.next().map(|parameter| {
207            let parameter_c = CString::new(parameter).expect("invalid C string");
208            let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
209
210            let block = unsafe { TensorBlockRefMut::from_raw(block) };
214            return (parameter, block);
215        })
216    }
217
218    fn size_hint(&self) -> (usize, Option<usize>) {
219        self.parameters.size_hint()
220    }
221}
222
223impl ExactSizeIterator for GradientsMutIter<'_> {
224    #[inline]
225    fn len(&self) -> usize {
226        self.parameters.len()
227    }
228}
229
230impl FusedIterator for GradientsMutIter<'_> {}
231
232#[cfg(test)]
233mod tests {
234    use crate::{Labels, TensorBlock};
235
236    #[test]
237    #[allow(clippy::float_cmp)]
238    fn gradients() {
239        let properties = Labels::new(["p"], &[[-2], [0], [1]]);
240        let mut block = TensorBlock::new(
241            ndarray::ArrayD::from_elem(vec![2, 3], 1.0),
242            &Labels::new(["s"], &[[0], [1]]), &[], &properties,
243        ).unwrap();
244
245        block.add_gradient("g", TensorBlock::new(
246            ndarray::ArrayD::from_elem(vec![2, 3], -1.0),
247            &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
248        ).unwrap()).unwrap();
249
250        block.add_gradient("f", TensorBlock::new(
251            ndarray::ArrayD::from_elem(vec![2, 3], -2.0),
252            &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
253        ).unwrap()).unwrap();
254
255
256        let mut block = block.as_ref_mut();
257        let gradient = block.gradient_mut("g").unwrap();
258        assert_eq!(gradient.values().as_array()[[0, 0]], -1.0);
259
260        let gradient = block.gradient_mut("f").unwrap();
261        assert_eq!(gradient.values().as_array()[[0, 0]], -2.0);
262
263        assert!(block.gradient_mut("h").is_none());
264
265        let mut iter = block.gradients_mut();
266        assert_eq!(iter.len(), 2);
267
268        assert_eq!(iter.next().unwrap().0, "g");
269        assert_eq!(iter.next().unwrap().0, "f");
270        assert!(iter.next().is_none());
271    }
272
273    #[test]
274    fn block_data() {
275        let mut block = TensorBlock::new(
276            ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0),
277            &Labels::new(["samples"], &[[0], [1]]),
278            &[Labels::new(["component"], &[[0]])],
279            &Labels::new(["properties"], &[[-2], [0], [1]]),
280        ).unwrap();
281        let mut block = block.as_ref_mut();
282
283        assert_eq!(block.values().as_array(), ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0));
284        assert_eq!(block.samples(), Labels::new(["samples"], &[[0], [1]]));
285        assert_eq!(block.components(), [Labels::new(["component"], &[[0]])]);
286        assert_eq!(block.properties(), Labels::new(["properties"], &[[-2], [0], [1]]));
287
288        let block = block.data_mut();
289        assert_eq!(block.values.as_array(), ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0));
290        assert_eq!(*block.samples, Labels::new(["samples"], &[[0], [1]]));
291        assert_eq!(*block.components, [Labels::new(["component"], &[[0]])]);
292        assert_eq!(*block.properties, Labels::new(["properties"], &[[-2], [0], [1]]));
293    }
294}