1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
use crate::c_api::mts_block_t;
use crate::errors::check_status;
use crate::{Array, ArrayRef, Labels, Error};

use super::{TensorBlockRef, TensorBlockRefMut};

/// A single block, containing both values & optionally gradients of these
/// values w.r.t. any relevant quantity.
#[derive(Debug)]
#[repr(transparent)]
pub struct TensorBlock {
    ptr: *mut mts_block_t,
}

// SAFETY: TensorBlock can be freed from any thread
unsafe impl Send for TensorBlock {}
// SAFETY: Sync is fine since there is no internal mutability in TensorBlock
unsafe impl Sync for TensorBlock {}

impl std::ops::Drop for TensorBlock {
    #[allow(unused_must_use)]
    fn drop(&mut self) {
        unsafe {
            crate::c_api::mts_block_free(self.as_mut_ptr());
        }
    }
}

impl TensorBlock {
    /// Create a new `TensorBlock` from a raw pointer.
    ///
    /// This function takes ownership of the pointer, and will call
    /// `mts_block_free` on it when the `TensorBlock` goes out of scope.
    ///
    /// # Safety
    ///
    /// The pointer must be non-null and point to a owned block, not a reference
    /// to a block from inside a [`TensorMap`](crate::TensorMap).
    pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlock {
        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");

        TensorBlock {
            ptr: ptr,
        }
    }

    /// Get the underlying raw pointer
    pub(super) fn as_ptr(&self) -> *const mts_block_t {
        self.ptr
    }

    /// Get the underlying (mutable) raw pointer
    pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
        self.ptr
    }

    /// Get a non mutable reference to this block
    #[inline]
    pub fn as_ref(&self) -> TensorBlockRef<'_> {
        unsafe {
            TensorBlockRef::from_raw(self.as_ptr())
        }
    }

    /// Get a non mutable reference to this block
    #[inline]
    pub fn as_ref_mut(&mut self) -> TensorBlockRefMut<'_> {
        unsafe {
            TensorBlockRefMut::from_raw(self.as_mut_ptr())
        }
    }

    /// Get the array for the values in this block
    #[inline]
    pub fn values(&self) -> ArrayRef<'_> {
        return self.as_ref().values();
    }

    /// Get the samples for this block
    #[inline]
    pub fn samples(&self) -> Labels {
        return self.as_ref().samples();
    }

    /// Get the components for this block
    #[inline]
    pub fn components(&self) -> Vec<Labels> {
        return self.as_ref().components();
    }

    /// Get the properties for this block
    #[inline]
    pub fn properties(&self) -> Labels {
        return self.as_ref().properties();
    }

    /// Create a new [`TensorBlock`] containing the given data, described by the
    /// `samples`, `components`, and `properties` labels. The block is
    /// initialized without any gradients.
    #[inline]
    pub fn new(
        data: impl Array,
        samples: &Labels,
        components: &[Labels],
        properties: &Labels
    ) -> Result<TensorBlock, Error> {
        let mut c_components = Vec::new();
        for component in components {
            c_components.push(component.as_mts_labels_t());
        }

        let ptr = unsafe {
            crate::c_api::mts_block(
                (Box::new(data) as Box<dyn Array>).into(),
                samples.as_mts_labels_t(),
                c_components.as_ptr(),
                c_components.len(),
                properties.as_mts_labels_t(),
            )
        };

        crate::errors::check_ptr(ptr)?;

        return Ok(unsafe { TensorBlock::from_raw(ptr) });
    }

    /// Add a gradient with respect to `parameter` to this block.
    ///
    /// The property of the gradient should match the ones of this block. The
    /// components of the gradients must contain at least the same entries as
    /// the value components, and can prepend other components.
    #[allow(clippy::needless_pass_by_value)]
    #[inline]
    pub fn add_gradient(
        &mut self,
        parameter: &str,
        mut gradient: TensorBlock
    ) -> Result<(), Error> {
        let mut parameter = parameter.to_owned().into_bytes();
        parameter.push(b'\0');


        let gradient_ptr = gradient.as_ref_mut().as_mut_ptr();
        // we give ownership of the gradient to `self`, so we should not free
        // them again from here
        std::mem::forget(gradient);

        unsafe {
            check_status(crate::c_api::mts_block_add_gradient(
                self.as_ref_mut().as_mut_ptr(),
                parameter.as_ptr().cast(),
                gradient_ptr,
            ))?;
        }

        return Ok(());
    }
}


#[cfg(test)]
mod tests {
    use crate::c_api::mts_block_t;
    use super::*;

    #[test]
    fn check_repr() {
        // we are casting `*mut TensorBlock` to `*mut mts_block_t` in TensorMap::new,
        // this is only legal because TensorBlock == *mut mts_block_t
        assert_eq!(std::mem::size_of::<TensorBlock>(), std::mem::size_of::<*mut mts_block_t>());
        assert_eq!(std::mem::align_of::<TensorBlock>(), std::mem::align_of::<*mut mts_block_t>());
    }
}