Skip to main content

metatensor/block/
block_ref.rs

1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t};
5use crate::c_api::MTS_INVALID_PARAMETER_ERROR;
6
7use crate::errors::check_status;
8use crate::{ArrayRef, Labels, Error};
9
10use super::{TensorBlock, LazyMetadata};
11
12/// Reference to a [`TensorBlock`]
13#[derive(Debug, Clone, Copy)]
14pub struct TensorBlockRef<'a> {
15    ptr: *const mts_block_t,
16    marker: std::marker::PhantomData<&'a mts_block_t>,
17}
18
19// SAFETY: Send is fine since TensorBlockRef does not implement Drop
20unsafe impl Send for TensorBlockRef<'_> {}
21// SAFETY: Sync is fine since there is no internal mutability in TensorBlockRef
22unsafe impl Sync for TensorBlockRef<'_> {}
23
24/// All the basic data in a `TensorBlockRef` as a struct with separate fields.
25///
26/// This can be useful when you need to borrow different fields on this struct
27/// separately. They are separate in the underlying metatensor-core, but since
28/// we go through the C API to access them, we need to re-expose them as
29/// separate fields for the rust compiler to be able to understand that.
30///
31/// The metadata is initialized lazily on first access, to not pay the cost of
32/// allocation/reference count increase if some metadata is not used.
33#[derive(Debug)]
34pub struct TensorBlockData<'a> {
35    pub values: ArrayRef<'a>,
36    pub samples: LazyMetadata<Labels>,
37    pub components: LazyMetadata<Vec<Labels>>,
38    pub properties: LazyMetadata<Labels>,
39}
40
41impl<'a> TensorBlockRef<'a> {
42    /// Create a new `TensorBlockRef` from the given raw `mts_block_t`
43    ///
44    /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
45    /// Make sure the lifetime is actually constrained by the lifetime of the
46    /// owner of this `mts_block_t`.
47    pub unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
48        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
49
50        TensorBlockRef {
51            ptr: ptr,
52            marker: std::marker::PhantomData,
53        }
54    }
55
56    /// Get the underlying raw pointer
57    pub fn as_ptr(&self) -> *const mts_block_t {
58        self.ptr
59    }
60}
61
62/// Get a gradient from this block
63fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
64    let mut gradient_block = std::ptr::null_mut();
65    let status = unsafe { crate::c_api::mts_block_gradient(
66            // the cast to mut pointer is fine since we are only returning a
67            // non-mut mts_block_t below
68            block.cast_mut(),
69            parameter.as_ptr(),
70            &mut gradient_block
71        )
72    };
73
74    match crate::errors::check_status(status) {
75        Ok(()) => Some(gradient_block.cast_const()),
76        Err(error) => {
77            if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
78                // there is no array for this gradient
79                None
80            } else {
81                panic!("failed to get the gradient from a block: {:?}", error)
82            }
83        }
84    }
85}
86
87pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
88    unsafe {
89        TensorBlockRef::from_raw(ptr).samples()
90    }
91}
92
93pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
94    unsafe {
95        TensorBlockRef::from_raw(ptr).components()
96    }
97}
98
99pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
100    unsafe {
101        TensorBlockRef::from_raw(ptr).properties()
102    }
103}
104
105impl<'a> TensorBlockRef<'a> {
106    /// Get the device on which the values of this block are stored.
107    #[inline]
108    pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
109        let mut device = dlpk::sys::DLDevice::cpu();
110        unsafe {
111            check_status(crate::c_api::mts_block_device(
112                self.as_ptr(),
113                &mut device,
114            ))?;
115        }
116        return Ok(device);
117    }
118
119    /// Get the data type of the values of this block.
120    #[inline]
121    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
122        let mut dtype = dlpk::sys::DLDataType {
123            code: dlpk::sys::DLDataTypeCode::kDLFloat,
124            bits: 0,
125            lanes: 0,
126        };
127        unsafe {
128            check_status(crate::c_api::mts_block_dtype(
129                self.as_ptr(),
130                &mut dtype,
131            ))?;
132        }
133        return Ok(dtype);
134    }
135
136    /// Get all the data and metadata inside this `TensorBlockRef` as a
137    /// struct with separate fields, to allow borrowing them separately.
138    #[inline]
139    pub fn data(&'a self) -> TensorBlockData<'a> {
140        TensorBlockData {
141            values: self.values(),
142            samples: LazyMetadata::new(get_samples, self.as_ptr()),
143            components: LazyMetadata::new(get_components, self.as_ptr()),
144            properties: LazyMetadata::new(get_properties, self.as_ptr()),
145        }
146    }
147
148    /// Get the array for the values in this block
149    #[inline]
150    pub fn values(&self) -> ArrayRef<'a> {
151        let mut array = mts_array_t::null();
152        unsafe {
153            crate::errors::check_status(crate::c_api::mts_block_data(
154                self.as_ptr().cast_mut(),
155                &mut array
156            )).expect("failed to get the array for a block");
157        };
158
159        // SAFETY: we can return an `ArrayRef` with lifetime `'a` (instead of
160        // `'self`) (which allows to get multiple references to the BasicBlock
161        // simultaneously), because there is no way to also get a mutable
162        // reference to the block at the same time (since we are already holding
163        // a const reference to the block itself).
164        unsafe { ArrayRef::from_raw(array) }
165    }
166
167    #[inline]
168    fn labels(&self, dimension: usize) -> Labels {
169        let ptr = unsafe {
170            crate::c_api::mts_block_labels(self.as_ptr(), dimension)
171        };
172        crate::errors::check_ptr(ptr).expect("failed to get labels");
173
174        unsafe { Labels::from_raw(ptr) }
175    }
176
177    /// Get the samples for this block
178    #[inline]
179    pub fn samples(&self) -> Labels {
180        return self.labels(0);
181    }
182
183    /// Get the components for this block
184    #[inline]
185    pub fn components(&self) -> Vec<Labels> {
186        let values = self.values();
187        let shape = values.shape().expect("failed to get the data shape");
188
189        let mut result = Vec::new();
190        for i in 1..(shape.len() - 1) {
191            result.push(self.labels(i));
192        }
193        return result;
194    }
195
196    /// Get the properties for this block
197    #[inline]
198    pub fn properties(&self) -> Labels {
199        let values = self.values();
200        let shape = values.shape().expect("failed to get the data shape");
201
202        return self.labels(shape.len() - 1);
203    }
204
205    /// Get the full list of gradients in this block
206    #[inline]
207    pub fn gradient_list(&self) -> Vec<&'a str> {
208        let mut parameters_ptr = std::ptr::null();
209        let mut parameters_count = 0;
210        unsafe {
211            check_status(crate::c_api::mts_block_gradients_list(
212                self.as_ptr(),
213                &mut parameters_ptr,
214                &mut parameters_count
215            )).expect("failed to get gradient list");
216        }
217
218        if parameters_count == 0 {
219            return Vec::new();
220        } else {
221            assert!(!parameters_ptr.is_null());
222            // SAFETY: we can return strings with the `'a` lifetime (instead of
223            // `'self`), because there is no way to also get a mutable reference
224            // to the gradient parameters at the same time.
225            unsafe {
226                let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
227                return parameters.iter()
228                    .map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
229                    .collect();
230            }
231        }
232    }
233
234    /// Get the data and metadata for the gradient with respect to the given
235    /// parameter in this block, if it exists.
236    #[inline]
237    pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
238        // SAFETY: we can return a TensorBlockRef with lifetime `'a` (instead of
239        // `'self`) for the same reasons as in the `values` function.
240        let parameter = CString::new(parameter).expect("invalid C string");
241
242        block_gradient(self.as_ptr(), &parameter)
243            .map(|gradient_block| {
244                // SAFETY: the lifetime of the block is the same as
245                // the lifetime of self, both are constrained to the
246                // root TensorMap/TensorBlock
247                unsafe { TensorBlockRef::from_raw(gradient_block) }
248        })
249    }
250
251    /// Clone this block, cloning all the data and metadata contained inside.
252    ///
253    /// This can fail if the external data held inside an `mts_array_t` can not
254    /// be cloned.
255    #[inline]
256    pub fn try_clone(&self) -> Result<TensorBlock, Error> {
257        let ptr = unsafe {
258            crate::c_api::mts_block_copy(self.as_ptr())
259        };
260        crate::errors::check_ptr(ptr)?;
261
262        return Ok(unsafe { TensorBlock::from_raw(ptr) });
263    }
264
265    /// Get an iterator over parameter/[`TensorBlockRef`] pairs for all gradients in
266    /// this block
267    #[inline]
268    pub fn gradients(&self) -> GradientsIter<'_> {
269        GradientsIter {
270            parameters: self.gradient_list().into_iter(),
271            block: self.as_ptr(),
272        }
273    }
274
275    /// Save the given block to the file at `path`
276    ///
277    /// This is a convenience function calling [`crate::io::save_block`]
278    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
279        return crate::io::save_block(path, *self);
280    }
281
282    /// Save the given block to an in-memory buffer
283    ///
284    /// This is a convenience function calling [`crate::io::save_block_buffer`]
285    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
286        return crate::io::save_block_buffer(*self, buffer);
287    }
288}
289
290/// Iterator over parameter/[`TensorBlockRef`] pairs for all gradients in a
291/// [`TensorBlockRef`]
292pub struct GradientsIter<'a> {
293    parameters: std::vec::IntoIter<&'a str>,
294    block: *const mts_block_t,
295}
296
297impl<'a> Iterator for GradientsIter<'a> {
298    type Item = (&'a str, TensorBlockRef<'a>);
299
300    #[inline]
301    fn next(&mut self) -> Option<Self::Item> {
302        self.parameters.next().map(|parameter| {
303            let parameter_c = CString::new(parameter).expect("invalid C string");
304            let block = block_gradient(self.block, &parameter_c).expect("missing gradient");
305
306            // SAFETY: the lifetime of the block is the same as the lifetime of
307            // the GradientsIter, both are constrained to the root
308            // TensorMap/TensorBlock
309            let block = unsafe { TensorBlockRef::from_raw(block) };
310            return (parameter, block);
311        })
312    }
313
314    fn size_hint(&self) -> (usize, Option<usize>) {
315        (self.len(), Some(self.len()))
316    }
317}
318
319impl ExactSizeIterator for GradientsIter<'_> {
320    #[inline]
321    fn len(&self) -> usize {
322        self.parameters.len()
323    }
324}
325
326impl FusedIterator for GradientsIter<'_> {}
327
328#[cfg(test)]
329mod tests {
330    use crate::{Labels, TensorBlock};
331
332    #[test]
333    #[allow(clippy::float_cmp)]
334    fn gradients() {
335        let properties = Labels::new(["p"], [[-2], [0], [1]]);
336        let mut block = TensorBlock::new(
337            ndarray::Array::from_elem(vec![2, 3], 1.0),
338            &Labels::new(["s"], [[0], [1]]), &[], &properties,
339        ).unwrap();
340
341        block.add_gradient("g", TensorBlock::new(
342            ndarray::Array::from_elem(vec![2, 3], -1.0),
343            &Labels::new(["sample"], [[0], [1]]), &[], &properties,
344        ).unwrap()).unwrap();
345
346        block.add_gradient("f", TensorBlock::new(
347            ndarray::Array::from_elem(vec![2, 3], -2.0),
348            &Labels::new(["sample"], [[0], [1]]), &[], &properties,
349        ).unwrap()).unwrap();
350
351
352        let block = block.as_ref();
353        let gradient = block.gradient("g").unwrap();
354        let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
355        assert_eq!(values[[0, 0]], -1.0);
356
357        let gradient = block.gradient("f").unwrap();
358        let values = gradient.values().to_ndarray_lock::<f64>().read().unwrap();
359        assert_eq!(values[[0, 0]], -2.0);
360
361        assert!(block.gradient("h").is_none());
362
363        let mut iter = block.gradients();
364        assert_eq!(iter.len(), 2);
365
366        assert_eq!(iter.next().unwrap().0, "g");
367        assert_eq!(iter.next().unwrap().0, "f");
368        assert!(iter.next().is_none());
369    }
370
371    #[test]
372    fn block_data() {
373        let block = TensorBlock::new(
374            ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
375            &Labels::new(["samples"], [[0], [1]]),
376            &[Labels::new(["component"], [[0]])],
377            &Labels::new(["properties"], [[-2], [0], [1]]),
378        ).unwrap();
379        let block = block.as_ref();
380
381        let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
382        assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
383        assert_eq!(block.samples(), Labels::new(["samples"], [[0], [1]]));
384        assert_eq!(block.components(), [Labels::new(["component"], [[0]])]);
385        assert_eq!(block.properties(), Labels::new(["properties"], [[-2], [0], [1]]));
386
387        let block = block.data();
388        let values = block.values.to_ndarray_lock::<f64>().read().unwrap();
389        assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
390        assert_eq!(*block.samples, Labels::new(["samples"], [[0], [1]]));
391        assert_eq!(*block.components, [Labels::new(["component"], [[0]])]);
392        assert_eq!(*block.properties, Labels::new(["properties"], [[-2], [0], [1]]));
393    }
394
395    #[test]
396    fn device_and_dtype() {
397        let block = TensorBlock::new(
398            ndarray::Array::from_elem(vec![2, 3], 1.0),
399            &Labels::new(["samples"], [[0], [1]]),
400            &[],
401            &Labels::new(["properties"], [[-2], [0], [1]]),
402        ).unwrap();
403        let block = block.as_ref();
404
405        let device = block.device().unwrap();
406        assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
407
408        let dtype = block.dtype().unwrap();
409        assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
410        assert_eq!(dtype.bits, 64);
411    }
412}