Skip to main content

metatensor/
tensor.rs

1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::mts_tensormap_t;
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, LabelValue, Labels, MtsArray, TensorBlock, TensorBlockRef};
9
10/// [`TensorMap`] is the main user-facing struct of this library, and can
11/// store any kind of data used in atomistic machine learning.
12///
13/// A tensor map contains a list of `TensorBlock`s, each one associated with a
14/// key in the form of a single `Labels` entry.
15///
16/// It provides functions to merge blocks together by moving some of these keys
17/// to the samples or properties labels of the blocks, transforming the sparse
18/// representation of the data to a dense one.
19pub struct TensorMap {
20    pub(crate) ptr: *mut mts_tensormap_t,
21    /// cache for the keys labels
22    keys: Labels,
23}
24
25// SAFETY: Send is fine since we can free a TensorMap from any thread
26unsafe impl Send for TensorMap {}
27// SAFETY: Sync is fine since there is no internal mutability in TensorMap
28unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        use crate::labels::pretty_print_labels;
33        writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35        write!(f, "    keys: ")?;
36        pretty_print_labels(self.keys(), "    ", f)?;
37        writeln!(f, "}}")
38    }
39}
40
41impl std::ops::Drop for TensorMap {
42    #[allow(unused_must_use)]
43    fn drop(&mut self) {
44        unsafe {
45            crate::c_api::mts_tensormap_free(self.ptr);
46        }
47    }
48}
49
50impl TensorMap {
51    /// Create a new `TensorMap` with the given keys and blocks.
52    ///
53    /// The number of keys must match the number of blocks, and all the blocks
54    /// must contain the same kind of data (same labels names, same gradients
55    /// defined on all blocks).
56    #[allow(clippy::needless_pass_by_value)]
57    #[inline]
58    pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59        let ptr = unsafe {
60            crate::c_api::mts_tensormap(
61                keys.as_mts_labels_t(),
62                // this cast is fine because TensorBlock is `repr(transparent)`
63                // to a `*mut mts_block_t` (through `TensorBlockRefMut`, and
64                // `TensorBlockRef`).
65                blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66                blocks.len()
67            )
68        };
69
70        for block in blocks {
71            // we give ownership of the blocks to the new tensormap, so we
72            // should not free them again from Rust
73            std::mem::forget(block);
74        }
75
76        check_ptr(ptr)?;
77
78        return Ok(unsafe { TensorMap::from_raw(ptr) });
79    }
80
81    /// Create a new `TensorMap` from a raw pointer.
82    ///
83    /// This function takes ownership of the pointer, and will call
84    /// `mts_tensormap_free` on it when the `TensorMap` goes out of scope.
85    ///
86    /// # Safety
87    ///
88    /// The pointer must be non-null and created by `mts_tensormap` or
89    /// `TensorMap::into_raw`.
90    pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91        assert!(!ptr.is_null());
92
93        let keys_ptr = crate::c_api::mts_tensormap_keys(ptr);
94        assert!(!keys_ptr.is_null(), "failed to get the keys");
95        let keys = Labels::from_raw(keys_ptr);
96
97        return TensorMap {
98            ptr,
99            keys
100        };
101    }
102
103    /// Extract the underlying raw pointer.
104    ///
105    /// The pointer should be passed back to `TensorMap::from_raw` or
106    /// `mts_tensormap_free` to release the memory corresponding to this
107    /// `TensorMap`.
108    pub fn into_raw(mut map: TensorMap) -> *mut mts_tensormap_t {
109        let ptr = map.ptr;
110        map.ptr = std::ptr::null_mut();
111        return ptr;
112    }
113
114    /// Clone this `TensorMap`, cloning all the data and metadata contained inside.
115    ///
116    /// This can fail if the external data held inside an `mts_array_t` can not
117    /// be cloned.
118    #[inline]
119    pub fn try_clone(&self) -> Result<TensorMap, Error> {
120        let ptr = unsafe {
121            crate::c_api::mts_tensormap_copy(self.ptr)
122        };
123        crate::errors::check_ptr(ptr)?;
124
125        return Ok(unsafe { TensorMap::from_raw(ptr) });
126    }
127
128    /// Load a `TensorMap` from the file at `path`
129    ///
130    /// This is a convenience function calling [`crate::io::load`]
131    pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
132        return crate::io::load(path);
133    }
134
135    /// Load a `TensorMap` from an in-memory buffer
136    ///
137    /// This is a convenience function calling [`crate::io::load_buffer`]
138    pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
139        return crate::io::load_buffer(buffer);
140    }
141
142    /// Save the given tensor to the file at `path`
143    ///
144    /// This is a convenience function calling [`crate::io::save`]
145    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
146        return crate::io::save(path, self);
147    }
148
149    /// Save the given tensor to an in-memory buffer
150    ///
151    /// This is a convenience function calling [`crate::io::save_buffer`]
152    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
153        return crate::io::save_buffer(self, buffer);
154    }
155
156    /// Get the keys defined in this `TensorMap`
157    #[inline]
158    pub fn keys(&self) -> &Labels {
159        &self.keys
160    }
161
162    /// Get a reference to the block at the given `index` in this `TensorMap`
163    ///
164    /// # Panics
165    ///
166    /// If the index is out of bounds
167    #[inline]
168    pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
169
170        let mut block = std::ptr::null_mut();
171        unsafe {
172            check_status(crate::c_api::mts_tensormap_block_by_id(
173                self.ptr,
174                &mut block,
175                index,
176            )).expect("failed to get a block");
177        }
178
179        return unsafe { TensorBlockRef::from_raw(block) }
180    }
181
182    /// Get a mutable reference to the block at the given `index` in this `TensorMap`
183    ///
184    /// # Panics
185    ///
186    /// If the index is out of bounds
187    #[inline]
188    pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
189        return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
190    }
191
192    /// Implementation of `block_mut_by_id` which does not borrow the
193    /// `mts_tensormap_t` pointer.
194    ///
195    /// This is used to provide references to multiple blocks at the same time
196    /// in the iterators.
197    ///
198    /// # Safety
199    ///
200    /// This should be called with a valid `mts_tensormap_t`, and the lifetime
201    /// `'a` should be properly constrained to the lifetime of the owner of
202    /// `ptr`.
203    #[inline]
204    unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
205        let mut block = std::ptr::null_mut();
206
207        check_status(
208            crate::c_api::mts_tensormap_block_by_id(
209            ptr,
210            &mut block,
211            index,
212        )).expect("failed to get a block");
213
214        return TensorBlockRefMut::from_raw(block);
215    }
216
217    /// Get a reference to the block matching the given selection.
218    #[inline]
219    pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
220        let matching = self.keys.select(selection)?;
221        if matching.len() != 1 {
222            let selection_str = selection.names()
223                .iter()
224                .zip(&selection[0])
225                .map(|(name, value)| format!("{} = {}", name, value))
226                .collect::<Vec<_>>()
227                .join(", ");
228
229            if matching.is_empty() {
230                return Err(Error {
231                    code: None,
232                    message: format!(
233                        "no blocks matched the selection ({})",
234                        selection_str
235                    ),
236                });
237            } else {
238                return Err(Error {
239                    code: None,
240                    message: format!(
241                        "{} blocks matched the selection ({}), expected only one",
242                        matching.len(),
243                        selection_str
244                    ),
245                });
246            }
247        }
248
249        return Ok(self.block_by_id(matching[0]));
250    }
251
252    /// Get a reference to every blocks in this `TensorMap`
253    #[inline]
254    pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
255        let mut blocks = Vec::new();
256        for i in 0..self.keys().count() {
257            blocks.push(self.block_by_id(i));
258        }
259        return blocks;
260    }
261
262    /// Get a mutable reference to every blocks in this `TensorMap`
263    #[inline]
264    pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
265        let mut blocks = Vec::new();
266        for i in 0..self.keys().count() {
267            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
268        }
269        return blocks;
270    }
271
272    /// Merge blocks with the same value for selected keys dimensions along the
273    /// samples axis.
274    ///
275    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
276    /// the sample labels, and blocks with the same remaining keys dimensions
277    /// will be merged together along the sample axis.
278    ///
279    /// `keys_to_move` must be empty (`keys_to_move.count() == 0`), and the new
280    /// sample labels will contain entries corresponding to the merged blocks'
281    /// keys.
282    ///
283    /// The new sample labels will contain all of the merged blocks sample
284    /// labels. The order of the samples is controlled by `sort_samples`. If
285    /// `sort_samples` is true, samples are re-ordered to keep them
286    /// lexicographically sorted. Otherwise they are kept in the order in which
287    /// they appear in the blocks.
288    #[inline]
289    pub fn keys_to_samples(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
290        let ptr = unsafe {
291            crate::c_api::mts_tensormap_keys_to_samples(
292                self.ptr,
293                keys_to_move.as_mts_labels_t(),
294                fill_value.into_raw(),
295                sort_samples,
296            )
297        };
298
299        check_ptr(ptr)?;
300        return Ok(unsafe { TensorMap::from_raw(ptr) });
301    }
302
303    /// Merge blocks with the same value for selected keys dimensions along the
304    /// property axis.
305    ///
306    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
307    /// the property labels, and blocks with the same remaining keys dimensions
308    /// will be merged together along the property axis.
309    ///
310    /// If `keys_to_move` does not contain any entries (`keys_to_move.count()
311    /// == 0`), then the new property labels will contain entries corresponding
312    /// to the merged blocks only. For example, merging a block with key `a=0`
313    /// and properties `p=1, 2` with a block with key `a=2` and properties `p=1,
314    /// 3` will produce a block with properties `a, p = (0, 1), (0, 2), (2, 1),
315    /// (2, 3)`.
316    ///
317    /// If `keys_to_move` contains entries, then the property labels must be the
318    /// same for all the merged blocks. In that case, the merged property labels
319    /// will contain each of the entries of `keys_to_move` and then the current
320    /// property labels. For example, using `a=2, 3` in `keys_to_move`, and
321    /// blocks with properties `p=1, 2` will result in `a, p = (2, 1), (2, 2),
322    /// (3, 1), (3, 2)`.
323    ///
324    /// The new sample labels will contain all of the merged blocks sample
325    /// labels. The order of the samples is controlled by `sort_samples`. If
326    /// `sort_samples` is true, samples are re-ordered to keep them
327    /// lexicographically sorted. Otherwise they are kept in the order in which
328    /// they appear in the blocks.
329    #[inline]
330    pub fn keys_to_properties(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
331        let ptr = unsafe {
332            crate::c_api::mts_tensormap_keys_to_properties(
333                self.ptr,
334                keys_to_move.as_mts_labels_t(),
335                fill_value.into_raw(),
336                sort_samples,
337            )
338        };
339
340        check_ptr(ptr)?;
341        return Ok(unsafe { TensorMap::from_raw(ptr) });
342    }
343
344    /// Move the given dimensions from the component labels to the property
345    /// labels for each block in this `TensorMap`.
346    #[inline]
347    pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
348        let dimensions_c = dimensions.iter()
349            .map(|&v| CString::new(v).expect("unexpected NULL byte"))
350            .collect::<Vec<_>>();
351
352        let dimensions_ptr = dimensions_c.iter()
353            .map(|v| v.as_ptr())
354            .collect::<Vec<_>>();
355
356
357        let ptr = unsafe {
358            crate::c_api::mts_tensormap_components_to_properties(
359                self.ptr,
360                dimensions_ptr.as_ptr(),
361                dimensions.len(),
362            )
363        };
364
365        check_ptr(ptr)?;
366        return Ok(unsafe { TensorMap::from_raw(ptr) });
367    }
368
369    /// Get an iterator over the keys and associated blocks
370    #[inline]
371    pub fn iter(&self) -> TensorMapIter<'_> {
372        return TensorMapIter {
373            inner: self.keys().into_iter().zip(self.blocks())
374        };
375    }
376
377    /// Get an iterator over the keys and associated blocks, with read-write
378    /// access to the blocks
379    #[inline]
380    pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
381        // we can not use `self.blocks_mut()` here, since it would
382        // double-borrow self
383        let mut blocks = Vec::new();
384        for i in 0..self.keys().count() {
385            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
386        }
387
388        return TensorMapIterMut {
389            inner: self.keys().into_iter().zip(blocks)
390        };
391    }
392
393    /// Get a parallel iterator over the keys and associated blocks
394    #[cfg(feature = "rayon")]
395    #[inline]
396    pub fn par_iter(&self) -> TensorMapParIter<'_> {
397        use rayon::prelude::*;
398        TensorMapParIter {
399            inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
400        }
401    }
402
403    /// Get a parallel iterator over the keys and associated blocks, with
404    /// read-write access to the blocks
405    #[cfg(feature = "rayon")]
406    #[inline]
407    pub fn par_iter_mut(&mut self) -> TensorMapParIterMut<'_> {
408        use rayon::prelude::*;
409
410        // we can not use `self.blocks_mut()` here, since it would
411        // double-borrow self
412        let mut blocks = Vec::new();
413        for i in 0..self.keys().count() {
414            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
415        }
416
417        TensorMapParIterMut {
418            inner: self.keys().par_iter().zip_eq(blocks)
419        }
420    }
421
422    /// Set or update the info (i.e. global metadata) `value` associated with
423    /// `key` for this `TensorMap`.
424    pub fn set_info(&mut self, key: &str, value: &str) {
425        let mut key = key.to_owned().into_bytes();
426        key.push(b'\0');
427
428        let mut value = value.to_owned().into_bytes();
429        value.push(b'\0');
430
431        unsafe {
432            check_status(crate::c_api::mts_tensormap_set_info(
433                self.ptr, key.as_ptr().cast(), value.as_ptr().cast()
434            )).expect("failed to set info");
435        }
436    }
437
438    /// Get the info (i.e. global metadata) with the given `key` for this
439    /// `TensorMap`.
440    pub fn get_info(&self, key: &str) -> Option<&str> {
441        let mut key = key.to_owned().into_bytes();
442        key.push(b'\0');
443
444        let mut value = std::ptr::null();
445
446        unsafe {
447            check_status(crate::c_api::mts_tensormap_get_info(
448                self.ptr, key.as_ptr().cast(), &mut value
449            )).expect("failed to set info");
450        }
451
452        if value.is_null() {
453            return None;
454        }
455
456        let c_str = unsafe { CStr::from_ptr(value) };
457        return Some(c_str.to_str().expect("invalid UTF-8 string"));
458    }
459
460    /// Get an iterator over all the key/value info pairs stored in this
461    /// `TensorMap`.
462    pub fn info(&self) -> TensorMapInfoIter<'_> {
463        let mut keys = std::ptr::null();
464        let mut count = 0;
465        unsafe {
466            check_status(crate::c_api::mts_tensormap_info_keys(
467                self.ptr,
468                &mut keys,
469                &mut count,
470            )).expect("failed to get info keys");
471        };
472
473        let keys = unsafe {
474            std::slice::from_raw_parts(keys, count)
475        };
476        let keys = keys.iter()
477            .map(|&k| {
478                let c_str = unsafe { CStr::from_ptr(k) };
479                c_str.to_str().expect("invalid UTF-8 string")
480            })
481            .collect::<Vec<_>>();
482
483        TensorMapInfoIter {
484            keys: keys,
485            tensor: self,
486            index: 0,
487            count,
488        }
489    }
490}
491
492/******************************************************************************/
493
494/// Iterator over key/block pairs in a [`TensorMap`]
495pub struct TensorMapIter<'a> {
496    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
497}
498
499impl<'a> Iterator for TensorMapIter<'a> {
500    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
501
502    #[inline]
503    fn next(&mut self) -> Option<Self::Item> {
504        self.inner.next()
505    }
506
507    fn size_hint(&self) -> (usize, Option<usize>) {
508        self.inner.size_hint()
509    }
510}
511
512impl ExactSizeIterator for TensorMapIter<'_> {
513    #[inline]
514    fn len(&self) -> usize {
515        self.inner.len()
516    }
517}
518
519impl FusedIterator for TensorMapIter<'_> {}
520
521impl<'a> IntoIterator for &'a TensorMap {
522    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
523
524    type IntoIter = TensorMapIter<'a>;
525
526    fn into_iter(self) -> Self::IntoIter {
527        self.iter()
528    }
529}
530
531/******************************************************************************/
532
533/// Iterator over key/block pairs in a [`TensorMap`], with mutable access to the
534/// blocks
535pub struct TensorMapIterMut<'a> {
536    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
537}
538
539impl<'a> Iterator for TensorMapIterMut<'a> {
540    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
541
542    #[inline]
543    fn next(&mut self) -> Option<Self::Item> {
544        self.inner.next()
545    }
546
547    fn size_hint(&self) -> (usize, Option<usize>) {
548        self.inner.size_hint()
549    }
550}
551
552impl ExactSizeIterator for TensorMapIterMut<'_> {
553    #[inline]
554    fn len(&self) -> usize {
555        self.inner.len()
556    }
557}
558
559impl FusedIterator for TensorMapIterMut<'_> {}
560
561impl<'a> IntoIterator for &'a mut TensorMap {
562    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
563
564    type IntoIter = TensorMapIterMut<'a>;
565
566    fn into_iter(self) -> Self::IntoIter {
567        self.iter_mut()
568    }
569}
570
571
572/******************************************************************************/
573
574/// Parallel iterator over key/block pairs in a [`TensorMap`]
575#[cfg(feature = "rayon")]
576pub struct TensorMapParIter<'a> {
577    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
578}
579
580#[cfg(feature = "rayon")]
581impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
582    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
583
584    #[inline]
585    fn drive_unindexed<C>(self, consumer: C) -> C::Result
586    where
587        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
588        self.inner.drive_unindexed(consumer)
589    }
590}
591
592#[cfg(feature = "rayon")]
593impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
594    #[inline]
595    fn len(&self) -> usize {
596        self.inner.len()
597    }
598
599    #[inline]
600    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
601        self.inner.drive(consumer)
602    }
603
604    #[inline]
605    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
606        self.inner.with_producer(callback)
607    }
608}
609
610/******************************************************************************/
611
612/// Parallel iterator over key/block pairs in a [`TensorMap`], with mutable
613/// access to the blocks
614#[cfg(feature = "rayon")]
615pub struct TensorMapParIterMut<'a> {
616    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
617}
618
619#[cfg(feature = "rayon")]
620impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
621    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
622
623    #[inline]
624    fn drive_unindexed<C>(self, consumer: C) -> C::Result
625    where
626        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
627        self.inner.drive_unindexed(consumer)
628    }
629}
630
631#[cfg(feature = "rayon")]
632impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
633    #[inline]
634    fn len(&self) -> usize {
635        self.inner.len()
636    }
637
638    #[inline]
639    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
640        self.inner.drive(consumer)
641    }
642
643    #[inline]
644    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
645        self.inner.with_producer(callback)
646    }
647}
648
649/******************************************************************************/
650
651/// Iterator over info key/value pairs in a `TensorMap`
652pub struct TensorMapInfoIter<'a> {
653    keys: Vec<&'a str>,
654    tensor: &'a TensorMap,
655    index: usize,
656    count: usize,
657}
658
659impl<'a> Iterator for TensorMapInfoIter<'a> {
660    type Item = (&'a str, &'a str);
661
662    #[inline]
663    fn next(&mut self) -> Option<Self::Item> {
664        if self.index >= self.count {
665            return None;
666        }
667        let key = self.keys[self.index];
668        let value = self.tensor.get_info(key).expect("missing info");
669        self.index += 1;
670        return Some((key, value));
671    }
672
673    fn size_hint(&self) -> (usize, Option<usize>) {
674        (self.count, Some(self.count))
675    }
676}
677
678impl ExactSizeIterator for TensorMapInfoIter<'_> {
679    #[inline]
680    fn len(&self) -> usize {
681        self.count
682    }
683}
684
685impl FusedIterator for TensorMapInfoIter<'_> {}
686
687
688/******************************************************************************/
689
690#[cfg(test)]
691#[allow(clippy::float_cmp)]
692mod tests {
693    use crate::{Labels, TensorBlock, TensorMap};
694
695    fn test_tensor() -> TensorMap {
696        let block_1 = TensorBlock::new(
697            ndarray::Array::from_elem(vec![2, 3], 1.0),
698            &Labels::new(["samples"], &[[0], [1]]),
699            &[],
700            &Labels::new(["properties"], &[[-2], [0], [1]]),
701        ).unwrap();
702
703        let block_2 = TensorBlock::new(
704            ndarray::Array::from_elem(vec![1, 1], 3.0),
705            &Labels::new(["samples"], &[[1]]),
706            &[],
707            &Labels::new(["properties"], &[[1]]),
708        ).unwrap();
709
710        let block_3 = TensorBlock::new(
711            ndarray::Array::from_elem(vec![3, 2], -4.0),
712            &Labels::new(["samples"], &[[0], [1], [3]]),
713            &[],
714            &Labels::new(["properties"], &[[-2], [1]]),
715        ).unwrap();
716
717        return TensorMap::new(
718            Labels::new(["key", "other"], &[[1, 0], [3, 1], [-4, 0]]),
719            vec![block_1, block_2, block_3],
720        ).unwrap();
721    }
722
723    #[test]
724    fn block_access() {
725        let mut tensor = test_tensor();
726
727        let block = tensor.block_by_id(1);
728        assert_eq!(block.values().shape().unwrap(), [1, 1]);
729
730        let block = tensor.block_mut_by_id(2);
731        assert_eq!(block.values().shape().unwrap(), [3, 2]);
732
733        let selection = Labels::new(["key"], &[[1]]);
734
735        let block = tensor.block(&selection).unwrap();
736        {
737            let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
738            assert_eq!(values.shape(), [2, 3]);
739        }
740
741        let blocks = tensor.blocks();
742        assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
743        assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
744        assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
745
746        let blocks = tensor.blocks_mut();
747        assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
748        assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
749        assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
750    }
751
752    #[test]
753    fn iter() {
754        let mut tensor = test_tensor();
755
756        // iterate over keys & blocks
757        for (key, block) in &tensor {
758            let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
759            assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
760        }
761
762        // iterate over keys & blocks mutably
763        for (key, mut block) in &mut tensor {
764            let array = block.values_mut().get_ndarray_mut::<f64>();
765            *array *= 2.0;
766            assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
767        }
768    }
769
770    #[cfg(feature = "rayon")]
771    #[test]
772    fn par_iter() {
773        use rayon::iter::ParallelIterator;
774
775        let mut tensor = test_tensor();
776
777        // iterate over keys & blocks
778        tensor.par_iter().for_each(|(key, block)| {
779            let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
780            assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
781        });
782
783        // iterate over keys & blocks mutably
784        tensor.par_iter_mut().for_each(|(key, mut block)| {
785            let array = block.values_mut().get_ndarray_mut::<f64>();
786            *array *= 2.0;
787            assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
788        });
789    }
790
791    #[test]
792    fn info() {
793        let mut tensor = test_tensor();
794        tensor.set_info("creator", "unit test");
795        tensor.set_info("version", "1.0");
796
797        assert_eq!(tensor.get_info("creator").unwrap(), "unit test");
798        assert_eq!(tensor.get_info("version").unwrap(), "1.0");
799        assert!(tensor.get_info("missing").is_none());
800
801        let mut info_iter = tensor.info();
802        let (key, value) = info_iter.next().unwrap();
803        assert_eq!(key, "creator");
804        assert_eq!(value, "unit test");
805        let (key, value) = info_iter.next().unwrap();
806        assert_eq!(key, "version");
807        assert_eq!(value, "1.0");
808        assert!(info_iter.next().is_none());
809    }
810}