Skip to main content

metatensor/block/
block_ref.rs

1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t};
5use crate::c_api::MTS_INVALID_PARAMETER_ERROR;
6
7use crate::errors::check_status;
8use crate::{ArrayRef, Labels, Error};
9
10use super::{TensorBlock, LazyMetadata};
11
12/// Reference to a [`TensorBlock`]
13#[derive(Debug, Clone, Copy)]
14pub struct TensorBlockRef<'a> {
15    ptr: *const mts_block_t,
16    marker: std::marker::PhantomData<&'a mts_block_t>,
17}
18
19// SAFETY: Send is fine since TensorBlockRef does not implement Drop
20unsafe impl Send for TensorBlockRef<'_> {}
21// SAFETY: Sync is fine since there is no internal mutability in TensorBlockRef
22unsafe impl Sync for TensorBlockRef<'_> {}
23
24/// All the basic data in a `TensorBlockRef` as a struct with separate fields.
25///
26/// This can be useful when you need to borrow different fields on this struct
27/// separately. They are separate in the underlying metatensor-core, but since
28/// we go through the C API to access them, we need to re-expose them as
29/// separate fields for the rust compiler to be able to understand that.
30///
31/// The metadata is initialized lazily on first access, to not pay the cost of
32/// allocation/reference count increase if some metadata is not used.
33#[derive(Debug)]
34pub struct TensorBlockData<'a> {
35    pub values: ArrayRef<'a>,
36    pub samples: LazyMetadata<Labels>,
37    pub components: LazyMetadata<Vec<Labels>>,
38    pub properties: LazyMetadata<Labels>,
39}
40
41impl<'a> TensorBlockRef<'a> {
42    /// Create a new `TensorBlockRef` from the given raw `mts_block_t`
43    ///
44    /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
45    /// Make sure the lifetime is actually constrained by the lifetime of the
46    /// owner of this `mts_block_t`.
47    pub(crate) unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
48        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
49
50        TensorBlockRef {
51            ptr: ptr,
52            marker: std::marker::PhantomData,
53        }
54    }
55
56    /// Get the underlying raw pointer
57    pub(crate) fn as_ptr(&self) -> *const mts_block_t {
58        self.ptr
59    }
60}
61
62/// Get a gradient from this block
63fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
64    let mut gradient_block = std::ptr::null_mut();
65    let status = unsafe { crate::c_api::mts_block_gradient(
66            // the cast to mut pointer is fine since we are only returning a
67            // non-mut mts_block_t below
68            block.cast_mut(),
69            parameter.as_ptr(),
70            &mut gradient_block
71        )
72    };
73
74    match crate::errors::check_status(status) {
75        Ok(()) => Some(gradient_block.cast_const()),
76        Err(error) => {
77            if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
78                // there is no array for this gradient
79                None
80            } else {
81                panic!("failed to get the gradient from a block: {:?}", error)
82            }
83        }
84    }
85}
86
87pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
88    unsafe {
89        TensorBlockRef::from_raw(ptr).samples()
90    }
91}
92
93pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
94    unsafe {
95        TensorBlockRef::from_raw(ptr).components()
96    }
97}
98
99pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
100    unsafe {
101        TensorBlockRef::from_raw(ptr).properties()
102    }
103}
104
105impl<'a> TensorBlockRef<'a> {
106    /// Get all the data and metadata inside this `TensorBlockRef` as a
107    /// struct with separate fields, to allow borrowing them separately.
108    #[inline]
109    pub fn data(&'a self) -> TensorBlockData<'a> {
110        TensorBlockData {
111            values: self.values(),
112            samples: LazyMetadata::new(get_samples, self.as_ptr()),
113            components: LazyMetadata::new(get_components, self.as_ptr()),
114            properties: LazyMetadata::new(get_properties, self.as_ptr()),
115        }
116    }
117
118    /// Get the array for the values in this block
119    #[inline]
120    pub fn values(&self) -> ArrayRef<'a> {
121        let mut array = mts_array_t::null();
122        unsafe {
123            crate::errors::check_status(crate::c_api::mts_block_data(
124                self.as_ptr().cast_mut(),
125                &mut array
126            )).expect("failed to get the array for a block");
127        };
128
129        // SAFETY: we can return an `ArrayRef` with lifetime `'a` (instead of
130        // `'self`) (which allows to get multiple references to the BasicBlock
131        // simultaneously), because there is no way to also get a mutable
132        // reference to the block at the same time (since we are already holding
133        // a const reference to the block itself).
134        unsafe { ArrayRef::from_raw(array) }
135    }
136
137    #[inline]
138    fn labels(&self, dimension: usize) -> Labels {
139        let ptr = unsafe {
140            crate::c_api::mts_block_labels(self.as_ptr(), dimension)
141        };
142        crate::errors::check_ptr(ptr).expect("failed to get labels");
143
144        unsafe { Labels::from_raw(ptr) }
145    }
146
147    /// Get the samples for this block
148    #[inline]
149    pub fn samples(&self) -> Labels {
150        return self.labels(0);
151    }
152
153    /// Get the components for this block
154    #[inline]
155    pub fn components(&self) -> Vec<Labels> {
156        let values = self.values();
157        let shape = values.shape().expect("failed to get the data shape");
158
159        let mut result = Vec::new();
160        for i in 1..(shape.len() - 1) {
161            result.push(self.labels(i));
162        }
163        return result;
164    }
165
166    /// Get the properties for this block
167    #[inline]
168    pub fn properties(&self) -> Labels {
169        let values = self.values();
170        let shape = values.shape().expect("failed to get the data shape");
171
172        return self.labels(shape.len() - 1);
173    }
174
175    /// Get the full list of gradients in this block
176    #[inline]
177    pub fn gradient_list(&self) -> Vec<&'a str> {
178        let mut parameters_ptr = std::ptr::null();
179        let mut parameters_count = 0;
180        unsafe {
181            check_status(crate::c_api::mts_block_gradients_list(
182                self.as_ptr(),
183                &mut parameters_ptr,
184                &mut parameters_count
185            )).expect("failed to get gradient list");
186        }
187
188        if parameters_count == 0 {
189            return Vec::new();
190        } else {
191            assert!(!parameters_ptr.is_null());
192            // SAFETY: we can return strings with the `'a` lifetime (instead of
193            // `'self`), because there is no way to also get a mutable reference
194            // to the gradient parameters at the same time.
195            unsafe {
196                let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
197                return parameters.iter()
198                    .map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
199                    .collect();
200            }
201        }
202    }
203
204    /// Get the data and metadata for the gradient with respect to the given
205    /// parameter in this block, if it exists.
206    #[inline]
207    pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
208        // SAFETY: we can return a TensorBlockRef with lifetime `'a` (instead of
209        // `'self`) for the same reasons as in the `values` function.
210        let parameter = CString::new(parameter).expect("invalid C string");
211
212        block_gradient(self.as_ptr(), &parameter)
213            .map(|gradient_block| {
214                // SAFETY: the lifetime of the block is the same as
215                // the lifetime of self, both are constrained to the
216                // root TensorMap/TensorBlock
217                unsafe { TensorBlockRef::from_raw(gradient_block) }
218        })
219    }
220
221    /// Clone this block, cloning all the data and metadata contained inside.
222    ///
223    /// This can fail if the external data held inside an `mts_array_t` can not
224    /// be cloned.
225    #[inline]
226    pub fn try_clone(&self) -> Result<TensorBlock, Error> {
227        let ptr = unsafe {
228            crate::c_api::mts_block_copy(self.as_ptr())
229        };
230        crate::errors::check_ptr(ptr)?;
231
232        return Ok(unsafe { TensorBlock::from_raw(ptr) });
233    }
234
235    /// Get an iterator over parameter/[`TensorBlockRef`] pairs for all gradients in
236    /// this block
237    #[inline]
238    pub fn gradients(&self) -> GradientsIter<'_> {
239        GradientsIter {
240            parameters: self.gradient_list().into_iter(),
241            block: self.as_ptr(),
242        }
243    }
244
245    /// Save the given block to the file at `path`
246    ///
247    /// This is a convenience function calling [`crate::io::save_block`]
248    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
249        return crate::io::save_block(path, *self);
250    }
251
252    /// Save the given block to an in-memory buffer
253    ///
254    /// This is a convenience function calling [`crate::io::save_block_buffer`]
255    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
256        return crate::io::save_block_buffer(*self, buffer);
257    }
258}
259
260/// Iterator over parameter/[`TensorBlockRef`] pairs for all gradients in a
261/// [`TensorBlockRef`]
262pub struct GradientsIter<'a> {
263    parameters: std::vec::IntoIter<&'a str>,
264    block: *const mts_block_t,
265}
266
267impl<'a> Iterator for GradientsIter<'a> {
268    type Item = (&'a str, TensorBlockRef<'a>);
269
270    #[inline]
271    fn next(&mut self) -> Option<Self::Item> {
272        self.parameters.next().map(|parameter| {
273            let parameter_c = CString::new(parameter).expect("invalid C string");
274            let block = block_gradient(self.block, &parameter_c).expect("missing gradient");
275
276            // SAFETY: the lifetime of the block is the same as the lifetime of
277            // the GradientsIter, both are constrained to the root
278            // TensorMap/TensorBlock
279            let block = unsafe { TensorBlockRef::from_raw(block) };
280            return (parameter, block);
281        })
282    }
283
284    fn size_hint(&self) -> (usize, Option<usize>) {
285        (self.len(), Some(self.len()))
286    }
287}
288
289impl ExactSizeIterator for GradientsIter<'_> {
290    #[inline]
291    fn len(&self) -> usize {
292        self.parameters.len()
293    }
294}
295
296impl FusedIterator for GradientsIter<'_> {}
297
298#[cfg(test)]
299mod tests {
300    use crate::{Labels, TensorBlock};
301
302    #[test]
303    #[allow(clippy::float_cmp)]
304    fn gradients() {
305        let properties = Labels::new(["p"], &[[-2], [0], [1]]);
306        let mut block = TensorBlock::new(
307            ndarray::Array::from_elem(vec![2, 3], 1.0),
308            &Labels::new(["s"], &[[0], [1]]), &[], &properties,
309        ).unwrap();
310
311        block.add_gradient("g", TensorBlock::new(
312            ndarray::Array::from_elem(vec![2, 3], -1.0),
313            &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
314        ).unwrap()).unwrap();
315
316        block.add_gradient("f", TensorBlock::new(
317            ndarray::Array::from_elem(vec![2, 3], -2.0),
318            &Labels::new(["sample"], &[[0], [1]]), &[], &properties,
319        ).unwrap()).unwrap();
320
321
322        let block = block.as_ref();
323        let gradient = block.gradient("g").unwrap();
324        let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
325        assert_eq!(values[[0, 0]], -1.0);
326
327        let gradient = block.gradient("f").unwrap();
328        let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
329        assert_eq!(values[[0, 0]], -2.0);
330
331        assert!(block.gradient("h").is_none());
332
333        let mut iter = block.gradients();
334        assert_eq!(iter.len(), 2);
335
336        assert_eq!(iter.next().unwrap().0, "g");
337        assert_eq!(iter.next().unwrap().0, "f");
338        assert!(iter.next().is_none());
339    }
340
341    #[test]
342    fn block_data() {
343        let block = TensorBlock::new(
344            ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
345            &Labels::new(["samples"], &[[0], [1]]),
346            &[Labels::new(["component"], &[[0]])],
347            &Labels::new(["properties"], &[[-2], [0], [1]]),
348        ).unwrap();
349        let block = block.as_ref();
350
351        let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
352        assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
353        assert_eq!(block.samples(), Labels::new(["samples"], &[[0], [1]]));
354        assert_eq!(block.components(), [Labels::new(["component"], &[[0]])]);
355        assert_eq!(block.properties(), Labels::new(["properties"], &[[-2], [0], [1]]));
356
357        let block = block.data();
358        let values = block.values.to_ndarray_lock::<f64>().read().unwrap();
359        assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
360        assert_eq!(*block.samples, Labels::new(["samples"], &[[0], [1]]));
361        assert_eq!(*block.components, [Labels::new(["component"], &[[0]])]);
362        assert_eq!(*block.properties, Labels::new(["properties"], &[[-2], [0], [1]]));
363    }
364}