Skip to main content

metatensor/
tensor.rs

1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::mts_tensormap_t;
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, LabelValue, Labels, MtsArray, TensorBlock, TensorBlockRef};
9
10/// [`TensorMap`] is the main user-facing struct of this library, and can
11/// store any kind of data used in atomistic machine learning.
12///
13/// A tensor map contains a list of `TensorBlock`s, each one associated with a
14/// key in the form of a single `Labels` entry.
15///
16/// It provides functions to merge blocks together by moving some of these keys
17/// to the samples or properties labels of the blocks, transforming the sparse
18/// representation of the data to a dense one.
19pub struct TensorMap {
20    pub(crate) ptr: *mut mts_tensormap_t,
21    /// cache for the keys labels
22    keys: Labels,
23}
24
25// SAFETY: Send is fine since we can free a TensorMap from any thread
26unsafe impl Send for TensorMap {}
27// SAFETY: Sync is fine since there is no internal mutability in TensorMap
28unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        use crate::labels::pretty_print_labels;
33        writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35        write!(f, "    keys: ")?;
36        pretty_print_labels(self.keys(), "    ", f)?;
37        writeln!(f, "}}")
38    }
39}
40
41impl std::ops::Drop for TensorMap {
42    #[allow(unused_must_use)]
43    fn drop(&mut self) {
44        unsafe {
45            crate::c_api::mts_tensormap_free(self.ptr);
46        }
47    }
48}
49
50impl TensorMap {
51    /// Create a new `TensorMap` with the given keys and blocks.
52    ///
53    /// The number of keys must match the number of blocks, and all the blocks
54    /// must contain the same kind of data (same labels names, same gradients
55    /// defined on all blocks).
56    #[allow(clippy::needless_pass_by_value)]
57    #[inline]
58    pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59        let ptr = unsafe {
60            crate::c_api::mts_tensormap(
61                keys.as_mts_labels_t(),
62                // this cast is fine because TensorBlock is `repr(transparent)`
63                // to a `*mut mts_block_t` (through `TensorBlockRefMut`, and
64                // `TensorBlockRef`).
65                blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66                blocks.len()
67            )
68        };
69
70        for block in blocks {
71            // we give ownership of the blocks to the new tensormap, so we
72            // should not free them again from Rust
73            std::mem::forget(block);
74        }
75
76        check_ptr(ptr)?;
77
78        return Ok(unsafe { TensorMap::from_raw(ptr) });
79    }
80
81    /// Create a new `TensorMap` from a raw pointer.
82    ///
83    /// This function takes ownership of the pointer, and will call
84    /// `mts_tensormap_free` on it when the `TensorMap` goes out of scope.
85    ///
86    /// # Safety
87    ///
88    /// The pointer must be non-null and created by
89    /// [`crate::c_api::mts_tensormap`] or [`TensorMap::into_raw`].
90    pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91        assert!(!ptr.is_null());
92
93        let keys_ptr = crate::c_api::mts_tensormap_keys(ptr);
94        assert!(!keys_ptr.is_null(), "failed to get the keys");
95        let keys = Labels::from_raw(keys_ptr);
96
97        return TensorMap {
98            ptr,
99            keys
100        };
101    }
102
103    /// Extract the underlying raw pointer.
104    ///
105    /// The pointer should be passed back to [`TensorMap::from_raw`] or
106    /// [`crate::c_api::mts_tensormap_free`] to release the memory corresponding
107    /// to this `TensorMap`.
108    pub fn into_raw(mut tensor: TensorMap) -> *mut mts_tensormap_t {
109        return std::mem::replace(&mut tensor.ptr, std::ptr::null_mut());
110    }
111
112    /// Get the underlying raw pointer.
113    ///
114    /// After a call, this `TensorMap` is still managing the corresponding
115    /// memory. To fully release the pointer, use [`TensorMap::into_raw`].
116    pub fn as_ptr(&self) -> *const mts_tensormap_t {
117        self.ptr
118    }
119
120    /// Get the underlying (mutable) raw pointer
121    ///
122    /// After a call, this `TensorMap` is still managing the corresponding
123    /// memory. To fully release the pointer, use [`TensorMap::into_raw`].
124    pub fn as_mut_ptr(&mut self) -> *mut mts_tensormap_t {
125        self.ptr
126    }
127
128    /// Clone this `TensorMap`, cloning all the data and metadata contained inside.
129    ///
130    /// This can fail if the external data held inside an `mts_array_t` can not
131    /// be cloned.
132    #[inline]
133    pub fn try_clone(&self) -> Result<TensorMap, Error> {
134        let ptr = unsafe {
135            crate::c_api::mts_tensormap_copy(self.ptr)
136        };
137        crate::errors::check_ptr(ptr)?;
138
139        return Ok(unsafe { TensorMap::from_raw(ptr) });
140    }
141
142    /// Load a `TensorMap` from the file at `path`
143    ///
144    /// This is a convenience function calling [`crate::io::load`]
145    pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
146        return crate::io::load(path);
147    }
148
149    /// Load a `TensorMap` from an in-memory buffer
150    ///
151    /// This is a convenience function calling [`crate::io::load_buffer`]
152    pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
153        return crate::io::load_buffer(buffer);
154    }
155
156    /// Save the given tensor to the file at `path`
157    ///
158    /// This is a convenience function calling [`crate::io::save`]
159    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
160        return crate::io::save(path, self);
161    }
162
163    /// Save the given tensor to an in-memory buffer
164    ///
165    /// This is a convenience function calling [`crate::io::save_buffer`]
166    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
167        return crate::io::save_buffer(self, buffer);
168    }
169
170    /// Get the device on which the values of this `TensorMap` are stored.
171    #[inline]
172    pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
173        let mut device = dlpk::sys::DLDevice::cpu();
174        unsafe {
175            check_status(crate::c_api::mts_tensormap_device(
176                self.ptr,
177                &mut device,
178            ))?;
179        }
180        return Ok(device);
181    }
182
183    /// Get the data type of the values of this `TensorMap`.
184    #[inline]
185    pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
186        let mut dtype = dlpk::sys::DLDataType {
187            code: dlpk::sys::DLDataTypeCode::kDLFloat,
188            bits: 0,
189            lanes: 0,
190        };
191        unsafe {
192            check_status(crate::c_api::mts_tensormap_dtype(
193                self.ptr,
194                &mut dtype,
195            ))?;
196        }
197        return Ok(dtype);
198    }
199
200    /// Get the keys defined in this `TensorMap`
201    #[inline]
202    pub fn keys(&self) -> &Labels {
203        &self.keys
204    }
205
206    /// Get a reference to the block at the given `index` in this `TensorMap`
207    ///
208    /// # Panics
209    ///
210    /// If the index is out of bounds
211    #[inline]
212    pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
213
214        let mut block = std::ptr::null_mut();
215        unsafe {
216            check_status(crate::c_api::mts_tensormap_block_by_id(
217                self.ptr,
218                &mut block,
219                index,
220            )).expect("failed to get a block");
221        }
222
223        return unsafe { TensorBlockRef::from_raw(block) }
224    }
225
226    /// Get a mutable reference to the block at the given `index` in this `TensorMap`
227    ///
228    /// # Panics
229    ///
230    /// If the index is out of bounds
231    #[inline]
232    pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
233        return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
234    }
235
236    /// Implementation of `block_mut_by_id` which does not borrow the
237    /// `mts_tensormap_t` pointer.
238    ///
239    /// This is used to provide references to multiple blocks at the same time
240    /// in the iterators.
241    ///
242    /// # Safety
243    ///
244    /// This should be called with a valid `mts_tensormap_t`, and the lifetime
245    /// `'a` should be properly constrained to the lifetime of the owner of
246    /// `ptr`.
247    #[inline]
248    unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
249        let mut block = std::ptr::null_mut();
250
251        check_status(
252            crate::c_api::mts_tensormap_block_by_id(
253            ptr,
254            &mut block,
255            index,
256        )).expect("failed to get a block");
257
258        return TensorBlockRefMut::from_raw(block);
259    }
260
261    /// Get a reference to the block matching the given selection.
262    #[inline]
263    pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
264        let matching = self.keys.select(selection)?;
265        if matching.len() != 1 {
266            let selection_str = selection.names()
267                .iter()
268                .zip(&selection[0])
269                .map(|(name, value)| format!("{} = {}", name, value))
270                .collect::<Vec<_>>()
271                .join(", ");
272
273            if matching.is_empty() {
274                return Err(Error {
275                    code: None,
276                    message: format!(
277                        "no blocks matched the selection ({})",
278                        selection_str
279                    ),
280                });
281            } else {
282                return Err(Error {
283                    code: None,
284                    message: format!(
285                        "{} blocks matched the selection ({}), expected only one",
286                        matching.len(),
287                        selection_str
288                    ),
289                });
290            }
291        }
292
293        return Ok(self.block_by_id(matching[0]));
294    }
295
296    /// Get a reference to every blocks in this `TensorMap`
297    #[inline]
298    pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
299        let mut blocks = Vec::new();
300        for i in 0..self.keys().count() {
301            blocks.push(self.block_by_id(i));
302        }
303        return blocks;
304    }
305
306    /// Get a mutable reference to every blocks in this `TensorMap`
307    #[inline]
308    pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
309        let mut blocks = Vec::new();
310        for i in 0..self.keys().count() {
311            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
312        }
313        return blocks;
314    }
315
316    /// Merge blocks with the same value for selected keys dimensions along the
317    /// samples axis.
318    ///
319    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
320    /// the sample labels, and blocks with the same remaining keys dimensions
321    /// will be merged together along the sample axis.
322    ///
323    /// `keys_to_move` must be empty (`keys_to_move.count() == 0`), and the new
324    /// sample labels will contain entries corresponding to the merged blocks'
325    /// keys.
326    ///
327    /// The new sample labels will contain all of the merged blocks sample
328    /// labels. The order of the samples is controlled by `sort_samples`. If
329    /// `sort_samples` is true, samples are re-ordered to keep them
330    /// lexicographically sorted. Otherwise they are kept in the order in which
331    /// they appear in the blocks.
332    #[inline]
333    pub fn keys_to_samples(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
334        let ptr = unsafe {
335            crate::c_api::mts_tensormap_keys_to_samples(
336                self.ptr,
337                keys_to_move.as_mts_labels_t(),
338                fill_value.into_raw(),
339                sort_samples,
340            )
341        };
342
343        check_ptr(ptr)?;
344        return Ok(unsafe { TensorMap::from_raw(ptr) });
345    }
346
347    /// Merge blocks with the same value for selected keys dimensions along the
348    /// property axis.
349    ///
350    /// The dimensions (names) of `keys_to_move` will be moved from the keys to
351    /// the property labels, and blocks with the same remaining keys dimensions
352    /// will be merged together along the property axis.
353    ///
354    /// If `keys_to_move` does not contain any entries (`keys_to_move.count()
355    /// == 0`), then the new property labels will contain entries corresponding
356    /// to the merged blocks only. For example, merging a block with key `a=0`
357    /// and properties `p=1, 2` with a block with key `a=2` and properties `p=1,
358    /// 3` will produce a block with properties `a, p = (0, 1), (0, 2), (2, 1),
359    /// (2, 3)`.
360    ///
361    /// If `keys_to_move` contains entries, then the property labels must be the
362    /// same for all the merged blocks. In that case, the merged property labels
363    /// will contain each of the entries of `keys_to_move` and then the current
364    /// property labels. For example, using `a=2, 3` in `keys_to_move`, and
365    /// blocks with properties `p=1, 2` will result in `a, p = (2, 1), (2, 2),
366    /// (3, 1), (3, 2)`.
367    ///
368    /// The new sample labels will contain all of the merged blocks sample
369    /// labels. The order of the samples is controlled by `sort_samples`. If
370    /// `sort_samples` is true, samples are re-ordered to keep them
371    /// lexicographically sorted. Otherwise they are kept in the order in which
372    /// they appear in the blocks.
373    #[inline]
374    pub fn keys_to_properties(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
375        let ptr = unsafe {
376            crate::c_api::mts_tensormap_keys_to_properties(
377                self.ptr,
378                keys_to_move.as_mts_labels_t(),
379                fill_value.into_raw(),
380                sort_samples,
381            )
382        };
383
384        check_ptr(ptr)?;
385        return Ok(unsafe { TensorMap::from_raw(ptr) });
386    }
387
388    /// Move the given dimensions from the component labels to the property
389    /// labels for each block in this `TensorMap`.
390    #[inline]
391    pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
392        let dimensions_c = dimensions.iter()
393            .map(|&v| CString::new(v).expect("unexpected NULL byte"))
394            .collect::<Vec<_>>();
395
396        let dimensions_ptr = dimensions_c.iter()
397            .map(|v| v.as_ptr())
398            .collect::<Vec<_>>();
399
400
401        let ptr = unsafe {
402            crate::c_api::mts_tensormap_components_to_properties(
403                self.ptr,
404                dimensions_ptr.as_ptr(),
405                dimensions.len(),
406            )
407        };
408
409        check_ptr(ptr)?;
410        return Ok(unsafe { TensorMap::from_raw(ptr) });
411    }
412
413    /// Get an iterator over the keys and associated blocks
414    #[inline]
415    pub fn iter(&self) -> TensorMapIter<'_> {
416        return TensorMapIter {
417            inner: self.keys().into_iter().zip(self.blocks())
418        };
419    }
420
421    /// Get an iterator over the keys and associated blocks, with read-write
422    /// access to the blocks
423    #[inline]
424    pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
425        // we can not use `self.blocks_mut()` here, since it would
426        // double-borrow self
427        let mut blocks = Vec::new();
428        for i in 0..self.keys().count() {
429            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
430        }
431
432        return TensorMapIterMut {
433            inner: self.keys().into_iter().zip(blocks)
434        };
435    }
436
437    /// Get a parallel iterator over the keys and associated blocks
438    #[cfg(feature = "rayon")]
439    #[inline]
440    pub fn par_iter(&self) -> TensorMapParIter<'_> {
441        use rayon::prelude::*;
442        TensorMapParIter {
443            inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
444        }
445    }
446
447    /// Get a parallel iterator over the keys and associated blocks, with
448    /// read-write access to the blocks
449    #[cfg(feature = "rayon")]
450    #[inline]
451    pub fn par_iter_mut(&mut self) -> TensorMapParIterMut<'_> {
452        use rayon::prelude::*;
453
454        // we can not use `self.blocks_mut()` here, since it would
455        // double-borrow self
456        let mut blocks = Vec::new();
457        for i in 0..self.keys().count() {
458            blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
459        }
460
461        TensorMapParIterMut {
462            inner: self.keys().par_iter().zip_eq(blocks)
463        }
464    }
465
466    /// Set or update the info (i.e. global metadata) `value` associated with
467    /// `key` for this `TensorMap`.
468    pub fn set_info(&mut self, key: &str, value: &str) {
469        let mut key = key.to_owned().into_bytes();
470        key.push(b'\0');
471
472        let mut value = value.to_owned().into_bytes();
473        value.push(b'\0');
474
475        unsafe {
476            check_status(crate::c_api::mts_tensormap_set_info(
477                self.ptr, key.as_ptr().cast(), value.as_ptr().cast()
478            )).expect("failed to set info");
479        }
480    }
481
482    /// Get the info (i.e. global metadata) with the given `key` for this
483    /// `TensorMap`.
484    pub fn get_info(&self, key: &str) -> Option<&str> {
485        let mut key = key.to_owned().into_bytes();
486        key.push(b'\0');
487
488        let mut value = std::ptr::null();
489
490        unsafe {
491            check_status(crate::c_api::mts_tensormap_get_info(
492                self.ptr, key.as_ptr().cast(), &mut value
493            )).expect("failed to set info");
494        }
495
496        if value.is_null() {
497            return None;
498        }
499
500        let c_str = unsafe { CStr::from_ptr(value) };
501        return Some(c_str.to_str().expect("invalid UTF-8 string"));
502    }
503
504    /// Get an iterator over all the key/value info pairs stored in this
505    /// `TensorMap`.
506    pub fn info(&self) -> TensorMapInfoIter<'_> {
507        let mut keys = std::ptr::null();
508        let mut count = 0;
509        unsafe {
510            check_status(crate::c_api::mts_tensormap_info_keys(
511                self.ptr,
512                &mut keys,
513                &mut count,
514            )).expect("failed to get info keys");
515        };
516
517        let keys = unsafe {
518            std::slice::from_raw_parts(keys, count)
519        };
520        let keys = keys.iter()
521            .map(|&k| {
522                let c_str = unsafe { CStr::from_ptr(k) };
523                c_str.to_str().expect("invalid UTF-8 string")
524            })
525            .collect::<Vec<_>>();
526
527        TensorMapInfoIter {
528            keys: keys,
529            tensor: self,
530            index: 0,
531            count,
532        }
533    }
534}
535
536/******************************************************************************/
537
538/// Iterator over key/block pairs in a [`TensorMap`]
539pub struct TensorMapIter<'a> {
540    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
541}
542
543impl<'a> Iterator for TensorMapIter<'a> {
544    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
545
546    #[inline]
547    fn next(&mut self) -> Option<Self::Item> {
548        self.inner.next()
549    }
550
551    fn size_hint(&self) -> (usize, Option<usize>) {
552        self.inner.size_hint()
553    }
554}
555
556impl ExactSizeIterator for TensorMapIter<'_> {
557    #[inline]
558    fn len(&self) -> usize {
559        self.inner.len()
560    }
561}
562
563impl FusedIterator for TensorMapIter<'_> {}
564
565impl<'a> IntoIterator for &'a TensorMap {
566    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
567
568    type IntoIter = TensorMapIter<'a>;
569
570    fn into_iter(self) -> Self::IntoIter {
571        self.iter()
572    }
573}
574
575/******************************************************************************/
576
577/// Iterator over key/block pairs in a [`TensorMap`], with mutable access to the
578/// blocks
579pub struct TensorMapIterMut<'a> {
580    inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
581}
582
583impl<'a> Iterator for TensorMapIterMut<'a> {
584    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
585
586    #[inline]
587    fn next(&mut self) -> Option<Self::Item> {
588        self.inner.next()
589    }
590
591    fn size_hint(&self) -> (usize, Option<usize>) {
592        self.inner.size_hint()
593    }
594}
595
596impl ExactSizeIterator for TensorMapIterMut<'_> {
597    #[inline]
598    fn len(&self) -> usize {
599        self.inner.len()
600    }
601}
602
603impl FusedIterator for TensorMapIterMut<'_> {}
604
605impl<'a> IntoIterator for &'a mut TensorMap {
606    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
607
608    type IntoIter = TensorMapIterMut<'a>;
609
610    fn into_iter(self) -> Self::IntoIter {
611        self.iter_mut()
612    }
613}
614
615
616/******************************************************************************/
617
618/// Parallel iterator over key/block pairs in a [`TensorMap`]
619#[cfg(feature = "rayon")]
620pub struct TensorMapParIter<'a> {
621    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
622}
623
624#[cfg(feature = "rayon")]
625impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
626    type Item = (&'a [LabelValue], TensorBlockRef<'a>);
627
628    #[inline]
629    fn drive_unindexed<C>(self, consumer: C) -> C::Result
630    where
631        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
632        self.inner.drive_unindexed(consumer)
633    }
634}
635
636#[cfg(feature = "rayon")]
637impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
638    #[inline]
639    fn len(&self) -> usize {
640        self.inner.len()
641    }
642
643    #[inline]
644    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
645        self.inner.drive(consumer)
646    }
647
648    #[inline]
649    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
650        self.inner.with_producer(callback)
651    }
652}
653
654/******************************************************************************/
655
656/// Parallel iterator over key/block pairs in a [`TensorMap`], with mutable
657/// access to the blocks
658#[cfg(feature = "rayon")]
659pub struct TensorMapParIterMut<'a> {
660    inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
661}
662
663#[cfg(feature = "rayon")]
664impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
665    type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
666
667    #[inline]
668    fn drive_unindexed<C>(self, consumer: C) -> C::Result
669    where
670        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
671        self.inner.drive_unindexed(consumer)
672    }
673}
674
675#[cfg(feature = "rayon")]
676impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
677    #[inline]
678    fn len(&self) -> usize {
679        self.inner.len()
680    }
681
682    #[inline]
683    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
684        self.inner.drive(consumer)
685    }
686
687    #[inline]
688    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
689        self.inner.with_producer(callback)
690    }
691}
692
693/******************************************************************************/
694
695/// Iterator over info key/value pairs in a `TensorMap`
696pub struct TensorMapInfoIter<'a> {
697    keys: Vec<&'a str>,
698    tensor: &'a TensorMap,
699    index: usize,
700    count: usize,
701}
702
703impl<'a> Iterator for TensorMapInfoIter<'a> {
704    type Item = (&'a str, &'a str);
705
706    #[inline]
707    fn next(&mut self) -> Option<Self::Item> {
708        if self.index >= self.count {
709            return None;
710        }
711        let key = self.keys[self.index];
712        let value = self.tensor.get_info(key).expect("missing info");
713        self.index += 1;
714        return Some((key, value));
715    }
716
717    fn size_hint(&self) -> (usize, Option<usize>) {
718        (self.count, Some(self.count))
719    }
720}
721
722impl ExactSizeIterator for TensorMapInfoIter<'_> {
723    #[inline]
724    fn len(&self) -> usize {
725        self.count
726    }
727}
728
729impl FusedIterator for TensorMapInfoIter<'_> {}
730
731
732/******************************************************************************/
733
734#[cfg(test)]
735#[allow(clippy::float_cmp)]
736mod tests {
737    use crate::{Labels, TensorBlock, TensorMap};
738
739    fn test_tensor() -> TensorMap {
740        let block_1 = TensorBlock::new(
741            ndarray::Array::from_elem(vec![2, 3], 1.0),
742            &Labels::new(["samples"], [[0], [1]]),
743            &[],
744            &Labels::new(["properties"], [[-2], [0], [1]]),
745        ).unwrap();
746
747        let block_2 = TensorBlock::new(
748            ndarray::Array::from_elem(vec![1, 1], 3.0),
749            &Labels::new(["samples"], [[1]]),
750            &[],
751            &Labels::new(["properties"], [[1]]),
752        ).unwrap();
753
754        let block_3 = TensorBlock::new(
755            ndarray::Array::from_elem(vec![3, 2], -4.0),
756            &Labels::new(["samples"], [[0], [1], [3]]),
757            &[],
758            &Labels::new(["properties"], [[-2], [1]]),
759        ).unwrap();
760
761        return TensorMap::new(
762            Labels::new(["key", "other"], [[1, 0], [3, 1], [-4, 0]]),
763            vec![block_1, block_2, block_3],
764        ).unwrap();
765    }
766
767    #[test]
768    fn block_access() {
769        let mut tensor = test_tensor();
770
771        let block = tensor.block_by_id(1);
772        assert_eq!(block.values().shape().unwrap(), [1, 1]);
773
774        let block = tensor.block_mut_by_id(2);
775        assert_eq!(block.values().shape().unwrap(), [3, 2]);
776
777        let selection = Labels::new(["key"], [[1]]);
778
779        let block = tensor.block(&selection).unwrap();
780        {
781            let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
782            assert_eq!(values.shape(), [2, 3]);
783        }
784
785        let blocks = tensor.blocks();
786        assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
787        assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
788        assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
789
790        let blocks = tensor.blocks_mut();
791        assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
792        assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
793        assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
794    }
795
796    #[test]
797    fn iter() {
798        let mut tensor = test_tensor();
799
800        // iterate over keys & blocks
801        for (key, block) in &tensor {
802            let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
803            assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
804        }
805
806        // iterate over keys & blocks mutably
807        for (key, mut block) in &mut tensor {
808            let array = block.values_mut().get_ndarray_mut::<f64>();
809            *array *= 2.0;
810            assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
811        }
812    }
813
814    #[cfg(feature = "rayon")]
815    #[test]
816    fn par_iter() {
817        use rayon::iter::ParallelIterator;
818
819        let mut tensor = test_tensor();
820
821        // iterate over keys & blocks
822        tensor.par_iter().for_each(|(key, block)| {
823            let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
824            assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
825        });
826
827        // iterate over keys & blocks mutably
828        tensor.par_iter_mut().for_each(|(key, mut block)| {
829            let array = block.values_mut().get_ndarray_mut::<f64>();
830            *array *= 2.0;
831            assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
832        });
833    }
834
835    #[test]
836    fn info() {
837        let mut tensor = test_tensor();
838        tensor.set_info("creator", "unit test");
839        tensor.set_info("version", "1.0");
840
841        assert_eq!(tensor.get_info("creator").unwrap(), "unit test");
842        assert_eq!(tensor.get_info("version").unwrap(), "1.0");
843        assert!(tensor.get_info("missing").is_none());
844
845        let mut info_iter = tensor.info();
846        let (key, value) = info_iter.next().unwrap();
847        assert_eq!(key, "creator");
848        assert_eq!(value, "unit test");
849        let (key, value) = info_iter.next().unwrap();
850        assert_eq!(key, "version");
851        assert_eq!(value, "1.0");
852        assert!(info_iter.next().is_none());
853    }
854
855    #[test]
856    fn device_and_dtype() {
857        let tensor = test_tensor();
858
859        let device = tensor.device().unwrap();
860        assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
861
862        let dtype = tensor.dtype().unwrap();
863        assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
864        assert_eq!(dtype.bits, 64);
865    }
866
867    #[test]
868    fn tensor_map_into_raw() {
869        let tensor = test_tensor();
870        let raw = TensorMap::into_raw(tensor);
871
872        let recovered = unsafe { TensorMap::from_raw(raw) };
873        assert_eq!(
874            recovered.keys(),
875            &Labels::new(["key", "other"], [[1, 0], [3, 1], [-4, 0]])
876        );
877    }
878}