metatensor/
tensor.rs

1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::{mts_tensormap_t, mts_labels_t};
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, TensorBlock, TensorBlockRef, Labels, LabelValue};
9
10/// [`TensorMap`] is the main user-facing struct of this library, and can
11/// store any kind of data used in atomistic machine learning.
12///
13/// A tensor map contains a list of `TensorBlock`s, each one associated with a
14/// key in the form of a single `Labels` entry.
15///
16/// It provides functions to merge blocks together by moving some of these keys
17/// to the samples or properties labels of the blocks, transforming the sparse
18/// representation of the data to a dense one.
19pub struct TensorMap {
20    pub(crate) ptr: *mut mts_tensormap_t,
21    /// cache for the keys labels
22    keys: Labels,
23}
24
25// SAFETY: Send is fine since we can free a TensorMap from any thread
26unsafe impl Send for TensorMap {}
27// SAFETY: Sync is fine since there is no internal mutability in TensorMap
28unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        use crate::labels::pretty_print_labels;
33        writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35        write!(f, "    keys: ")?;
36        pretty_print_labels(self.keys(), "    ", f)?;
37        writeln!(f, "}}")
38    }
39}
40
41impl std::ops::Drop for TensorMap {
42    #[allow(unused_must_use)]
43    fn drop(&mut self) {
44        unsafe {
45            crate::c_api::mts_tensormap_free(self.ptr);
46        }
47    }
48}
49
50impl TensorMap {
51    /// Create a new `TensorMap` with the given keys and blocks.
52    ///
53    /// The number of keys must match the number of blocks, and all the blocks
54    /// must contain the same kind of data (same labels names, same gradients
55    /// defined on all blocks).
56    #[allow(clippy::needless_pass_by_value)]
57    #[inline]
58    pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59        let ptr = unsafe {
60            crate::c_api::mts_tensormap(
61                keys.as_mts_labels_t(),
62                // this cast is fine because TensorBlock is `repr(transparent)`
63                // to a `*mut mts_block_t` (through `TensorBlockRefMut`, and
64                // `TensorBlockRef`).
65                blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66                blocks.len()
67            )
68        };
69
70        for block in blocks {
71            // we give ownership of the blocks to the new tensormap, so we
72            // should not free them again from Rust
73            std::mem::forget(block);
74        }
75
76        check_ptr(ptr)?;
77
78        return Ok(unsafe { TensorMap::from_raw(ptr) });
79    }
80
81    /// Create a new `TensorMap` from a raw pointer.
82    ///
83    /// This function takes ownership of the pointer, and will call
84    /// `mts_tensormap_free` on it when the `TensorMap` goes out of scope.
85    ///
86    /// # Safety
87    ///
88    /// The pointer must be non-null and created by `mts_tensormap` or
89    /// `TensorMap::into_raw`.
90    pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91        assert!(!ptr.is_null());
92
93        let mut keys = mts_labels_t::null();
94        check_status(crate::c_api::mts_tensormap_keys(
95            ptr,
96            &mut keys
97        )).expect("failed to get the keys");
98
99        let keys = Labels::from_raw(keys);
100
101        return TensorMap {
102            ptr,
103            keys
104        };
105    }
106
107    /// Extract the underlying raw pointer.
108    ///
109    /// The pointer should be passed back to `TensorMap::from_raw` or
110    /// `mts_tensormap_free` to release the memory corresponding to this
111    /// `TensorMap`.
112    pub fn into_raw(mut map: TensorMap) -> *mut mts_tensormap_t {
113        let ptr = map.ptr;
114        map.ptr = std::ptr::null_mut();
115        return ptr;
116    }
117
118    /// Clone this `TensorMap`, cloning all the data and metadata contained inside.
119    ///
120    /// This can fail if the external data held inside an `mts_array_t` can not
121    /// be cloned.
122    #[inline]
123    pub fn try_clone(&self) -> Result<TensorMap, Error> {
124        let ptr = unsafe {
125            crate::c_api::mts_tensormap_copy(self.ptr)
126        };
127        crate::errors::check_ptr(ptr)?;
128
129        return Ok(unsafe { TensorMap::from_raw(ptr) });
130    }
131
132    /// Load a `TensorMap` from the file at `path`
133    ///
134    /// This is a convenience function calling [`crate::io::load`]
135    pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
136        return crate::io::load(path);
137    }
138
139    /// Load a `TensorMap` from an in-memory buffer
140    ///
141    /// This is a convenience function calling [`crate::io::load_buffer`]
142    pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
143        return crate::io::load_buffer(buffer);
144    }
145
146    /// Save the given tensor to the file at `path`
147    ///
148    /// This is a convenience function calling [`crate::io::save`]
149    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
150        return crate::io::save(path, self);
151    }
152
153    /// Save the given tensor to an in-memory buffer
154    ///
155    /// This is a convenience function calling [`crate::io::save_buffer`]
156    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
157        return crate::io::save_buffer(self, buffer);
158    }
159
160    /// Get the keys defined in this `TensorMap`
161    #[inline]
162    pub fn keys(&self) -> &Labels {
163        &self.keys
164    }
165
166    /// Get a reference to the block at the given `index` in this `TensorMap`
167    ///
168    /// # Panics
169    ///
170    /// If the index is out of bounds
171    #[inline]
172    pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
173
174        let mut block = std::ptr::null_mut();
175        unsafe {
176            check_status(crate::c_api::mts_tensormap_block_by_id(
177                self.ptr,
178                &mut block,
179                index,
180            )).expect("failed to get a block");
181        }
182
183        return unsafe { TensorBlockRef::from_raw(block) }
184    }
185
186    /// Get a mutable reference to the block at the given `index` in this `TensorMap`
187    ///
188    /// # Panics
189    ///
190    /// If the index is out of bounds
191    #[inline]
192    pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
193        return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
194    }
195
196    /// Implementation of `block_mut_by_id` which does not borrow the
197    /// `mts_tensormap_t` pointer.
198    ///
199    /// This is used to provide references to multiple blocks at the same time
200    /// in the iterators.
201    ///
202    /// # Safety
203    ///
204    /// This should be called with a valid `mts_tensormap_t`, and the lifetime
205    /// `'a` should be properly constrained to the lifetime of the owner of
206    /// `ptr`.
207    #[inline]
208    unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
209        let mut block = std::ptr::null_mut();
210
211        check_status(crate::c_api::mts_tensormap_block_by_id(
212            ptr,
213            &mut block,
214            index,
215        )).expect("failed to get a block");
216
217        return TensorBlockRefMut::from_raw(block);
218    }
219
220    /// Get the index of blocks matching the given selection.
221    ///
222    /// The selection must contains a single entry, defining the requested key
223    /// or keys. If the selection contains only a subset of the dimensions of the
224    /// keys, there can be multiple matching blocks.
225    #[inline]
226    pub fn blocks_matching(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
227        let mut indexes = vec![0; self.keys().count()];
228        let mut matching = indexes.len();
229        unsafe {
230            check_status(crate::c_api::mts_tensormap_blocks_matching(
231                self.ptr,
232                indexes.as_mut_ptr(),
233                &mut matching,
234                selection.as_mts_labels_t(),
235            ))?;
236        }
237        indexes.resize(matching, 0);
238
239        return Ok(indexes);
240    }
241
242    /// Get the index of the single block matching the given selection.
243    ///
244    /// This function is similar to [`TensorMap::blocks_matching`], but also
245    /// returns an error if more than one block matches the selection.
246    #[inline]
247    pub fn block_matching(&self, selection: &Labels) -> Result<usize, Error> {
248        let matching = self.blocks_matching(selection)?;
249        if matching.len() != 1 {
250            let selection_str = selection.names()
251                .iter().zip(&selection[0])
252                .map(|(name, value)| format!("{} = {}", name, value))
253                .collect::<Vec<_>>()
254                .join(", ");
255
256
257            if matching.is_empty() {
258                return Err(Error {
259                    code: None,
260                    message: format!(
261                        "no blocks matched the selection ({})",
262                        selection_str
263                    ),
264                });
265            } else {
266                return Err(Error {
267                    code: None,
268                    message: format!(
269                        "{} blocks matched the selection ({}), expected only one",
270                        matching.len(),
271                        selection_str
272                    ),
273                });
274            }
275        }
276
277        return Ok(matching[0])
278    }
279
280    /// Get a reference to the block matching the given selection.
281    ///
282    /// This function uses [`TensorMap::blocks_matching`] under the hood to find
283    /// the matching block.
284    #[inline]
285    pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
286        let id = self.block_matching(selection)?;
287        return Ok(self.block_by_id(id));
288    }
289
290    /// Get a reference to every blocks in this `TensorMap`
291    #[inline]
292    pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
293        let mut blocks = Vec::new();
294        for i in 0..self.keys().count() {
295            blocks.push(self.block_by_id(i));
296        }
297        return blocks;
298    }
299
300    /// Get a mutable reference to every blocks in this `TensorMap`
301    #[inline]
302    pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
303        let mut blocks = Vec::new();
304        for i in 0..self.keys().count() {
305            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
306        }
307        return blocks;
308    }
309
310    /// Merge blocks with the same value for selected keys dimensions along the
311    /// samples axis.
312    ///
313    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
314    /// the sample labels, and blocks with the same remaining keys dimensions
315    /// will be merged together along the sample axis.
316    ///
317    /// `keys_to_move` must be empty (`keys_to_move.count() == 0`), and the new
318    /// sample labels will contain entries corresponding to the merged blocks'
319    /// keys.
320    ///
321    /// The new sample labels will contains all of the merged blocks sample
322    /// labels. The order of the samples is controlled by `sort_samples`. If
323    /// `sort_samples` is true, samples are re-ordered to keep them
324    /// lexicographically sorted. Otherwise they are kept in the order in which
325    /// they appear in the blocks.
326    ///
327    /// This function is only implemented if all merged block have the same
328    /// property labels.
329    #[inline]
330    pub fn keys_to_samples(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
331        let ptr = unsafe {
332            crate::c_api::mts_tensormap_keys_to_samples(
333                self.ptr,
334                keys_to_move.as_mts_labels_t(),
335                sort_samples,
336            )
337        };
338
339        check_ptr(ptr)?;
340        return Ok(unsafe { TensorMap::from_raw(ptr) });
341    }
342
343    /// Merge blocks with the same value for selected keys dimensions along the
344    /// property axis.
345    ///
346    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
347    /// the property labels, and blocks with the same remaining keys dimensions
348    /// will be merged together along the property axis.
349    ///
350    /// If `keys_to_move` does not contains any entries (`keys_to_move.count()
351    /// == 0`), then the new property labels will contain entries corresponding
352    /// to the merged blocks only. For example, merging a block with key `a=0`
353    /// and properties `p=1, 2` with a block with key `a=2` and properties `p=1,
354    /// 3` will produce a block with properties `a, p = (0, 1), (0, 2), (2, 1),
355    /// (2, 3)`.
356    ///
357    /// If `keys_to_move` contains entries, then the property labels must be the
358    /// same for all the merged blocks. In that case, the merged property labels
359    /// will contains each of the entries of `keys_to_move` and then the current
360    /// property labels. For example, using `a=2, 3` in `keys_to_move`, and
361    /// blocks with properties `p=1, 2` will result in `a, p = (2, 1), (2, 2),
362    /// (3, 1), (3, 2)`.
363    ///
364    /// The new sample labels will contains all of the merged blocks sample
365    /// labels. The order of the samples is controlled by `sort_samples`. If
366    /// `sort_samples` is true, samples are re-ordered to keep them
367    /// lexicographically sorted. Otherwise they are kept in the order in which
368    /// they appear in the blocks.
369    #[inline]
370    pub fn keys_to_properties(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
371        let ptr = unsafe {
372            crate::c_api::mts_tensormap_keys_to_properties(
373                self.ptr,
374                keys_to_move.as_mts_labels_t(),
375                sort_samples,
376            )
377        };
378
379        check_ptr(ptr)?;
380        return Ok(unsafe { TensorMap::from_raw(ptr) });
381    }
382
383    /// Move the given dimensions from the component labels to the property
384    /// labels for each block in this `TensorMap`.
385    #[inline]
386    pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
387        let dimensions_c = dimensions.iter()
388            .map(|&v| CString::new(v).expect("unexpected NULL byte"))
389            .collect::<Vec<_>>();
390
391        let dimensions_ptr = dimensions_c.iter()
392            .map(|v| v.as_ptr())
393            .collect::<Vec<_>>();
394
395
396        let ptr = unsafe {
397            crate::c_api::mts_tensormap_components_to_properties(
398                self.ptr,
399                dimensions_ptr.as_ptr(),
400                dimensions.len(),
401            )
402        };
403
404        check_ptr(ptr)?;
405        return Ok(unsafe { TensorMap::from_raw(ptr) });
406    }
407
408    /// Get an iterator over the keys and associated blocks
409    #[inline]
410    pub fn iter(&self) -> TensorMapIter<'_> {
411        return TensorMapIter {
412            inner: self.keys().iter().zip(self.blocks())
413        };
414    }
415
416    /// Get an iterator over the keys and associated blocks, with read-write
417    /// access to the blocks
418    #[inline]
419    pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
420        // we can not use `self.blocks_mut()` here, since it would
421        // double-borrow self
422        let mut blocks = Vec::new();
423        for i in 0..self.keys().count() {
424            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
425        }
426
427        return TensorMapIterMut {
428            inner: self.keys().into_iter().zip(blocks)
429        };
430    }
431
432    /// Get a parallel iterator over the keys and associated blocks
433    #[cfg(feature = "rayon")]
434    #[inline]
435    pub fn par_iter(&self) -> TensorMapParIter<'_> {
436        use rayon::prelude::*;
437        TensorMapParIter {
438            inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
439        }
440    }
441
442    /// Get a parallel iterator over the keys and associated blocks, with
443    /// read-write access to the blocks
444    #[cfg(feature = "rayon")]
445    #[inline]
446    pub fn par_iter_mut(&mut self) -> TensorMapParIterMut<'_> {
447        use rayon::prelude::*;
448
449        // we can not use `self.blocks_mut()` here, since it would
450        // double-borrow self
451        let mut blocks = Vec::new();
452        for i in 0..self.keys().count() {
453            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
454        }
455
456        TensorMapParIterMut {
457            inner: self.keys().par_iter().zip_eq(blocks)
458        }
459    }
460
461    /// Set or update the info (i.e. global metadata) `value` associated with
462    /// `key` for this `TensorMap`.
463    pub fn set_info(&mut self, key: &str, value: &str) {
464        let mut key = key.to_owned().into_bytes();
465        key.push(b'\0');
466
467        let mut value = value.to_owned().into_bytes();
468        value.push(b'\0');
469
470        unsafe {
471            check_status(crate::c_api::mts_tensormap_set_info(
472                self.ptr, key.as_ptr().cast(), value.as_ptr().cast()
473            )).expect("failed to set info");
474        }
475    }
476
477    /// Get the info (i.e. global metadata) with the given `key` for this
478    /// `TensorMap`.
479    pub fn get_info(&self, key: &str) -> Option<&str> {
480        let mut key = key.to_owned().into_bytes();
481        key.push(b'\0');
482
483        let mut value = std::ptr::null();
484
485        unsafe {
486            check_status(crate::c_api::mts_tensormap_get_info(
487                self.ptr, key.as_ptr().cast(), &mut value
488            )).expect("failed to set info");
489        }
490
491        if value.is_null() {
492            return None;
493        }
494
495        let c_str = unsafe { CStr::from_ptr(value) };
496        return Some(c_str.to_str().expect("invalid UTF-8 string"));
497    }
498
499    /// Get an iterator over all the key/value info pairs stored in this
500    /// `TensorMap`.
501    pub fn info(&self) -> TensorMapInfoIter<'_> {
502        let mut keys = std::ptr::null();
503        let mut count = 0;
504        unsafe {
505            check_status(crate::c_api::mts_tensormap_info_keys(
506                self.ptr,
507                &mut keys,
508                &mut count,
509            )).expect("failed to get info keys");
510        };
511
512        let keys = unsafe {
513            std::slice::from_raw_parts(keys, count)
514        };
515        let keys = keys.iter()
516            .map(|&k| {
517                let c_str = unsafe { CStr::from_ptr(k) };
518                c_str.to_str().expect("invalid UTF-8 string")
519            })
520            .collect::<Vec<_>>();
521
522        TensorMapInfoIter {
523            keys: keys,
524            tensor: self,
525            index: 0,
526            count,
527        }
528    }
529}
530
531/******************************************************************************/
532
533/// Iterator over key/block pairs in a [`TensorMap`]
534pub struct TensorMapIter<'a> {
535    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
536}
537
538impl<'a> Iterator for TensorMapIter<'a> {
539    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
540
541    #[inline]
542    fn next(&mut self) -> Option<Self::Item> {
543        self.inner.next()
544    }
545
546    fn size_hint(&self) -> (usize, Option<usize>) {
547        self.inner.size_hint()
548    }
549}
550
551impl ExactSizeIterator for TensorMapIter<'_> {
552    #[inline]
553    fn len(&self) -> usize {
554        self.inner.len()
555    }
556}
557
558impl FusedIterator for TensorMapIter<'_> {}
559
560impl<'a> IntoIterator for &'a TensorMap {
561    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
562
563    type IntoIter = TensorMapIter<'a>;
564
565    fn into_iter(self) -> Self::IntoIter {
566        self.iter()
567    }
568}
569
570/******************************************************************************/
571
572/// Iterator over key/block pairs in a [`TensorMap`], with mutable access to the
573/// blocks
574pub struct TensorMapIterMut<'a> {
575    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
576}
577
578impl<'a> Iterator for TensorMapIterMut<'a> {
579    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
580
581    #[inline]
582    fn next(&mut self) -> Option<Self::Item> {
583        self.inner.next()
584    }
585
586    fn size_hint(&self) -> (usize, Option<usize>) {
587        self.inner.size_hint()
588    }
589}
590
591impl ExactSizeIterator for TensorMapIterMut<'_> {
592    #[inline]
593    fn len(&self) -> usize {
594        self.inner.len()
595    }
596}
597
598impl FusedIterator for TensorMapIterMut<'_> {}
599
600impl<'a> IntoIterator for &'a mut TensorMap {
601    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
602
603    type IntoIter = TensorMapIterMut<'a>;
604
605    fn into_iter(self) -> Self::IntoIter {
606        self.iter_mut()
607    }
608}
609
610
611/******************************************************************************/
612
613/// Parallel iterator over key/block pairs in a [`TensorMap`]
614#[cfg(feature = "rayon")]
615pub struct TensorMapParIter<'a> {
616    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
617}
618
619#[cfg(feature = "rayon")]
620impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
621    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
622
623    #[inline]
624    fn drive_unindexed<C>(self, consumer: C) -> C::Result
625    where
626        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
627        self.inner.drive_unindexed(consumer)
628    }
629}
630
631#[cfg(feature = "rayon")]
632impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
633    #[inline]
634    fn len(&self) -> usize {
635        self.inner.len()
636    }
637
638    #[inline]
639    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
640        self.inner.drive(consumer)
641    }
642
643    #[inline]
644    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
645        self.inner.with_producer(callback)
646    }
647}
648
649/******************************************************************************/
650
651/// Parallel iterator over key/block pairs in a [`TensorMap`], with mutable
652/// access to the blocks
653#[cfg(feature = "rayon")]
654pub struct TensorMapParIterMut<'a> {
655    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
656}
657
658#[cfg(feature = "rayon")]
659impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
660    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
661
662    #[inline]
663    fn drive_unindexed<C>(self, consumer: C) -> C::Result
664    where
665        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
666        self.inner.drive_unindexed(consumer)
667    }
668}
669
670#[cfg(feature = "rayon")]
671impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
672    #[inline]
673    fn len(&self) -> usize {
674        self.inner.len()
675    }
676
677    #[inline]
678    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
679        self.inner.drive(consumer)
680    }
681
682    #[inline]
683    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
684        self.inner.with_producer(callback)
685    }
686}
687
688/******************************************************************************/
689
690/// Iterator over info key/value pairs in a `TensorMap`
691pub struct TensorMapInfoIter<'a> {
692    keys: Vec<&'a str>,
693    tensor: &'a TensorMap,
694    index: usize,
695    count: usize,
696}
697
698impl<'a> Iterator for TensorMapInfoIter<'a> {
699    type Item = (&'a str, &'a str);
700
701    #[inline]
702    fn next(&mut self) -> Option<Self::Item> {
703        if self.index >= self.count {
704            return None;
705        }
706        let key = self.keys[self.index];
707        let value = self.tensor.get_info(key).expect("missing info");
708        self.index += 1;
709        return Some((key, value));
710    }
711
712    fn size_hint(&self) -> (usize, Option<usize>) {
713        (self.count, Some(self.count))
714    }
715}
716
717impl ExactSizeIterator for TensorMapInfoIter<'_> {
718    #[inline]
719    fn len(&self) -> usize {
720        self.count
721    }
722}
723
724impl FusedIterator for TensorMapInfoIter<'_> {}
725
726
727/******************************************************************************/
728
729#[cfg(test)]
730#[allow(clippy::float_cmp)]
731mod tests {
732    use crate::{Labels, TensorBlock, TensorMap};
733
734    fn test_tensor() -> TensorMap {
735        let block_1 = TensorBlock::new(
736            ndarray::ArrayD::from_elem(vec![2, 3], 1.0),
737            &Labels::new(["samples"], &[[0], [1]]),
738            &[],
739            &Labels::new(["properties"], &[[-2], [0], [1]]),
740        ).unwrap();
741
742        let block_2 = TensorBlock::new(
743            ndarray::ArrayD::from_elem(vec![1, 1], 3.0),
744            &Labels::new(["samples"], &[[1]]),
745            &[],
746            &Labels::new(["properties"], &[[1]]),
747        ).unwrap();
748
749        let block_3 = TensorBlock::new(
750            ndarray::ArrayD::from_elem(vec![3, 2], -4.0),
751            &Labels::new(["samples"], &[[0], [1], [3]]),
752            &[],
753            &Labels::new(["properties"], &[[-2], [1]]),
754        ).unwrap();
755
756        return TensorMap::new(
757            Labels::new(["key", "other"], &[[1, 0], [3, 1], [-4, 0]]),
758            vec![block_1, block_2, block_3],
759        ).unwrap();
760    }
761
762    #[test]
763    fn block_access() {
764        let mut tensor = test_tensor();
765
766        let block = tensor.block_by_id(1);
767        assert_eq!(block.values().as_array().shape(), [1, 1]);
768
769        let block = tensor.block_mut_by_id(2);
770        assert_eq!(block.values().as_array().shape(), [3, 2]);
771
772        let selection = Labels::new(["key"], &[[1]]);
773        assert_eq!(tensor.block_matching(&selection).unwrap(), 0);
774        assert_eq!(tensor.blocks_matching(&selection).unwrap(), [0]);
775
776        let block = tensor.block(&selection).unwrap();
777        assert_eq!(block.values().as_array().shape(), [2, 3]);
778
779        let selection = Labels::new(["other"], &[[0]]);
780        assert!(tensor.block_matching(&selection).is_err());
781        assert_eq!(tensor.blocks_matching(&selection).unwrap(), [0, 2]);
782
783        let blocks = tensor.blocks();
784        assert_eq!(blocks[0].values().as_array().shape(), [2, 3]);
785        assert_eq!(blocks[1].values().as_array().shape(), [1, 1]);
786        assert_eq!(blocks[2].values().as_array().shape(), [3, 2]);
787
788        let blocks = tensor.blocks_mut();
789        assert_eq!(blocks[0].values().as_array().shape(), [2, 3]);
790        assert_eq!(blocks[1].values().as_array().shape(), [1, 1]);
791        assert_eq!(blocks[2].values().as_array().shape(), [3, 2]);
792    }
793
794    #[test]
795    fn iter() {
796        let mut tensor = test_tensor();
797
798        // iterate over keys & blocks
799        for (key, block) in &tensor {
800            assert_eq!(block.values().to_array()[[0, 0]], f64::from(key[0].i32()));
801        }
802
803        // iterate over keys & blocks mutably
804        for (key, mut block) in &mut tensor {
805            let array = block.values_mut().to_array_mut();
806            *array *= 2.0;
807            assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
808        }
809    }
810
811    #[cfg(feature = "rayon")]
812    #[test]
813    fn par_iter() {
814        use rayon::iter::ParallelIterator;
815
816        let mut tensor = test_tensor();
817
818        // iterate over keys & blocks
819        tensor.par_iter().for_each(|(key, block)| {
820            assert_eq!(block.values().to_array()[[0, 0]], f64::from(key[0].i32()));
821        });
822
823        // iterate over keys & blocks mutably
824        tensor.par_iter_mut().for_each(|(key, mut block)| {
825            let array = block.values_mut().to_array_mut();
826            *array *= 2.0;
827            assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
828        });
829    }
830
831    #[test]
832    fn info() {
833        let mut tensor = test_tensor();
834        tensor.set_info("creator", "unit test");
835        tensor.set_info("version", "1.0");
836
837        assert_eq!(tensor.get_info("creator").unwrap(), "unit test");
838        assert_eq!(tensor.get_info("version").unwrap(), "1.0");
839        assert!(tensor.get_info("missing").is_none());
840
841        let mut info_iter = tensor.info();
842        let (key, value) = info_iter.next().unwrap();
843        assert_eq!(key, "creator");
844        assert_eq!(value, "unit test");
845        let (key, value) = info_iter.next().unwrap();
846        assert_eq!(key, "version");
847        assert_eq!(value, "1.0");
848        assert!(info_iter.next().is_none());
849    }
850}