metatensor/block/
block_ref.rs1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::c_api::{mts_block_t, mts_array_t, mts_labels_t};
5use crate::c_api::MTS_INVALID_PARAMETER_ERROR;
6
7use crate::errors::check_status;
8use crate::{ArrayRef, Labels, Error};
9
10use super::{TensorBlock, LazyMetadata};
11
12#[derive(Debug, Clone, Copy)]
14pub struct TensorBlockRef<'a> {
15 ptr: *const mts_block_t,
16 marker: std::marker::PhantomData<&'a mts_block_t>,
17}
18
19unsafe impl Send for TensorBlockRef<'_> {}
21unsafe impl Sync for TensorBlockRef<'_> {}
23
24#[derive(Debug)]
34pub struct TensorBlockData<'a> {
35 pub values: ArrayRef<'a>,
36 pub samples: LazyMetadata<Labels>,
37 pub components: LazyMetadata<Vec<Labels>>,
38 pub properties: LazyMetadata<Labels>,
39}
40
41impl<'a> TensorBlockRef<'a> {
42 pub(crate) unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
48 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
49
50 TensorBlockRef {
51 ptr: ptr,
52 marker: std::marker::PhantomData,
53 }
54 }
55
56 pub(crate) fn as_ptr(&self) -> *const mts_block_t {
58 self.ptr
59 }
60}
61
62fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
64 let mut gradient_block = std::ptr::null_mut();
65 let status = unsafe { crate::c_api::mts_block_gradient(
66 block.cast_mut(),
69 parameter.as_ptr(),
70 &mut gradient_block
71 )
72 };
73
74 match crate::errors::check_status(status) {
75 Ok(()) => Some(gradient_block.cast_const()),
76 Err(error) => {
77 if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
78 None
80 } else {
81 panic!("failed to get the gradient from a block: {:?}", error)
82 }
83 }
84 }
85}
86
87pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
88 unsafe {
89 TensorBlockRef::from_raw(ptr).samples()
90 }
91}
92
93pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
94 unsafe {
95 TensorBlockRef::from_raw(ptr).components()
96 }
97}
98
99pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
100 unsafe {
101 TensorBlockRef::from_raw(ptr).properties()
102 }
103}
104
105impl<'a> TensorBlockRef<'a> {
106 #[inline]
109 pub fn data(&'a self) -> TensorBlockData<'a> {
110 TensorBlockData {
111 values: self.values(),
112 samples: LazyMetadata::new(get_samples, self.as_ptr()),
113 components: LazyMetadata::new(get_components, self.as_ptr()),
114 properties: LazyMetadata::new(get_properties, self.as_ptr()),
115 }
116 }
117
118 #[inline]
120 pub fn values(&self) -> ArrayRef<'a> {
121 let mut array = mts_array_t::null();
122 unsafe {
123 crate::errors::check_status(crate::c_api::mts_block_data(
124 self.as_ptr().cast_mut(),
125 &mut array
126 )).expect("failed to get the array for a block");
127 };
128
129 unsafe { ArrayRef::from_raw(array) }
135 }
136
137 #[inline]
138 fn labels(&self, dimension: usize) -> Labels {
139 let mut labels = mts_labels_t::null();
140 unsafe {
141 check_status(crate::c_api::mts_block_labels(
142 self.as_ptr(),
143 dimension,
144 &mut labels,
145 )).expect("failed to get labels");
146 }
147 return unsafe { Labels::from_raw(labels) };
148 }
149
150 #[inline]
152 pub fn samples(&self) -> Labels {
153 return self.labels(0);
154 }
155
156 #[inline]
158 pub fn components(&self) -> Vec<Labels> {
159 let values = self.values();
160 let shape = values.as_raw().shape().expect("failed to get the data shape");
161
162 let mut result = Vec::new();
163 for i in 1..(shape.len() - 1) {
164 result.push(self.labels(i));
165 }
166 return result;
167 }
168
169 #[inline]
171 pub fn properties(&self) -> Labels {
172 let values = self.values();
173 let shape = values.as_raw().shape().expect("failed to get the data shape");
174
175 return self.labels(shape.len() - 1);
176 }
177
178 #[inline]
180 pub fn gradient_list(&self) -> Vec<&'a str> {
181 let mut parameters_ptr = std::ptr::null();
182 let mut parameters_count = 0;
183 unsafe {
184 check_status(crate::c_api::mts_block_gradients_list(
185 self.as_ptr(),
186 &mut parameters_ptr,
187 &mut parameters_count
188 )).expect("failed to get gradient list");
189 }
190
191 if parameters_count == 0 {
192 return Vec::new();
193 } else {
194 assert!(!parameters_ptr.is_null());
195 unsafe {
199 let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
200 return parameters.iter()
201 .map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
202 .collect();
203 }
204 }
205 }
206
207 #[inline]
210 pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
211 let parameter = CString::new(parameter).expect("invalid C string");
214
215 block_gradient(self.as_ptr(), ¶meter)
216 .map(|gradient_block| {
217 unsafe { TensorBlockRef::from_raw(gradient_block) }
221 })
222 }
223
224 #[inline]
229 pub fn try_clone(&self) -> Result<TensorBlock, Error> {
230 let ptr = unsafe {
231 crate::c_api::mts_block_copy(self.as_ptr())
232 };
233 crate::errors::check_ptr(ptr)?;
234
235 return Ok(unsafe { TensorBlock::from_raw(ptr) });
236 }
237
238 #[inline]
241 pub fn gradients(&self) -> GradientsIter<'_> {
242 GradientsIter {
243 parameters: self.gradient_list().into_iter(),
244 block: self.as_ptr(),
245 }
246 }
247
248 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
252 return crate::io::save_block(path, *self);
253 }
254
255 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
259 return crate::io::save_block_buffer(*self, buffer);
260 }
261}
262
263pub struct GradientsIter<'a> {
266 parameters: std::vec::IntoIter<&'a str>,
267 block: *const mts_block_t,
268}
269
270impl<'a> Iterator for GradientsIter<'a> {
271 type Item = (&'a str, TensorBlockRef<'a>);
272
273 #[inline]
274 fn next(&mut self) -> Option<Self::Item> {
275 self.parameters.next().map(|parameter| {
276 let parameter_c = CString::new(parameter).expect("invalid C string");
277 let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
278
279 let block = unsafe { TensorBlockRef::from_raw(block) };
283 return (parameter, block);
284 })
285 }
286
287 fn size_hint(&self) -> (usize, Option<usize>) {
288 (self.len(), Some(self.len()))
289 }
290}
291
292impl ExactSizeIterator for GradientsIter<'_> {
293 #[inline]
294 fn len(&self) -> usize {
295 self.parameters.len()
296 }
297}
298
299impl FusedIterator for GradientsIter<'_> {}
300
301#[cfg(test)]
302mod tests {
303 }