metatensor/block/
block_ref.rs

1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, mts_labels_t};
5use crate::c_api::MTS_INVALID_PARAMETER_ERROR;
6
7use crate::errors::check_status;
8use crate::{ArrayRef, Labels, Error};
9
10use super::{TensorBlock, LazyMetadata};
11
12/// Reference to a [`TensorBlock`]
13#[derive(Debug, Clone, Copy)]
14pub struct TensorBlockRef<'a> {
15    ptr: *const mts_block_t,
16    marker: std::marker::PhantomData<&'a mts_block_t>,
17}
18
19// SAFETY: Send is fine since TensorBlockRef does not implement Drop
20unsafe impl Send for TensorBlockRef<'_> {}
21// SAFETY: Sync is fine since there is no internal mutability in TensorBlockRef
22unsafe impl Sync for TensorBlockRef<'_> {}
23
24/// All the basic data in a `TensorBlockRef` as a struct with separate fields.
25///
26/// This can be useful when you need to borrow different fields on this struct
27/// separately. They are separate in the underlying metatensor-core, but since we
28/// go through the C API to access them, we need to re-expose them as separate
29/// fields for the rust compiler to be able to understand that.
30///
31/// The metadata is initialized lazily on first access, to not pay the cost of
32/// allocation/reference count increase if some metadata is not used.
33#[derive(Debug)]
34pub struct TensorBlockData<'a> {
35    pub values: ArrayRef<'a>,
36    pub samples: LazyMetadata<Labels>,
37    pub components: LazyMetadata<Vec<Labels>>,
38    pub properties: LazyMetadata<Labels>,
39}
40
41impl<'a> TensorBlockRef<'a> {
42    /// Create a new `TensorBlockRef` from the given raw `mts_block_t`
43    ///
44    /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
45    /// Make sure the lifetime is actually constrained by the lifetime of the
46    /// owner of this `mts_block_t`.
47    pub(crate) unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
48        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
49
50        TensorBlockRef {
51            ptr: ptr,
52            marker: std::marker::PhantomData,
53        }
54    }
55
56    /// Get the underlying raw pointer
57    pub(crate) fn as_ptr(&self) -> *const mts_block_t {
58        self.ptr
59    }
60}
61
62/// Get a gradient from this block
63fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
64    let mut gradient_block = std::ptr::null_mut();
65    let status = unsafe { crate::c_api::mts_block_gradient(
66            // the cast to mut pointer is fine since we are only returning a
67            // non-mut mts_block_t below
68            block.cast_mut(),
69            parameter.as_ptr(),
70            &mut gradient_block
71        )
72    };
73
74    match crate::errors::check_status(status) {
75        Ok(()) => Some(gradient_block.cast_const()),
76        Err(error) => {
77            if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
78                // there is no array for this gradient
79                None
80            } else {
81                panic!("failed to get the gradient from a block: {:?}", error)
82            }
83        }
84    }
85}
86
87pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
88    unsafe {
89        TensorBlockRef::from_raw(ptr).samples()
90    }
91}
92
93pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
94    unsafe {
95        TensorBlockRef::from_raw(ptr).components()
96    }
97}
98
99pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
100    unsafe {
101        TensorBlockRef::from_raw(ptr).properties()
102    }
103}
104
105impl<'a> TensorBlockRef<'a> {
106    /// Get all the data and metadata inside this `TensorBlockRef` as a
107    /// struct with separate fields, to allow borrowing them separately.
108    #[inline]
109    pub fn data(&'a self) -> TensorBlockData<'a> {
110        TensorBlockData {
111            values: self.values(),
112            samples: LazyMetadata::new(get_samples, self.as_ptr()),
113            components: LazyMetadata::new(get_components, self.as_ptr()),
114            properties: LazyMetadata::new(get_properties, self.as_ptr()),
115        }
116    }
117
118    /// Get the array for the values in this block
119    #[inline]
120    pub fn values(&self) -> ArrayRef<'a> {
121        let mut array = mts_array_t::null();
122        unsafe {
123            crate::errors::check_status(crate::c_api::mts_block_data(
124                self.as_ptr().cast_mut(),
125                &mut array
126            )).expect("failed to get the array for a block");
127        };
128
129        // SAFETY: we can return an `ArrayRef` with lifetime `'a` (instead of
130        // `'self`) (which allows to get multiple references to the BasicBlock
131        // simultaneously), because there is no way to also get a mutable
132        // reference to the block at the same time (since we are already holding
133        // a const reference to the block itself).
134        unsafe { ArrayRef::from_raw(array) }
135    }
136
137    #[inline]
138    fn labels(&self, dimension: usize) -> Labels {
139        let mut labels = mts_labels_t::null();
140        unsafe {
141            check_status(crate::c_api::mts_block_labels(
142                self.as_ptr(),
143                dimension,
144                &mut labels,
145            )).expect("failed to get labels");
146        }
147        return unsafe { Labels::from_raw(labels) };
148    }
149
150    /// Get the samples for this block
151    #[inline]
152    pub fn samples(&self) -> Labels {
153        return self.labels(0);
154    }
155
156    /// Get the components for this block
157    #[inline]
158    pub fn components(&self) -> Vec<Labels> {
159        let values = self.values();
160        let shape = values.as_raw().shape().expect("failed to get the data shape");
161
162        let mut result = Vec::new();
163        for i in 1..(shape.len() - 1) {
164            result.push(self.labels(i));
165        }
166        return result;
167    }
168
169    /// Get the properties for this block
170    #[inline]
171    pub fn properties(&self) -> Labels {
172        let values = self.values();
173        let shape = values.as_raw().shape().expect("failed to get the data shape");
174
175        return self.labels(shape.len() - 1);
176    }
177
178    /// Get the full list of gradients in this block
179    #[inline]
180    pub fn gradient_list(&self) -> Vec<&'a str> {
181        let mut parameters_ptr = std::ptr::null();
182        let mut parameters_count = 0;
183        unsafe {
184            check_status(crate::c_api::mts_block_gradients_list(
185                self.as_ptr(),
186                &mut parameters_ptr,
187                &mut parameters_count
188            )).expect("failed to get gradient list");
189        }
190
191        if parameters_count == 0 {
192            return Vec::new();
193        } else {
194            assert!(!parameters_ptr.is_null());
195            // SAFETY: we can return strings with the `'a` lifetime (instead of
196            // `'self`), because there is no way to also get a mutable reference
197            // to the gradient parameters at the same time.
198            unsafe {
199                let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
200                return parameters.iter()
201                    .map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
202                    .collect();
203            }
204        }
205    }
206
207    /// Get the data and metadata for the gradient with respect to the given
208    /// parameter in this block, if it exists.
209    #[inline]
210    pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
211        // SAFETY: we can return a TensorBlockRef with lifetime `'a` (instead of
212        // `'self`) for the same reasons as in the `values` function.
213        let parameter = CString::new(parameter).expect("invalid C string");
214
215        block_gradient(self.as_ptr(), &parameter)
216            .map(|gradient_block| {
217                // SAFETY: the lifetime of the block is the same as
218                // the lifetime of self, both are constrained to the
219                // root TensorMap/TensorBlock
220                unsafe { TensorBlockRef::from_raw(gradient_block) }
221        })
222    }
223
224    /// Clone this block, cloning all the data and metadata contained inside.
225    ///
226    /// This can fail if the external data held inside an `mts_array_t` can not
227    /// be cloned.
228    #[inline]
229    pub fn try_clone(&self) -> Result<TensorBlock, Error> {
230        let ptr = unsafe {
231            crate::c_api::mts_block_copy(self.as_ptr())
232        };
233        crate::errors::check_ptr(ptr)?;
234
235        return Ok(unsafe { TensorBlock::from_raw(ptr) });
236    }
237
238    /// Get an iterator over parameter/[`TensorBlockRef`] pairs for all gradients in
239    /// this block
240    #[inline]
241    pub fn gradients(&self) -> GradientsIter<'_> {
242        GradientsIter {
243            parameters: self.gradient_list().into_iter(),
244            block: self.as_ptr(),
245        }
246    }
247
248    /// Save the given block to the file at `path`
249    ///
250    /// This is a convenience function calling [`crate::io::save_block`]
251    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
252        return crate::io::save_block(path, *self);
253    }
254
255    /// Save the given block to an in-memory buffer
256    ///
257    /// This is a convenience function calling [`crate::io::save_block_buffer`]
258    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
259        return crate::io::save_block_buffer(*self, buffer);
260    }
261}
262
263/// Iterator over parameter/[`TensorBlockRef`] pairs for all gradients in a
264/// [`TensorBlockRef`]
265pub struct GradientsIter<'a> {
266    parameters: std::vec::IntoIter<&'a str>,
267    block: *const mts_block_t,
268}
269
270impl<'a> Iterator for GradientsIter<'a> {
271    type Item = (&'a str, TensorBlockRef<'a>);
272
273    #[inline]
274    fn next(&mut self) -> Option<Self::Item> {
275        self.parameters.next().map(|parameter| {
276            let parameter_c = CString::new(parameter).expect("invalid C string");
277            let block = block_gradient(self.block, &parameter_c).expect("missing gradient");
278
279            // SAFETY: the lifetime of the block is the same as the lifetime of
280            // the GradientsIter, both are constrained to the root
281            // TensorMap/TensorBlock
282            let block = unsafe { TensorBlockRef::from_raw(block) };
283            return (parameter, block);
284        })
285    }
286
287    fn size_hint(&self) -> (usize, Option<usize>) {
288        (self.len(), Some(self.len()))
289    }
290}
291
292impl ExactSizeIterator for GradientsIter<'_> {
293    #[inline]
294    fn len(&self) -> usize {
295        self.parameters.len()
296    }
297}
298
299impl FusedIterator for GradientsIter<'_> {}
300
301#[cfg(test)]
302mod tests {
303    // TODO: check gradient/gradient iter code
304}