metatensor/
tensor.rs

1use std::ffi::CString;
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::{mts_tensormap_t, mts_labels_t};
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, TensorBlock, TensorBlockRef, Labels, LabelValue};
9
10/// [`TensorMap`] is the main user-facing struct of this library, and can
11/// store any kind of data used in atomistic machine learning.
12///
13/// A tensor map contains a list of `TensorBlock`s, each one associated with a
14/// key in the form of a single `Labels` entry.
15///
16/// It provides functions to merge blocks together by moving some of these keys
17/// to the samples or properties labels of the blocks, transforming the sparse
18/// representation of the data to a dense one.
19pub struct TensorMap {
20    pub(crate) ptr: *mut mts_tensormap_t,
21    /// cache for the keys labels
22    keys: Labels,
23}
24
25// SAFETY: Send is fine since we can free a TensorMap from any thread
26unsafe impl Send for TensorMap {}
27// SAFETY: Sync is fine since there is no internal mutability in TensorMap
28unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        use crate::labels::pretty_print_labels;
33        writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35        write!(f, "    keys: ")?;
36        pretty_print_labels(self.keys(), "    ", f)?;
37        writeln!(f, "}}")
38    }
39}
40
41impl std::ops::Drop for TensorMap {
42    #[allow(unused_must_use)]
43    fn drop(&mut self) {
44        unsafe {
45            crate::c_api::mts_tensormap_free(self.ptr);
46        }
47    }
48}
49
50impl TensorMap {
51    /// Create a new `TensorMap` with the given keys and blocks.
52    ///
53    /// The number of keys must match the number of blocks, and all the blocks
54    /// must contain the same kind of data (same labels names, same gradients
55    /// defined on all blocks).
56    #[allow(clippy::needless_pass_by_value)]
57    #[inline]
58    pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59        let ptr = unsafe {
60            crate::c_api::mts_tensormap(
61                keys.as_mts_labels_t(),
62                // this cast is fine because TensorBlock is `repr(transparent)`
63                // to a `*mut mts_block_t` (through `TensorBlockRefMut`, and
64                // `TensorBlockRef`).
65                blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66                blocks.len()
67            )
68        };
69
70        for block in blocks {
71            // we give ownership of the blocks to the new tensormap, so we
72            // should not free them again from Rust
73            std::mem::forget(block);
74        }
75
76        check_ptr(ptr)?;
77
78        return Ok(unsafe { TensorMap::from_raw(ptr) });
79    }
80
81    /// Create a new `TensorMap` from a raw pointer.
82    ///
83    /// This function takes ownership of the pointer, and will call
84    /// `mts_tensormap_free` on it when the `TensorMap` goes out of scope.
85    ///
86    /// # Safety
87    ///
88    /// The pointer must be non-null and created by `mts_tensormap` or
89    /// `TensorMap::into_raw`.
90    pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91        assert!(!ptr.is_null());
92
93        let mut keys = mts_labels_t::null();
94        check_status(crate::c_api::mts_tensormap_keys(
95            ptr,
96            &mut keys
97        )).expect("failed to get the keys");
98
99        let keys = Labels::from_raw(keys);
100
101        return TensorMap {
102            ptr,
103            keys
104        };
105    }
106
107    /// Extract the underlying raw pointer.
108    ///
109    /// The pointer should be passed back to `TensorMap::from_raw` or
110    /// `mts_tensormap_free` to release the memory corresponding to this
111    /// `TensorMap`.
112    pub fn into_raw(mut map: TensorMap) -> *mut mts_tensormap_t {
113        let ptr = map.ptr;
114        map.ptr = std::ptr::null_mut();
115        return ptr;
116    }
117
118    /// Clone this `TensorMap`, cloning all the data and metadata contained inside.
119    ///
120    /// This can fail if the external data held inside an `mts_array_t` can not
121    /// be cloned.
122    #[inline]
123    pub fn try_clone(&self) -> Result<TensorMap, Error> {
124        let ptr = unsafe {
125            crate::c_api::mts_tensormap_copy(self.ptr)
126        };
127        crate::errors::check_ptr(ptr)?;
128
129        return Ok(unsafe { TensorMap::from_raw(ptr) });
130    }
131
132    /// Load a `TensorMap` from the file at `path`
133    ///
134    /// This is a convenience function calling [`crate::io::load`]
135    pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
136        return crate::io::load(path);
137    }
138
139    /// Load a `TensorMap` from an in-memory buffer
140    ///
141    /// This is a convenience function calling [`crate::io::load_buffer`]
142    pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
143        return crate::io::load_buffer(buffer);
144    }
145
146    /// Save the given tensor to the file at `path`
147    ///
148    /// This is a convenience function calling [`crate::io::save`]
149    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
150        return crate::io::save(path, self);
151    }
152
153    /// Save the given tensor to an in-memory buffer
154    ///
155    /// This is a convenience function calling [`crate::io::save_buffer`]
156    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
157        return crate::io::save_buffer(self, buffer);
158    }
159
160    /// Get the keys defined in this `TensorMap`
161    #[inline]
162    pub fn keys(&self) -> &Labels {
163        &self.keys
164    }
165
166    /// Get a reference to the block at the given `index` in this `TensorMap`
167    ///
168    /// # Panics
169    ///
170    /// If the index is out of bounds
171    #[inline]
172    pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
173
174        let mut block = std::ptr::null_mut();
175        unsafe {
176            check_status(crate::c_api::mts_tensormap_block_by_id(
177                self.ptr,
178                &mut block,
179                index,
180            )).expect("failed to get a block");
181        }
182
183        return unsafe { TensorBlockRef::from_raw(block) }
184    }
185
186    /// Get a mutable reference to the block at the given `index` in this `TensorMap`
187    ///
188    /// # Panics
189    ///
190    /// If the index is out of bounds
191    #[inline]
192    pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
193        return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
194    }
195
196    /// Implementation of `block_mut_by_id` which does not borrow the
197    /// `mts_tensormap_t` pointer.
198    ///
199    /// This is used to provide references to multiple blocks at the same time
200    /// in the iterators.
201    ///
202    /// # Safety
203    ///
204    /// This should be called with a valid `mts_tensormap_t`, and the lifetime
205    /// `'a` should be properly constrained to the lifetime of the owner of
206    /// `ptr`.
207    #[inline]
208    unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
209        let mut block = std::ptr::null_mut();
210
211        check_status(crate::c_api::mts_tensormap_block_by_id(
212            ptr,
213            &mut block,
214            index,
215        )).expect("failed to get a block");
216
217        return TensorBlockRefMut::from_raw(block);
218    }
219
220    /// Get the index of blocks matching the given selection.
221    ///
222    /// The selection must contains a single entry, defining the requested key
223    /// or keys. If the selection contains only a subset of the dimensions of the
224    /// keys, there can be multiple matching blocks.
225    #[inline]
226    pub fn blocks_matching(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
227        let mut indexes = vec![0; self.keys().count()];
228        let mut matching = indexes.len();
229        unsafe {
230            check_status(crate::c_api::mts_tensormap_blocks_matching(
231                self.ptr,
232                indexes.as_mut_ptr(),
233                &mut matching,
234                selection.as_mts_labels_t(),
235            ))?;
236        }
237        indexes.resize(matching, 0);
238
239        return Ok(indexes);
240    }
241
242    /// Get the index of the single block matching the given selection.
243    ///
244    /// This function is similar to [`TensorMap::blocks_matching`], but also
245    /// returns an error if more than one block matches the selection.
246    #[inline]
247    pub fn block_matching(&self, selection: &Labels) -> Result<usize, Error> {
248        let matching = self.blocks_matching(selection)?;
249        if matching.len() != 1 {
250            let selection_str = selection.names()
251                .iter().zip(&selection[0])
252                .map(|(name, value)| format!("{} = {}", name, value))
253                .collect::<Vec<_>>()
254                .join(", ");
255
256
257            if matching.is_empty() {
258                return Err(Error {
259                    code: None,
260                    message: format!(
261                        "no blocks matched the selection ({})",
262                        selection_str
263                    ),
264                });
265            } else {
266                return Err(Error {
267                    code: None,
268                    message: format!(
269                        "{} blocks matched the selection ({}), expected only one",
270                        matching.len(),
271                        selection_str
272                    ),
273                });
274            }
275        }
276
277        return Ok(matching[0])
278    }
279
280    /// Get a reference to the block matching the given selection.
281    ///
282    /// This function uses [`TensorMap::blocks_matching`] under the hood to find
283    /// the matching block.
284    #[inline]
285    pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
286        let id = self.block_matching(selection)?;
287        return Ok(self.block_by_id(id));
288    }
289
290    /// Get a reference to every blocks in this `TensorMap`
291    #[inline]
292    pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
293        let mut blocks = Vec::new();
294        for i in 0..self.keys().count() {
295            blocks.push(self.block_by_id(i));
296        }
297        return blocks;
298    }
299
300    /// Get a mutable reference to every blocks in this `TensorMap`
301    #[inline]
302    pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
303        let mut blocks = Vec::new();
304        for i in 0..self.keys().count() {
305            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
306        }
307        return blocks;
308    }
309
310    /// Merge blocks with the same value for selected keys dimensions along the
311    /// samples axis.
312    ///
313    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
314    /// the sample labels, and blocks with the same remaining keys dimensions
315    /// will be merged together along the sample axis.
316    ///
317    /// `keys_to_move` must be empty (`keys_to_move.count() == 0`), and the new
318    /// sample labels will contain entries corresponding to the merged blocks'
319    /// keys.
320    ///
321    /// The new sample labels will contains all of the merged blocks sample
322    /// labels. The order of the samples is controlled by `sort_samples`. If
323    /// `sort_samples` is true, samples are re-ordered to keep them
324    /// lexicographically sorted. Otherwise they are kept in the order in which
325    /// they appear in the blocks.
326    ///
327    /// This function is only implemented if all merged block have the same
328    /// property labels.
329    #[inline]
330    pub fn keys_to_samples(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
331        let ptr = unsafe {
332            crate::c_api::mts_tensormap_keys_to_samples(
333                self.ptr,
334                keys_to_move.as_mts_labels_t(),
335                sort_samples,
336            )
337        };
338
339        check_ptr(ptr)?;
340        return Ok(unsafe { TensorMap::from_raw(ptr) });
341    }
342
343    /// Merge blocks with the same value for selected keys dimensions along the
344    /// property axis.
345    ///
346    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
347    /// the property labels, and blocks with the same remaining keys dimensions
348    /// will be merged together along the property axis.
349    ///
350    /// If `keys_to_move` does not contains any entries (`keys_to_move.count()
351    /// == 0`), then the new property labels will contain entries corresponding
352    /// to the merged blocks only. For example, merging a block with key `a=0`
353    /// and properties `p=1, 2` with a block with key `a=2` and properties `p=1,
354    /// 3` will produce a block with properties `a, p = (0, 1), (0, 2), (2, 1),
355    /// (2, 3)`.
356    ///
357    /// If `keys_to_move` contains entries, then the property labels must be the
358    /// same for all the merged blocks. In that case, the merged property labels
359    /// will contains each of the entries of `keys_to_move` and then the current
360    /// property labels. For example, using `a=2, 3` in `keys_to_move`, and
361    /// blocks with properties `p=1, 2` will result in `a, p = (2, 1), (2, 2),
362    /// (3, 1), (3, 2)`.
363    ///
364    /// The new sample labels will contains all of the merged blocks sample
365    /// labels. The order of the samples is controlled by `sort_samples`. If
366    /// `sort_samples` is true, samples are re-ordered to keep them
367    /// lexicographically sorted. Otherwise they are kept in the order in which
368    /// they appear in the blocks.
369    #[inline]
370    pub fn keys_to_properties(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
371        let ptr = unsafe {
372            crate::c_api::mts_tensormap_keys_to_properties(
373                self.ptr,
374                keys_to_move.as_mts_labels_t(),
375                sort_samples,
376            )
377        };
378
379        check_ptr(ptr)?;
380        return Ok(unsafe { TensorMap::from_raw(ptr) });
381    }
382
383    /// Move the given dimensions from the component labels to the property
384    /// labels for each block in this `TensorMap`.
385    #[inline]
386    pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
387        let dimensions_c = dimensions.iter()
388            .map(|&v| CString::new(v).expect("unexpected NULL byte"))
389            .collect::<Vec<_>>();
390
391        let dimensions_ptr = dimensions_c.iter()
392            .map(|v| v.as_ptr())
393            .collect::<Vec<_>>();
394
395
396        let ptr = unsafe {
397            crate::c_api::mts_tensormap_components_to_properties(
398                self.ptr,
399                dimensions_ptr.as_ptr(),
400                dimensions.len(),
401            )
402        };
403
404        check_ptr(ptr)?;
405        return Ok(unsafe { TensorMap::from_raw(ptr) });
406    }
407
408    /// Get an iterator over the keys and associated blocks
409    #[inline]
410    pub fn iter(&self) -> TensorMapIter<'_> {
411        return TensorMapIter {
412            inner: self.keys().iter().zip(self.blocks())
413        };
414    }
415
416    /// Get an iterator over the keys and associated blocks, with read-write
417    /// access to the blocks
418    #[inline]
419    pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
420        // we can not use `self.blocks_mut()` here, since it would
421        // double-borrow self
422        let mut blocks = Vec::new();
423        for i in 0..self.keys().count() {
424            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
425        }
426
427        return TensorMapIterMut {
428            inner: self.keys().into_iter().zip(blocks)
429        };
430    }
431
432    /// Get a parallel iterator over the keys and associated blocks
433    #[cfg(feature = "rayon")]
434    #[inline]
435    pub fn par_iter(&self) -> TensorMapParIter {
436        use rayon::prelude::*;
437        TensorMapParIter {
438            inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
439        }
440    }
441
442    /// Get a parallel iterator over the keys and associated blocks, with
443    /// read-write access to the blocks
444    #[cfg(feature = "rayon")]
445    #[inline]
446    pub fn par_iter_mut(&mut self) -> TensorMapParIterMut {
447        use rayon::prelude::*;
448
449        // we can not use `self.blocks_mut()` here, since it would
450        // double-borrow self
451        let mut blocks = Vec::new();
452        for i in 0..self.keys().count() {
453            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
454        }
455
456        TensorMapParIterMut {
457            inner: self.keys().par_iter().zip_eq(blocks)
458        }
459    }
460}
461
462/******************************************************************************/
463
464/// Iterator over key/block pairs in a [`TensorMap`]
465pub struct TensorMapIter<'a> {
466    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
467}
468
469impl<'a> Iterator for TensorMapIter<'a> {
470    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
471
472    #[inline]
473    fn next(&mut self) -> Option<Self::Item> {
474        self.inner.next()
475    }
476
477    fn size_hint(&self) -> (usize, Option<usize>) {
478        self.inner.size_hint()
479    }
480}
481
482impl ExactSizeIterator for TensorMapIter<'_> {
483    #[inline]
484    fn len(&self) -> usize {
485        self.inner.len()
486    }
487}
488
489impl FusedIterator for TensorMapIter<'_> {}
490
491impl<'a> IntoIterator for &'a TensorMap {
492    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
493
494    type IntoIter = TensorMapIter<'a>;
495
496    fn into_iter(self) -> Self::IntoIter {
497        self.iter()
498    }
499}
500
501/******************************************************************************/
502
503/// Iterator over key/block pairs in a [`TensorMap`], with mutable access to the
504/// blocks
505pub struct TensorMapIterMut<'a> {
506    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
507}
508
509impl<'a> Iterator for TensorMapIterMut<'a> {
510    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
511
512    #[inline]
513    fn next(&mut self) -> Option<Self::Item> {
514        self.inner.next()
515    }
516
517    fn size_hint(&self) -> (usize, Option<usize>) {
518        self.inner.size_hint()
519    }
520}
521
522impl ExactSizeIterator for TensorMapIterMut<'_> {
523    #[inline]
524    fn len(&self) -> usize {
525        self.inner.len()
526    }
527}
528
529impl FusedIterator for TensorMapIterMut<'_> {}
530
531impl<'a> IntoIterator for &'a mut TensorMap {
532    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
533
534    type IntoIter = TensorMapIterMut<'a>;
535
536    fn into_iter(self) -> Self::IntoIter {
537        self.iter_mut()
538    }
539}
540
541
542/******************************************************************************/
543
544/// Parallel iterator over key/block pairs in a [`TensorMap`]
545#[cfg(feature = "rayon")]
546pub struct TensorMapParIter<'a> {
547    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
548}
549
550#[cfg(feature = "rayon")]
551impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
552    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
553
554    #[inline]
555    fn drive_unindexed<C>(self, consumer: C) -> C::Result
556    where
557        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
558        self.inner.drive_unindexed(consumer)
559    }
560}
561
562#[cfg(feature = "rayon")]
563impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
564    #[inline]
565    fn len(&self) -> usize {
566        self.inner.len()
567    }
568
569    #[inline]
570    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
571        self.inner.drive(consumer)
572    }
573
574    #[inline]
575    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
576        self.inner.with_producer(callback)
577    }
578}
579
580/******************************************************************************/
581
582/// Parallel iterator over key/block pairs in a [`TensorMap`], with mutable
583/// access to the blocks
584#[cfg(feature = "rayon")]
585pub struct TensorMapParIterMut<'a> {
586    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
587}
588
589#[cfg(feature = "rayon")]
590impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
591    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
592
593    #[inline]
594    fn drive_unindexed<C>(self, consumer: C) -> C::Result
595    where
596        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
597        self.inner.drive_unindexed(consumer)
598    }
599}
600
601#[cfg(feature = "rayon")]
602impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
603    #[inline]
604    fn len(&self) -> usize {
605        self.inner.len()
606    }
607
608    #[inline]
609    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
610        self.inner.drive(consumer)
611    }
612
613    #[inline]
614    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
615        self.inner.with_producer(callback)
616    }
617}
618
619/******************************************************************************/
620
621#[cfg(test)]
622mod tests {
623    use crate::{Labels, TensorBlock, TensorMap};
624
625    #[test]
626    #[allow(clippy::cast_lossless, clippy::float_cmp)]
627    fn iter() {
628        let block_1 = TensorBlock::new(
629            ndarray::ArrayD::from_elem(vec![2, 3], 1.0),
630            &Labels::new(["samples"], &[[0], [1]]),
631            &[],
632            &Labels::new(["properties"], &[[-2], [0], [1]]),
633        ).unwrap();
634
635        let block_2 = TensorBlock::new(
636            ndarray::ArrayD::from_elem(vec![1, 1], 3.0),
637            &Labels::new(["samples"], &[[1]]),
638            &[],
639            &Labels::new(["properties"], &[[1]]),
640        ).unwrap();
641
642        let block_3 = TensorBlock::new(
643            ndarray::ArrayD::from_elem(vec![3, 2], -4.0),
644            &Labels::new(["samples"], &[[0], [1], [3]]),
645            &[],
646            &Labels::new(["properties"], &[[-2], [1]]),
647        ).unwrap();
648
649        let mut tensor = TensorMap::new(
650            Labels::new(["key"], &[[1], [3], [-4]]),
651            vec![block_1, block_2, block_3],
652        ).unwrap();
653
654        // iterate over keys & blocks
655        for (key, block) in &tensor {
656            assert_eq!(block.values().to_array()[[0, 0]], key[0].i32() as f64);
657        }
658
659        // iterate over keys & blocks mutably
660        for (key, mut block) in &mut tensor {
661            let array = block.values_mut().to_array_mut();
662            *array *= 2.0;
663            assert_eq!(array[[0, 0]], 2.0 * (key[0].i32() as f64));
664        }
665    }
666}