metatensor/block/
owned.rs1use crate::c_api::mts_block_t;
2use crate::errors::check_status;
3use crate::{Array, ArrayRef, Labels, Error};
4
5use super::{TensorBlockRef, TensorBlockRefMut};
6
7#[derive(Debug)]
10#[repr(transparent)]
11pub struct TensorBlock {
12 ptr: *mut mts_block_t,
13}
14
15unsafe impl Send for TensorBlock {}
17unsafe impl Sync for TensorBlock {}
19
20impl std::ops::Drop for TensorBlock {
21 #[allow(unused_must_use)]
22 fn drop(&mut self) {
23 unsafe {
24 crate::c_api::mts_block_free(self.as_mut_ptr());
25 }
26 }
27}
28
29impl TensorBlock {
30 pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlock {
40 assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
41
42 TensorBlock {
43 ptr: ptr,
44 }
45 }
46
47 pub(super) fn as_ptr(&self) -> *const mts_block_t {
49 self.ptr
50 }
51
52 pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
54 self.ptr
55 }
56
57 #[inline]
59 pub fn as_ref(&self) -> TensorBlockRef<'_> {
60 unsafe {
61 TensorBlockRef::from_raw(self.as_ptr())
62 }
63 }
64
65 #[inline]
67 pub fn as_ref_mut(&mut self) -> TensorBlockRefMut<'_> {
68 unsafe {
69 TensorBlockRefMut::from_raw(self.as_mut_ptr())
70 }
71 }
72
73 #[inline]
75 pub fn values(&self) -> ArrayRef<'_> {
76 return self.as_ref().values();
77 }
78
79 #[inline]
81 pub fn samples(&self) -> Labels {
82 return self.as_ref().samples();
83 }
84
85 #[inline]
87 pub fn components(&self) -> Vec<Labels> {
88 return self.as_ref().components();
89 }
90
91 #[inline]
93 pub fn properties(&self) -> Labels {
94 return self.as_ref().properties();
95 }
96
97 #[inline]
101 pub fn new(
102 data: impl Array,
103 samples: &Labels,
104 components: &[Labels],
105 properties: &Labels
106 ) -> Result<TensorBlock, Error> {
107 let mut c_components = Vec::new();
108 for component in components {
109 c_components.push(component.as_mts_labels_t());
110 }
111
112 let ptr = unsafe {
113 crate::c_api::mts_block(
114 (Box::new(data) as Box<dyn Array>).into(),
115 samples.as_mts_labels_t(),
116 c_components.as_ptr(),
117 c_components.len(),
118 properties.as_mts_labels_t(),
119 )
120 };
121
122 crate::errors::check_ptr(ptr)?;
123
124 return Ok(unsafe { TensorBlock::from_raw(ptr) });
125 }
126
127 #[allow(clippy::needless_pass_by_value)]
133 #[inline]
134 pub fn add_gradient(
135 &mut self,
136 parameter: &str,
137 mut gradient: TensorBlock
138 ) -> Result<(), Error> {
139 let mut parameter = parameter.to_owned().into_bytes();
140 parameter.push(b'\0');
141
142
143 let gradient_ptr = gradient.as_ref_mut().as_mut_ptr();
144 std::mem::forget(gradient);
147
148 unsafe {
149 check_status(crate::c_api::mts_block_add_gradient(
150 self.as_ref_mut().as_mut_ptr(),
151 parameter.as_ptr().cast(),
152 gradient_ptr,
153 ))?;
154 }
155
156 return Ok(());
157 }
158
159 pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorBlock, Error> {
163 return crate::io::load_block(path);
164 }
165
166 pub fn load_buffer(buffer: &[u8]) -> Result<TensorBlock, Error> {
170 return crate::io::load_block_buffer(buffer);
171 }
172
173 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
177 self.as_ref().save(path)
178 }
179
180 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
184 self.as_ref().save_buffer(buffer)
185 }
186}
187
188
189#[cfg(test)]
190mod tests {
191 use crate::c_api::mts_block_t;
192 use super::*;
193
194 #[test]
195 fn check_repr() {
196 assert_eq!(std::mem::size_of::<TensorBlock>(), std::mem::size_of::<*mut mts_block_t>());
199 assert_eq!(std::mem::align_of::<TensorBlock>(), std::mem::align_of::<*mut mts_block_t>());
200 }
201}