Skip to main content

metatensor/block/
owned.rs

1use crate::c_api::mts_block_t;
2use crate::errors::check_status;
3use crate::{ArrayRef, Error, Labels, MtsArray};
4
5use super::{TensorBlockRef, TensorBlockRefMut};
6
7/// A single block, containing both values & optionally gradients of these
8/// values w.r.t. any relevant quantity.
9#[derive(Debug)]
10#[repr(transparent)]
11pub struct TensorBlock {
12    ptr: *mut mts_block_t,
13}
14
15// SAFETY: TensorBlock can be freed from any thread
16unsafe impl Send for TensorBlock {}
17// SAFETY: Sync is fine since there is no internal mutability in TensorBlock
18unsafe impl Sync for TensorBlock {}
19
20impl std::ops::Drop for TensorBlock {
21    #[allow(unused_must_use)]
22    fn drop(&mut self) {
23        unsafe {
24            crate::c_api::mts_block_free(self.as_mut_ptr());
25        }
26    }
27}
28
29impl TensorBlock {
30    /// Create a new `TensorBlock` from a raw pointer.
31    ///
32    /// This function takes ownership of the pointer, and will call
33    /// `mts_block_free` on it when the `TensorBlock` goes out of scope.
34    ///
35    /// # Safety
36    ///
37    /// The pointer must be non-null and point to a owned block, not a reference
38    /// to a block from inside a [`TensorMap`](crate::TensorMap).
39    pub unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlock {
40        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
41
42        TensorBlock {
43            ptr: ptr,
44        }
45    }
46
47    /// Extract the underlying raw pointer.
48    ///
49    /// The pointer should be passed back to [`TensorBlock::from_raw`] or
50    /// [`crate::c_api::mts_block_free`] to release the memory corresponding
51    /// to this `TensorBlock`.
52    pub fn into_raw(mut block: TensorBlock) -> *mut mts_block_t {
53        return std::mem::replace(&mut block.ptr, std::ptr::null_mut());
54    }
55
56    /// Get the underlying raw pointer.
57    ///
58    /// After a call, this `TensorBlock` is still managing the corresponding
59    /// memory. To fully release the pointer, use [`TensorBlock::into_raw`].
60    pub fn as_ptr(&self) -> *const mts_block_t {
61        self.ptr
62    }
63
64    /// Get the underlying (mutable) raw pointer
65    ///
66    /// After a call, this `TensorBlock` is still managing the corresponding
67    /// memory. To fully release the pointer, use [`TensorBlock::into_raw`].
68    pub fn as_mut_ptr(&mut self) -> *mut mts_block_t {
69        self.ptr
70    }
71
72    /// Get a non mutable reference to this block
73    #[inline]
74    pub fn as_ref(&self) -> TensorBlockRef<'_> {
75        unsafe {
76            TensorBlockRef::from_raw(self.as_ptr())
77        }
78    }
79
80    /// Get a non mutable reference to this block
81    #[inline]
82    pub fn as_ref_mut(&mut self) -> TensorBlockRefMut<'_> {
83        unsafe {
84            TensorBlockRefMut::from_raw(self.as_mut_ptr())
85        }
86    }
87
88    /// Get the device on which the values of this block are stored.
89    #[inline]
90    pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
91        self.as_ref().device()
92    }
93
94    /// Get the data type of the values of this block.
95    #[inline]
96    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
97        self.as_ref().dtype()
98    }
99
100    /// Get the array for the values in this block
101    #[inline]
102    pub fn values(&self) -> ArrayRef<'_> {
103        return self.as_ref().values();
104    }
105
106    /// Get the samples for this block
107    #[inline]
108    pub fn samples(&self) -> Labels {
109        return self.as_ref().samples();
110    }
111
112    /// Get the components for this block
113    #[inline]
114    pub fn components(&self) -> Vec<Labels> {
115        return self.as_ref().components();
116    }
117
118    /// Get the properties for this block
119    #[inline]
120    pub fn properties(&self) -> Labels {
121        return self.as_ref().properties();
122    }
123
124    /// Create a new [`TensorBlock`] containing the given data, described by the
125    /// `samples`, `components`, and `properties` labels. The block is
126    /// initialized without any gradients.
127    #[inline]
128    pub fn new(
129        values: impl Into<MtsArray>,
130        samples: &Labels,
131        components: &[Labels],
132        properties: &Labels
133    ) -> Result<TensorBlock, Error> {
134        let mut c_components = Vec::new();
135        for component in components {
136            c_components.push(component.as_mts_labels_t());
137        }
138
139        let ptr = unsafe {
140            crate::c_api::mts_block(
141                values.into().into_raw(),
142                samples.as_mts_labels_t(),
143                c_components.as_ptr(),
144                c_components.len(),
145                properties.as_mts_labels_t(),
146            )
147        };
148
149        crate::errors::check_ptr(ptr)?;
150
151        return Ok(unsafe { TensorBlock::from_raw(ptr) });
152    }
153
154    /// Add a gradient with respect to `parameter` to this block.
155    ///
156    /// The property of the gradient should match the ones of this block. The
157    /// components of the gradients must contain at least the same entries as
158    /// the value components, and can prepend other components.
159    #[allow(clippy::needless_pass_by_value)]
160    #[inline]
161    pub fn add_gradient(
162        &mut self,
163        parameter: &str,
164        mut gradient: TensorBlock
165    ) -> Result<(), Error> {
166        let mut parameter = parameter.to_owned().into_bytes();
167        parameter.push(b'\0');
168
169
170        let gradient_ptr = gradient.as_ref_mut().as_mut_ptr();
171        // we give ownership of the gradient to `self`, so we should not free
172        // them again from here
173        std::mem::forget(gradient);
174
175        unsafe {
176            check_status(crate::c_api::mts_block_add_gradient(
177                self.as_ref_mut().as_mut_ptr(),
178                parameter.as_ptr().cast(),
179                gradient_ptr,
180            ))?;
181        }
182
183        return Ok(());
184    }
185
186    /// Load a `TensorBlock` from the file at `path`
187    ///
188    /// This is a convenience function calling [`crate::io::load_block`]
189    pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorBlock, Error> {
190        return crate::io::load_block(path);
191    }
192
193    /// Load a `TensorBlock` from an in-memory buffer
194    ///
195    /// This is a convenience function calling [`crate::io::load_block_buffer`]
196    pub fn load_buffer(buffer: &[u8]) -> Result<TensorBlock, Error> {
197        return crate::io::load_block_buffer(buffer);
198    }
199
200    /// Save the given block to the file at `path`
201    ///
202    /// This is a convenience function calling [`crate::io::save_block`]
203    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
204        self.as_ref().save(path)
205    }
206
207    /// Save the given block to an in-memory buffer
208    ///
209    /// This is a convenience function calling [`crate::io::save_block_buffer`]
210    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
211        self.as_ref().save_buffer(buffer)
212    }
213}
214
215
216#[cfg(test)]
217mod tests {
218    use crate::c_api::mts_block_t;
219    use super::*;
220
221    #[test]
222    fn block() {
223        let block = TensorBlock::new(
224            ndarray::Array::from_elem(vec![2, 1, 3], 1.0),
225            &Labels::new(["samples"], [[0], [1]]),
226            &[Labels::new(["component"], [[0]])],
227            &Labels::new(["properties"], [[-2], [0], [1]]),
228        ).unwrap();
229
230        let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
231        assert_eq!(*values, ndarray::Array::from_elem(vec![2, 1, 3], 1.0));
232        assert_eq!(block.samples(), Labels::new(["samples"], [[0], [1]]));
233        assert_eq!(block.components(), [Labels::new(["component"], [[0]])]);
234        assert_eq!(block.properties(), Labels::new(["properties"], [[-2], [0], [1]]));
235    }
236
237    #[test]
238    fn check_repr() {
239        // we are casting `*mut TensorBlock` to `*mut mts_block_t` in TensorMap::new,
240        // this is only legal because TensorBlock == *mut mts_block_t
241        assert_eq!(std::mem::size_of::<TensorBlock>(), std::mem::size_of::<*mut mts_block_t>());
242        assert_eq!(std::mem::align_of::<TensorBlock>(), std::mem::align_of::<*mut mts_block_t>());
243    }
244
245    #[test]
246    fn device_and_dtype() {
247        let block = TensorBlock::new(
248            ndarray::Array::from_elem(vec![2, 3], 1.0),
249            &Labels::new(["samples"], [[0], [1]]),
250            &[],
251            &Labels::new(["properties"], [[-2], [0], [1]]),
252        ).unwrap();
253
254        let device = block.device().unwrap();
255        assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
256
257        let dtype = block.dtype().unwrap();
258        assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
259        assert_eq!(dtype.bits, 64);
260    }
261
262    #[test]
263    fn tensor_block_into_raw() {
264        let block = TensorBlock::new(
265            ndarray::Array::from_elem(vec![3, 2], 1.0),
266            &Labels::new(["samples"], [[0], [1], [4]]),
267            &[],
268            &Labels::new(["properties"], [[5], [3]]),
269        ).unwrap();
270
271        let raw = TensorBlock::into_raw(block);
272
273        let recovered = unsafe { TensorBlock::from_raw(raw) };
274        assert_eq!(recovered.samples(), Labels::new(["samples"], [[0], [1], [4]]));
275        assert_eq!(recovered.properties(), Labels::new(["properties"], [[5], [3]]));
276    }
277}