metatensor/block/
block_mut.rs

1use std::ffi::{CString, CStr};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
5use crate::{ArrayRef, ArrayRefMut, Labels, Error};
6
7use super::{TensorBlockRef, LazyMetadata};
8use super::block_ref::{get_samples, get_components, get_properties};
9
10/// Mutable reference to a [`TensorBlock`](crate::TensorBlock)
11#[derive(Debug)]
12pub struct TensorBlockRefMut<'a> {
13    ptr: *mut mts_block_t,
14    marker: std::marker::PhantomData<&'a mut mts_block_t>,
15}
16
17// SAFETY: Send is fine since TensorBlockRefMut does not implement Drop
18unsafe impl Send for TensorBlockRefMut<'_> {}
19// SAFETY: Sync is fine since there is no internal mutability in TensorBlockRefMut
20// (all mutations still require a `&mut TensorBlockRefMut`)
21unsafe impl Sync for TensorBlockRefMut<'_> {}
22
23/// All the basic data in a `TensorBlockRefMut` as a struct with separate fields.
24///
25/// This can be useful when you need to borrow different fields on this struct
26/// separately. They are separate in the underlying metatensor-core, but since we
27/// go through the C API to access them, we need to re-expose them as separate
28/// fields for the rust compiler to be able to understand that.
29///
30/// The metadata is initialized lazily on first access, to not pay the cost of
31/// allocation/reference count increase if some metadata is not used.
32#[derive(Debug)]
33pub struct TensorBlockDataMut<'a> {
34    pub values: ArrayRefMut<'a>,
35    pub samples: LazyMetadata<Labels>,
36    pub components: LazyMetadata<Vec<Labels>>,
37    pub properties: LazyMetadata<Labels>,
38}
39
40/// Get a gradient from this block
41fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
42    let mut gradient_block = std::ptr::null_mut();
43    let status = unsafe { crate::c_api::mts_block_gradient(
44            block,
45            parameter.as_ptr(),
46            &mut gradient_block
47        )
48    };
49
50    match crate::errors::check_status(status) {
51        Ok(()) => Some(gradient_block),
52        Err(error) => {
53            if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
54                // there is no array for this gradient
55                None
56            } else {
57                panic!("failed to get the gradient from a block: {:?}", error)
58            }
59        }
60    }
61}
62
63impl<'a> TensorBlockRefMut<'a> {
64    /// Create a new `TensorBlockRefMut` from the given raw `mts_block_t`
65    ///
66    /// This is a **VERY** unsafe function, creating a lifetime out of thin air,
67    /// and allowing mutable access to the `mts_block_t`. Make sure the lifetime
68    /// is actually constrained by the lifetime of the owner of this
69    /// `mts_block_t`; and that the owner is mutably borrowed by this
70    /// `TensorBlockRefMut`.
71    pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
72        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
73
74        TensorBlockRefMut {
75            ptr: ptr,
76            marker: std::marker::PhantomData,
77        }
78    }
79
80    /// Get the underlying raw pointer
81    pub(super) fn as_ptr(&self) -> *const mts_block_t {
82        self.ptr
83    }
84
85    /// Get the underlying (mutable) raw pointer
86    pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
87        self.ptr
88    }
89
90    /// Get a non mutable reference to this block
91    #[inline]
92    pub fn as_ref(&self) -> TensorBlockRef<'_> {
93        unsafe {
94            TensorBlockRef::from_raw(self.as_ptr())
95        }
96    }
97
98    /// Get all the data and metadata inside this `TensorBlockRefMut` as a
99    /// struct with separate fields, to allow borrowing them separately.
100    #[inline]
101    pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
102        let samples = LazyMetadata::new(get_samples, self.as_ptr());
103        let components = LazyMetadata::new(get_components, self.as_ptr());
104        let properties = LazyMetadata::new(get_properties, self.as_ptr());
105
106        TensorBlockDataMut {
107            // SAFETY: we are returning an `ArrayRefMut` mutably borrowing from `self`
108            values: self.values_mut(),
109            samples: samples,
110            components: components,
111            properties: properties,
112        }
113    }
114
115    /// Get a mutable reference to the values in this block
116    #[inline]
117    pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
118        let mut array = mts_array_t::null();
119        unsafe {
120            crate::errors::check_status(crate::c_api::mts_block_data(
121                self.as_mut_ptr(),
122                &mut array
123            )).expect("failed to get the array for a block");
124        };
125
126        // SAFETY: we are returning an `ArrayRefMut` mutably borrowing from `self`
127        unsafe { ArrayRefMut::new(array) }
128    }
129
130    /// Get the array for the values in this block
131    #[inline]
132    pub fn values(&self) -> ArrayRef<'_> {
133        return self.as_ref().values();
134    }
135
136    /// Get the samples for this block
137    #[inline]
138    pub fn samples(&self) -> Labels {
139        return self.as_ref().samples();
140    }
141
142    /// Get the components for this block
143    #[inline]
144    pub fn components(&self) -> Vec<Labels> {
145        return self.as_ref().components();
146    }
147
148    /// Get the properties for this block
149    #[inline]
150    pub fn properties(&self) -> Labels {
151        return self.as_ref().properties();
152    }
153
154    /// Get a mutable reference to the data and metadata for the gradient with
155    /// respect to the given parameter in this block, if it exists.
156    #[inline]
157    pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
158        let parameter = CString::new(parameter).expect("invalid C string");
159
160        block_gradient(self.as_mut_ptr(), &parameter)
161            .map(|gradient_block| {
162                // SAFETY: we are returning an `TensorBlockRefMut` mutably
163                // borrowing from `self`
164                unsafe { TensorBlockRefMut::from_raw(gradient_block) }
165            })
166    }
167
168    /// Get an iterator over parameter/[`TensorBlockRefMut`] pairs for all gradients
169    /// in this block
170    #[inline]
171    pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
172        let block_ptr = self.as_mut_ptr();
173        GradientsMutIter {
174            parameters: self.as_ref().gradient_list().into_iter(),
175            block: block_ptr,
176        }
177    }
178
179    /// Save the given block to the file at `path`
180    ///
181    /// This is a convenience function calling [`crate::io::save_block`]
182    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
183        self.as_ref().save(path)
184    }
185
186    /// Save the given block to an in-memory buffer
187    ///
188    /// This is a convenience function calling [`crate::io::save_block_buffer`]
189    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
190        self.as_ref().save_buffer(buffer)
191    }
192}
193
194/// Iterator over parameter/[`TensorBlockRefMut`] pairs for all gradients in a
195/// [`TensorBlockRefMut`]
196pub struct GradientsMutIter<'a> {
197    parameters: std::vec::IntoIter<&'a str>,
198    block: *mut mts_block_t,
199}
200
201impl<'a> Iterator for GradientsMutIter<'a> {
202    type Item = (&'a str, TensorBlockRefMut<'a>);
203
204    #[inline]
205    fn next(&mut self) -> Option<Self::Item> {
206        self.parameters.next().map(|parameter| {
207            let parameter_c = CString::new(parameter).expect("invalid C string");
208            let block = block_gradient(self.block, &parameter_c).expect("missing gradient");
209
210            // SAFETY: all blocks are disjoint, and we are only returning a
211            // mutable reference to each once. The reference lifetime is
212            // constrained by the lifetime of the parent TensorBlockRefMut
213            let block = unsafe { TensorBlockRefMut::from_raw(block) };
214            return (parameter, block);
215        })
216    }
217
218    fn size_hint(&self) -> (usize, Option<usize>) {
219        self.parameters.size_hint()
220    }
221}
222
223impl ExactSizeIterator for GradientsMutIter<'_> {
224    #[inline]
225    fn len(&self) -> usize {
226        self.parameters.len()
227    }
228}
229
230impl FusedIterator for GradientsMutIter<'_> {}
231
232#[cfg(test)]
233mod tests {
234    // TODO
235}