1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
use std::ffi::{CStr, CString};
use std::iter::FusedIterator;

use crate::c_api::{mts_block_t, mts_array_t, mts_labels_t};
use crate::c_api::MTS_INVALID_PARAMETER_ERROR;

use crate::errors::check_status;
use crate::{ArrayRef, Labels, Error};

use super::{TensorBlock, LazyMetadata};

/// Reference to a [`TensorBlock`]
#[derive(Debug, Clone, Copy)]
pub struct TensorBlockRef<'a> {
    ptr: *const mts_block_t,
    marker: std::marker::PhantomData<&'a mts_block_t>,
}

// SAFETY: Send is fine since TensorBlockRef does not implement Drop
unsafe impl<'a> Send for TensorBlockRef<'a> {}
// SAFETY: Sync is fine since there is no internal mutability in TensorBlockRef
unsafe impl<'a> Sync for TensorBlockRef<'a> {}

/// All the basic data in a `TensorBlockRef` as a struct with separate fields.
///
/// This can be useful when you need to borrow different fields on this struct
/// separately. They are separate in the underlying metatensor-core, but since we
/// go through the C API to access them, we need to re-expose them as separate
/// fields for the rust compiler to be able to understand that.
///
/// The metadata is initialized lazily on first access, to not pay the cost of
/// allocation/reference count increase if some metadata is not used.
#[derive(Debug)]
pub struct TensorBlockData<'a> {
    pub values: ArrayRef<'a>,
    pub samples: LazyMetadata<Labels>,
    pub components: LazyMetadata<Vec<Labels>>,
    pub properties: LazyMetadata<Labels>,
}

impl<'a> TensorBlockRef<'a> {
    /// Create a new `TensorBlockRef` from the given raw `mts_block_t`
    ///
    /// This is a **VERY** unsafe function, creating a lifetime out of thin air.
    /// Make sure the lifetime is actually constrained by the lifetime of the
    /// owner of this `mts_block_t`.
    pub(crate) unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
        assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");

        TensorBlockRef {
            ptr: ptr,
            marker: std::marker::PhantomData,
        }
    }

    /// Get the underlying raw pointer
    pub(super) fn as_ptr(&self) -> *const mts_block_t {
        self.ptr
    }
}

/// Get a gradient from this block
fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
    let mut gradient_block = std::ptr::null_mut();
    let status = unsafe { crate::c_api::mts_block_gradient(
            // the cast to mut pointer is fine since we are only returning a
            // non-mut mts_block_t below
            block.cast_mut(),
            parameter.as_ptr(),
            &mut gradient_block
        )
    };

    match crate::errors::check_status(status) {
        Ok(()) => Some(gradient_block.cast_const()),
        Err(error) => {
            if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
                // there is no array for this gradient
                None
            } else {
                panic!("failed to get the gradient from a block: {:?}", error)
            }
        }
    }
}

pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
    unsafe {
        TensorBlockRef::from_raw(ptr).samples()
    }
}

pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
    unsafe {
        TensorBlockRef::from_raw(ptr).components()
    }
}

pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
    unsafe {
        TensorBlockRef::from_raw(ptr).properties()
    }
}

impl<'a> TensorBlockRef<'a> {
    /// Get all the data and metadata inside this `TensorBlockRef` as a
    /// struct with separate fields, to allow borrowing them separately.
    #[inline]
    pub fn data(&'a self) -> TensorBlockData<'a> {
        TensorBlockData {
            values: self.values(),
            samples: LazyMetadata::new(get_samples, self.as_ptr()),
            components: LazyMetadata::new(get_components, self.as_ptr()),
            properties: LazyMetadata::new(get_properties, self.as_ptr()),
        }
    }

    /// Get the array for the values in this block
    #[inline]
    pub fn values(&self) -> ArrayRef<'a> {
        let mut array = mts_array_t::null();
        unsafe {
            crate::errors::check_status(crate::c_api::mts_block_data(
                self.as_ptr().cast_mut(),
                &mut array
            )).expect("failed to get the array for a block");
        };

        // SAFETY: we can return an `ArrayRef` with lifetime `'a` (instead of
        // `'self`) (which allows to get multiple references to the BasicBlock
        // simultaneously), because there is no way to also get a mutable
        // reference to the block at the same time (since we are already holding
        // a const reference to the block itself).
        unsafe { ArrayRef::from_raw(array) }
    }

    #[inline]
    fn labels(&self, dimension: usize) -> Labels {
        let mut labels = mts_labels_t::null();
        unsafe {
            check_status(crate::c_api::mts_block_labels(
                self.as_ptr(),
                dimension,
                &mut labels,
            )).expect("failed to get labels");
        }
        return unsafe { Labels::from_raw(labels) };
    }

    /// Get the samples for this block
    #[inline]
    pub fn samples(&self) -> Labels {
        return self.labels(0);
    }

    /// Get the components for this block
    #[inline]
    pub fn components(&self) -> Vec<Labels> {
        let values = self.values();
        let shape = values.as_raw().shape().expect("failed to get the data shape");

        let mut result = Vec::new();
        for i in 1..(shape.len() - 1) {
            result.push(self.labels(i));
        }
        return result;
    }

    /// Get the properties for this block
    #[inline]
    pub fn properties(&self) -> Labels {
        let values = self.values();
        let shape = values.as_raw().shape().expect("failed to get the data shape");

        return self.labels(shape.len() - 1);
    }

    /// Get the full list of gradients in this block

    // SAFETY: we can return strings with the `'a` lifetime (instead of
    // `'self`), because there is no way to also get a mutable reference
    // to the gradient parameters at the same time.
    #[inline]
    pub fn gradient_list(&self) -> Vec<&'a str> {
        let mut parameters_ptr = std::ptr::null();
        let mut parameters_count = 0;
        unsafe {
            check_status(crate::c_api::mts_block_gradients_list(
                self.as_ptr(),
                &mut parameters_ptr,
                &mut parameters_count
            )).expect("failed to get gradient list");
        }

        if parameters_count == 0 {
            return Vec::new();
        } else {
            assert!(!parameters_ptr.is_null());
            unsafe {
                let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
                return parameters.iter()
                    .map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
                    .collect();
            }
        }
    }

    /// Get the data and metadata for the gradient with respect to the given
    /// parameter in this block, if it exists.

    // SAFETY: we can return a TensorBlockRef with lifetime `'a` (instead of
    // `'self`) for the same reasons as in the `values` function.
    #[inline]
    pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
        let parameter = CString::new(parameter).expect("invalid C string");

        block_gradient(self.as_ptr(), &parameter)
            .map(|gradient_block| {
                // SAFETY: the lifetime of the block is the same as
                // the lifetime of self, both are constrained to the
                // root TensorMap/TensorBlock
                unsafe { TensorBlockRef::from_raw(gradient_block) }
        })
    }

    /// Clone this block, cloning all the data and metadata contained inside.
    ///
    /// This can fail if the external data held inside an `mts_array_t` can not
    /// be cloned.
    #[inline]
    pub fn try_clone(&self) -> Result<TensorBlock, Error> {
        let ptr = unsafe {
            crate::c_api::mts_block_copy(self.as_ptr())
        };
        crate::errors::check_ptr(ptr)?;

        return Ok(unsafe { TensorBlock::from_raw(ptr) });
    }

    /// Get an iterator over parameter/[`TensorBlockRef`] pairs for all gradients in
    /// this block
    #[inline]
    pub fn gradients(&self) -> GradientsIter<'_> {
        GradientsIter {
            parameters: self.gradient_list().into_iter(),
            block: self.as_ptr(),
        }
    }
}

/// Iterator over parameter/[`TensorBlockRef`] pairs for all gradients in a
/// [`TensorBlockRef`]
pub struct GradientsIter<'a> {
    parameters: std::vec::IntoIter<&'a str>,
    block: *const mts_block_t,
}

impl<'a> Iterator for GradientsIter<'a> {
    type Item = (&'a str, TensorBlockRef<'a>);

    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        self.parameters.next().map(|parameter| {
            let parameter_c = CString::new(parameter).expect("invalid C string");
            let block = block_gradient(self.block, &parameter_c).expect("missing gradient");

            // SAFETY: the lifetime of the block is the same as the lifetime of
            // the GradientsIter, both are constrained to the root
            // TensorMap/TensorBlock
            let block = unsafe { TensorBlockRef::from_raw(block) };
            return (parameter, block);
        })
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (self.len(), Some(self.len()))
    }
}

impl<'a> ExactSizeIterator for GradientsIter<'a> {
    #[inline]
    fn len(&self) -> usize {
        self.parameters.len()
    }
}

impl<'a> FusedIterator for GradientsIter<'a> {}

#[cfg(test)]
mod tests {
    // TODO: check gradient/gradient iter code
}