Skip to main content

metatensor/block/
block_mut.rs

1use std::ffi::{CString, CStr};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
5use crate::{ArrayRef, ArrayRefMut, Labels, Error};
6
7use super::{TensorBlockRef, LazyMetadata};
8use super::block_ref::{get_samples, get_components, get_properties};
9
10/// Mutable reference to a [`TensorBlock`](crate::TensorBlock)
11#[derive(Debug)]
12pub struct TensorBlockRefMut<'a> {
13    ptr: *mut mts_block_t,
14    marker: std::marker::PhantomData<&'a mut mts_block_t>,
15}
16
17// SAFETY: Send is fine since TensorBlockRefMut does not implement Drop
18unsafe impl Send for TensorBlockRefMut<'_> {}
19// SAFETY: Sync is fine since there is no internal mutability in TensorBlockRefMut
20// (all mutations still require a `&mut TensorBlockRefMut`)
21unsafe impl Sync for TensorBlockRefMut<'_> {}
22
23/// All the basic data in a `TensorBlockRefMut` as a struct with separate fields.
24///
25/// This can be useful when you need to borrow different fields on this struct
26/// separately. They are separate in the underlying metatensor-core, but since we
27/// go through the C API to access them, we need to re-expose them as separate
28/// fields for the rust compiler to be able to understand that.
29///
30/// The metadata is initialized lazily on first access, to not pay the cost of
31/// allocation/reference count increase if some metadata is not used.
32#[derive(Debug)]
33pub struct TensorBlockDataMut<'a> {
34    pub values: ArrayRefMut<'a>,
35    pub samples: LazyMetadata<Labels>,
36    pub components: LazyMetadata<Vec<Labels>>,
37    pub properties: LazyMetadata<Labels>,
38}
39
40/// Get a gradient from this block
41fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
42    let mut gradient_block = std::ptr::null_mut();
43    let status = unsafe { crate::c_api::mts_block_gradient(
44            block,
45            parameter.as_ptr(),
46            &mut gradient_block
47        )
48    };
49
50    match crate::errors::check_status(status) {
51        Ok(()) => Some(gradient_block),
52        Err(error) => {
53            if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
54                // there is no array for this gradient
55                None
56            } else {
57                panic!("failed to get the gradient from a block: {:?}", error)
58            }
59        }
60    }
61}
62
63impl<'a> TensorBlockRefMut<'a> {
64    /// Create a new `TensorBlockRefMut` from the given raw `mts_block_t`
65    ///
66    /// This is a **VERY** unsafe function, creating a lifetime out of thin air,
67    /// and allowing mutable access to the `mts_block_t`. Make sure the lifetime
68    /// is actually constrained by the lifetime of the owner of this
69    /// `mts_block_t`; and that the owner is mutably borrowed by this
70    /// `TensorBlockRefMut`.
71    pub unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
72        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
73
74        TensorBlockRefMut {
75            ptr: ptr,
76            marker: std::marker::PhantomData,
77        }
78    }
79
80    /// Get the underlying raw pointer
81    pub fn as_ptr(&self) -> *const mts_block_t {
82        self.ptr
83    }
84
85    /// Get the underlying (mutable) raw pointer
86    pub fn as_mut_ptr(&mut self) -> *mut mts_block_t {
87        self.ptr
88    }
89
90    /// Get a non mutable reference to this block
91    #[inline]
92    pub fn as_ref(&self) -> TensorBlockRef<'_> {
93        unsafe {
94            TensorBlockRef::from_raw(self.as_ptr())
95        }
96    }
97
98    /// Get all the data and metadata inside this `TensorBlockRefMut` as a
99    /// struct with separate fields, to allow borrowing them separately.
100    #[inline]
101    pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
102        let samples = LazyMetadata::new(get_samples, self.as_ptr());
103        let components = LazyMetadata::new(get_components, self.as_ptr());
104        let properties = LazyMetadata::new(get_properties, self.as_ptr());
105
106        TensorBlockDataMut {
107            // SAFETY: we are returning an `ArrayRefMut` mutably borrowing from `self`
108            values: self.values_mut(),
109            samples: samples,
110            components: components,
111            properties: properties,
112        }
113    }
114
115    /// Get the device on which the values of this block are stored.
116    #[inline]
117    pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
118        self.as_ref().device()
119    }
120
121    /// Get the data type of the values of this block.
122    #[inline]
123    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
124        self.as_ref().dtype()
125    }
126
127    /// Get a mutable reference to the values in this block
128    #[inline]
129    pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
130        let mut array = mts_array_t::null();
131        unsafe {
132            crate::errors::check_status(crate::c_api::mts_block_data(
133                self.as_mut_ptr(),
134                &mut array
135            )).expect("failed to get the array for a block");
136        };
137
138        // SAFETY: we are returning an `ArrayRefMut` mutably borrowing from `self`
139        unsafe { ArrayRefMut::from_raw(array) }
140    }
141
142    /// Get the array for the values in this block
143    #[inline]
144    pub fn values(&self) -> ArrayRef<'_> {
145        return self.as_ref().values();
146    }
147
148    /// Get the samples for this block
149    #[inline]
150    pub fn samples(&self) -> Labels {
151        return self.as_ref().samples();
152    }
153
154    /// Get the components for this block
155    #[inline]
156    pub fn components(&self) -> Vec<Labels> {
157        return self.as_ref().components();
158    }
159
160    /// Get the properties for this block
161    #[inline]
162    pub fn properties(&self) -> Labels {
163        return self.as_ref().properties();
164    }
165
166    /// Get a mutable reference to the data and metadata for the gradient with
167    /// respect to the given parameter in this block, if it exists.
168    #[inline]
169    pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
170        let parameter = CString::new(parameter).expect("invalid C string");
171
172        block_gradient(self.as_mut_ptr(), &parameter)
173            .map(|gradient_block| {
174                // SAFETY: we are returning an `TensorBlockRefMut` mutably
175                // borrowing from `self`
176                unsafe { TensorBlockRefMut::from_raw(gradient_block) }
177            })
178    }
179
180    /// Get an iterator over parameter/[`TensorBlockRefMut`] pairs for all gradients
181    /// in this block
182    #[inline]
183    pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
184        let block_ptr = self.as_mut_ptr();
185        GradientsMutIter {
186            parameters: self.as_ref().gradient_list().into_iter(),
187            block: block_ptr,
188        }
189    }
190
191    /// Save the given block to the file at `path`
192    ///
193    /// This is a convenience function calling [`crate::io::save_block`]
194    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
195        self.as_ref().save(path)
196    }
197
198    /// Save the given block to an in-memory buffer
199    ///
200    /// This is a convenience function calling [`crate::io::save_block_buffer`]
201    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
202        self.as_ref().save_buffer(buffer)
203    }
204}
205
206/// Iterator over parameter/[`TensorBlockRefMut`] pairs for all gradients in a
207/// [`TensorBlockRefMut`]
208pub struct GradientsMutIter<'a> {
209    parameters: std::vec::IntoIter<&'a str>,
210    block: *mut mts_block_t,
211}
212
213impl<'a> Iterator for GradientsMutIter<'a> {
214    type Item = (&'a str, TensorBlockRefMut<'a>);
215
216    #[inline]
217    fn next(&mut self) -> Option<Self::Item> {
218        self.parameters.next().map(|parameter| {
219            let parameter_c = CString::new(parameter).expect("invalid C string");
220            let block = block_gradient(self.block, &parameter_c).expect("missing gradient");
221
222            // SAFETY: all blocks are disjoint, and we are only returning a
223            // mutable reference to each once. The reference lifetime is
224            // constrained by the lifetime of the parent TensorBlockRefMut
225            let block = unsafe { TensorBlockRefMut::from_raw(block) };
226            return (parameter, block);
227        })
228    }
229
230    fn size_hint(&self) -> (usize, Option<usize>) {
231        self.parameters.size_hint()
232    }
233}
234
235impl ExactSizeIterator for GradientsMutIter<'_> {
236    #[inline]
237    fn len(&self) -> usize {
238        self.parameters.len()
239    }
240}
241
242impl FusedIterator for GradientsMutIter<'_> {}
243
244#[cfg(test)]
245mod tests {
246    use crate::{Labels, TensorBlock};
247
248    #[test]
249    #[allow(clippy::float_cmp)]
250    fn gradients() {
251        let properties = Labels::new(["p"], [[-2], [0], [1]]);
252        let mut block = TensorBlock::new(
253            ndarray::Array::from_elem(vec![2, 3], 1.0),
254            &Labels::new(["s"], [[0], [1]]), &[], &properties,
255        ).unwrap();
256
257        block.add_gradient("g", TensorBlock::new(
258            ndarray::Array::from_elem(vec![2, 3], -1.0),
259            &Labels::new(["sample"], [[0], [1]]), &[], &properties,
260        ).unwrap()).unwrap();
261
262        block.add_gradient("f", TensorBlock::new(
263            ndarray::Array::from_elem(vec![2, 3], -2.0),
264            &Labels::new(["sample"], [[0], [1]]), &[], &properties,
265        ).unwrap()).unwrap();
266
267
268        let mut block = block.as_ref_mut();
269        {
270            let gradient = block.gradient_mut("g").unwrap();
271            let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
272            assert_eq!(values[[0, 0]], -1.0);
273        }
274
275        {
276            let gradient = block.gradient_mut("f").unwrap();
277            let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
278            assert_eq!(values[[0, 0]], -2.0);
279        }
280
281        assert!(block.gradient_mut("h").is_none());
282
283        let mut iter = block.gradients_mut();
284        assert_eq!(iter.len(), 2);
285
286        assert_eq!(iter.next().unwrap().0, "g");
287        assert_eq!(iter.next().unwrap().0, "f");
288        assert!(iter.next().is_none());
289    }
290
291    #[test]
292    fn block_data() {
293        let mut block = TensorBlock::new(
294            ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
295            &Labels::new(["samples"], [[0], [1]]),
296            &[Labels::new(["component"], [[0]])],
297            &Labels::new(["properties"], [[-2], [0], [1]]),
298        ).unwrap();
299        let mut block = block.as_ref_mut();
300
301        {
302            let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
303            assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
304            assert_eq!(block.samples(), Labels::new(["samples"], [[0], [1]]));
305            assert_eq!(block.components(), [Labels::new(["component"], [[0]])]);
306            assert_eq!(block.properties(), Labels::new(["properties"], [[-2], [0], [1]]));
307        }
308
309        let block = block.data_mut();
310        let values = block.values.to_ndarray_lock::<f64>().read().unwrap();
311        assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
312        assert_eq!(*block.samples, Labels::new(["samples"], [[0], [1]]));
313        assert_eq!(*block.components, [Labels::new(["component"], [[0]])]);
314        assert_eq!(*block.properties, Labels::new(["properties"], [[-2], [0], [1]]));
315    }
316
317    #[test]
318    fn device_and_dtype() {
319        let mut block = TensorBlock::new(
320            ndarray::Array::from_elem(vec![2, 3], 1.0),
321            &Labels::new(["samples"], [[0], [1]]),
322            &[],
323            &Labels::new(["properties"], [[-2], [0], [1]]),
324        ).unwrap();
325        let block = block.as_ref_mut();
326
327        let device = block.device().unwrap();
328        assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
329
330        let dtype = block.dtype().unwrap();
331        assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
332        assert_eq!(dtype.bits, 64);
333    }
334}