metatensor/block/
block_mut.rs1use std::ffi::{CString, CStr};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
5use crate::{ArrayRef, ArrayRefMut, Labels, Error};
6
7use super::{TensorBlockRef, LazyMetadata};
8use super::block_ref::{get_samples, get_components, get_properties};
9
10#[derive(Debug)]
12pub struct TensorBlockRefMut<'a> {
13 ptr: *mut mts_block_t,
14 marker: std::marker::PhantomData<&'a mut mts_block_t>,
15}
16
17unsafe impl Send for TensorBlockRefMut<'_> {}
19unsafe impl Sync for TensorBlockRefMut<'_> {}
22
23#[derive(Debug)]
33pub struct TensorBlockDataMut<'a> {
34 pub values: ArrayRefMut<'a>,
35 pub samples: LazyMetadata<Labels>,
36 pub components: LazyMetadata<Vec<Labels>>,
37 pub properties: LazyMetadata<Labels>,
38}
39
40fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
42 let mut gradient_block = std::ptr::null_mut();
43 let status = unsafe { crate::c_api::mts_block_gradient(
44 block,
45 parameter.as_ptr(),
46 &mut gradient_block
47 )
48 };
49
50 match crate::errors::check_status(status) {
51 Ok(()) => Some(gradient_block),
52 Err(error) => {
53 if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
54 None
56 } else {
57 panic!("failed to get the gradient from a block: {:?}", error)
58 }
59 }
60 }
61}
62
63impl<'a> TensorBlockRefMut<'a> {
64 pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
72 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
73
74 TensorBlockRefMut {
75 ptr: ptr,
76 marker: std::marker::PhantomData,
77 }
78 }
79
80 pub(super) fn as_ptr(&self) -> *const mts_block_t {
82 self.ptr
83 }
84
85 pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
87 self.ptr
88 }
89
90 #[inline]
92 pub fn as_ref(&self) -> TensorBlockRef<'_> {
93 unsafe {
94 TensorBlockRef::from_raw(self.as_ptr())
95 }
96 }
97
98 #[inline]
101 pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
102 let samples = LazyMetadata::new(get_samples, self.as_ptr());
103 let components = LazyMetadata::new(get_components, self.as_ptr());
104 let properties = LazyMetadata::new(get_properties, self.as_ptr());
105
106 TensorBlockDataMut {
107 values: self.values_mut(),
109 samples: samples,
110 components: components,
111 properties: properties,
112 }
113 }
114
115 #[inline]
117 pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
118 let mut array = mts_array_t::null();
119 unsafe {
120 crate::errors::check_status(crate::c_api::mts_block_data(
121 self.as_mut_ptr(),
122 &mut array
123 )).expect("failed to get the array for a block");
124 };
125
126 unsafe { ArrayRefMut::new(array) }
128 }
129
130 #[inline]
132 pub fn values(&self) -> ArrayRef<'_> {
133 return self.as_ref().values();
134 }
135
136 #[inline]
138 pub fn samples(&self) -> Labels {
139 return self.as_ref().samples();
140 }
141
142 #[inline]
144 pub fn components(&self) -> Vec<Labels> {
145 return self.as_ref().components();
146 }
147
148 #[inline]
150 pub fn properties(&self) -> Labels {
151 return self.as_ref().properties();
152 }
153
154 #[inline]
157 pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
158 let parameter = CString::new(parameter).expect("invalid C string");
159
160 block_gradient(self.as_mut_ptr(), ¶meter)
161 .map(|gradient_block| {
162 unsafe { TensorBlockRefMut::from_raw(gradient_block) }
165 })
166 }
167
168 #[inline]
171 pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
172 let block_ptr = self.as_mut_ptr();
173 GradientsMutIter {
174 parameters: self.as_ref().gradient_list().into_iter(),
175 block: block_ptr,
176 }
177 }
178
179 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
183 self.as_ref().save(path)
184 }
185
186 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
190 self.as_ref().save_buffer(buffer)
191 }
192}
193
194pub struct GradientsMutIter<'a> {
197 parameters: std::vec::IntoIter<&'a str>,
198 block: *mut mts_block_t,
199}
200
201impl<'a> Iterator for GradientsMutIter<'a> {
202 type Item = (&'a str, TensorBlockRefMut<'a>);
203
204 #[inline]
205 fn next(&mut self) -> Option<Self::Item> {
206 self.parameters.next().map(|parameter| {
207 let parameter_c = CString::new(parameter).expect("invalid C string");
208 let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
209
210 let block = unsafe { TensorBlockRefMut::from_raw(block) };
214 return (parameter, block);
215 })
216 }
217
218 fn size_hint(&self) -> (usize, Option<usize>) {
219 self.parameters.size_hint()
220 }
221}
222
223impl ExactSizeIterator for GradientsMutIter<'_> {
224 #[inline]
225 fn len(&self) -> usize {
226 self.parameters.len()
227 }
228}
229
230impl FusedIterator for GradientsMutIter<'_> {}
231
232#[cfg(test)]
233mod tests {
234 }