metatensor/block/
owned.rs

1use crate::c_api::mts_block_t;
2use crate::errors::check_status;
3use crate::{Array, ArrayRef, Labels, Error};
4
5use super::{TensorBlockRef, TensorBlockRefMut};
6
7/// A single block, containing both values & optionally gradients of these
8/// values w.r.t. any relevant quantity.
9#[derive(Debug)]
10#[repr(transparent)]
11pub struct TensorBlock {
12    ptr: *mut mts_block_t,
13}
14
15// SAFETY: TensorBlock can be freed from any thread
16unsafe impl Send for TensorBlock {}
17// SAFETY: Sync is fine since there is no internal mutability in TensorBlock
18unsafe impl Sync for TensorBlock {}
19
20impl std::ops::Drop for TensorBlock {
21    #[allow(unused_must_use)]
22    fn drop(&mut self) {
23        unsafe {
24            crate::c_api::mts_block_free(self.as_mut_ptr());
25        }
26    }
27}
28
29impl TensorBlock {
30    /// Create a new `TensorBlock` from a raw pointer.
31    ///
32    /// This function takes ownership of the pointer, and will call
33    /// `mts_block_free` on it when the `TensorBlock` goes out of scope.
34    ///
35    /// # Safety
36    ///
37    /// The pointer must be non-null and point to a owned block, not a reference
38    /// to a block from inside a [`TensorMap`](crate::TensorMap).
39    pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlock {
40        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
41
42        TensorBlock {
43            ptr: ptr,
44        }
45    }
46
47    /// Get the underlying raw pointer
48    pub(super) fn as_ptr(&self) -> *const mts_block_t {
49        self.ptr
50    }
51
52    /// Get the underlying (mutable) raw pointer
53    pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
54        self.ptr
55    }
56
57    /// Get a non mutable reference to this block
58    #[inline]
59    pub fn as_ref(&self) -> TensorBlockRef<'_> {
60        unsafe {
61            TensorBlockRef::from_raw(self.as_ptr())
62        }
63    }
64
65    /// Get a non mutable reference to this block
66    #[inline]
67    pub fn as_ref_mut(&mut self) -> TensorBlockRefMut<'_> {
68        unsafe {
69            TensorBlockRefMut::from_raw(self.as_mut_ptr())
70        }
71    }
72
73    /// Get the array for the values in this block
74    #[inline]
75    pub fn values(&self) -> ArrayRef<'_> {
76        return self.as_ref().values();
77    }
78
79    /// Get the samples for this block
80    #[inline]
81    pub fn samples(&self) -> Labels {
82        return self.as_ref().samples();
83    }
84
85    /// Get the components for this block
86    #[inline]
87    pub fn components(&self) -> Vec<Labels> {
88        return self.as_ref().components();
89    }
90
91    /// Get the properties for this block
92    #[inline]
93    pub fn properties(&self) -> Labels {
94        return self.as_ref().properties();
95    }
96
97    /// Create a new [`TensorBlock`] containing the given data, described by the
98    /// `samples`, `components`, and `properties` labels. The block is
99    /// initialized without any gradients.
100    #[inline]
101    pub fn new(
102        data: impl Array,
103        samples: &Labels,
104        components: &[Labels],
105        properties: &Labels
106    ) -> Result<TensorBlock, Error> {
107        let mut c_components = Vec::new();
108        for component in components {
109            c_components.push(component.as_mts_labels_t());
110        }
111
112        let ptr = unsafe {
113            crate::c_api::mts_block(
114                (Box::new(data) as Box<dyn Array>).into(),
115                samples.as_mts_labels_t(),
116                c_components.as_ptr(),
117                c_components.len(),
118                properties.as_mts_labels_t(),
119            )
120        };
121
122        crate::errors::check_ptr(ptr)?;
123
124        return Ok(unsafe { TensorBlock::from_raw(ptr) });
125    }
126
127    /// Add a gradient with respect to `parameter` to this block.
128    ///
129    /// The property of the gradient should match the ones of this block. The
130    /// components of the gradients must contain at least the same entries as
131    /// the value components, and can prepend other components.
132    #[allow(clippy::needless_pass_by_value)]
133    #[inline]
134    pub fn add_gradient(
135        &mut self,
136        parameter: &str,
137        mut gradient: TensorBlock
138    ) -> Result<(), Error> {
139        let mut parameter = parameter.to_owned().into_bytes();
140        parameter.push(b'\0');
141
142
143        let gradient_ptr = gradient.as_ref_mut().as_mut_ptr();
144        // we give ownership of the gradient to `self`, so we should not free
145        // them again from here
146        std::mem::forget(gradient);
147
148        unsafe {
149            check_status(crate::c_api::mts_block_add_gradient(
150                self.as_ref_mut().as_mut_ptr(),
151                parameter.as_ptr().cast(),
152                gradient_ptr,
153            ))?;
154        }
155
156        return Ok(());
157    }
158
159    /// Load a `TensorBlock` from the file at `path`
160    ///
161    /// This is a convenience function calling [`crate::io::load_block`]
162    pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorBlock, Error> {
163        return crate::io::load_block(path);
164    }
165
166    /// Load a `TensorBlock` from an in-memory buffer
167    ///
168    /// This is a convenience function calling [`crate::io::load_block_buffer`]
169    pub fn load_buffer(buffer: &[u8]) -> Result<TensorBlock, Error> {
170        return crate::io::load_block_buffer(buffer);
171    }
172
173    /// Save the given block to the file at `path`
174    ///
175    /// This is a convenience function calling [`crate::io::save_block`]
176    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
177        self.as_ref().save(path)
178    }
179
180    /// Save the given block to an in-memory buffer
181    ///
182    /// This is a convenience function calling [`crate::io::save_block_buffer`]
183    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
184        self.as_ref().save_buffer(buffer)
185    }
186}
187
188
189#[cfg(test)]
190mod tests {
191    use crate::c_api::mts_block_t;
192    use super::*;
193
194    #[test]
195    fn check_repr() {
196        // we are casting `*mut TensorBlock` to `*mut mts_block_t` in TensorMap::new,
197        // this is only legal because TensorBlock == *mut mts_block_t
198        assert_eq!(std::mem::size_of::<TensorBlock>(), std::mem::size_of::<*mut mts_block_t>());
199        assert_eq!(std::mem::align_of::<TensorBlock>(), std::mem::align_of::<*mut mts_block_t>());
200    }
201}