1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
use crate::c_api::mts_block_t;
use crate::errors::check_status;
use crate::{Array, ArrayRef, Labels, Error};
use super::{TensorBlockRef, TensorBlockRefMut};
/// A single block, containing both values & optionally gradients of these
/// values w.r.t. any relevant quantity.
#[derive(Debug)]
#[repr(transparent)]
pub struct TensorBlock {
ptr: *mut mts_block_t,
}
// SAFETY: TensorBlock can be freed from any thread
unsafe impl Send for TensorBlock {}
// SAFETY: Sync is fine since there is no internal mutability in TensorBlock
unsafe impl Sync for TensorBlock {}
impl std::ops::Drop for TensorBlock {
#[allow(unused_must_use)]
fn drop(&mut self) {
unsafe {
crate::c_api::mts_block_free(self.as_mut_ptr());
}
}
}
impl TensorBlock {
/// Create a new `TensorBlock` from a raw pointer.
///
/// This function takes ownership of the pointer, and will call
/// `mts_block_free` on it when the `TensorBlock` goes out of scope.
///
/// # Safety
///
/// The pointer must be non-null and point to a owned block, not a reference
/// to a block from inside a [`TensorMap`](crate::TensorMap).
pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlock {
assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
TensorBlock {
ptr: ptr,
}
}
/// Get the underlying raw pointer
pub(super) fn as_ptr(&self) -> *const mts_block_t {
self.ptr
}
/// Get the underlying (mutable) raw pointer
pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
self.ptr
}
/// Get a non mutable reference to this block
#[inline]
pub fn as_ref(&self) -> TensorBlockRef<'_> {
unsafe {
TensorBlockRef::from_raw(self.as_ptr())
}
}
/// Get a non mutable reference to this block
#[inline]
pub fn as_ref_mut(&mut self) -> TensorBlockRefMut<'_> {
unsafe {
TensorBlockRefMut::from_raw(self.as_mut_ptr())
}
}
/// Get the array for the values in this block
#[inline]
pub fn values(&self) -> ArrayRef<'_> {
return self.as_ref().values();
}
/// Get the samples for this block
#[inline]
pub fn samples(&self) -> Labels {
return self.as_ref().samples();
}
/// Get the components for this block
#[inline]
pub fn components(&self) -> Vec<Labels> {
return self.as_ref().components();
}
/// Get the properties for this block
#[inline]
pub fn properties(&self) -> Labels {
return self.as_ref().properties();
}
/// Create a new [`TensorBlock`] containing the given data, described by the
/// `samples`, `components`, and `properties` labels. The block is
/// initialized without any gradients.
#[inline]
pub fn new(
data: impl Array,
samples: &Labels,
components: &[Labels],
properties: &Labels
) -> Result<TensorBlock, Error> {
let mut c_components = Vec::new();
for component in components {
c_components.push(component.as_mts_labels_t());
}
let ptr = unsafe {
crate::c_api::mts_block(
(Box::new(data) as Box<dyn Array>).into(),
samples.as_mts_labels_t(),
c_components.as_ptr(),
c_components.len(),
properties.as_mts_labels_t(),
)
};
crate::errors::check_ptr(ptr)?;
return Ok(unsafe { TensorBlock::from_raw(ptr) });
}
/// Add a gradient with respect to `parameter` to this block.
///
/// The property of the gradient should match the ones of this block. The
/// components of the gradients must contain at least the same entries as
/// the value components, and can prepend other components.
#[allow(clippy::needless_pass_by_value)]
#[inline]
pub fn add_gradient(
&mut self,
parameter: &str,
mut gradient: TensorBlock
) -> Result<(), Error> {
let mut parameter = parameter.to_owned().into_bytes();
parameter.push(b'\0');
let gradient_ptr = gradient.as_ref_mut().as_mut_ptr();
// we give ownership of the gradient to `self`, so we should not free
// them again from here
std::mem::forget(gradient);
unsafe {
check_status(crate::c_api::mts_block_add_gradient(
self.as_ref_mut().as_mut_ptr(),
parameter.as_ptr().cast(),
gradient_ptr,
))?;
}
return Ok(());
}
}
#[cfg(test)]
mod tests {
use crate::c_api::mts_block_t;
use super::*;
#[test]
fn check_repr() {
// we are casting `*mut TensorBlock` to `*mut mts_block_t` in TensorMap::new,
// this is only legal because TensorBlock == *mut mts_block_t
assert_eq!(std::mem::size_of::<TensorBlock>(), std::mem::size_of::<*mut mts_block_t>());
assert_eq!(std::mem::align_of::<TensorBlock>(), std::mem::align_of::<*mut mts_block_t>());
}
}