metatensor/
labels.rs

1use std:: ffi::CStr;
2use std::ffi::CString;
3use std::collections::BTreeSet;
4use std::iter::FusedIterator;
5
6use smallvec::SmallVec;
7
8use crate::c_api::mts_labels_t;
9use crate::errors::{Error, check_status};
10
11/// A single value inside a label.
12///
13/// This is represented as a 32-bit signed integer, with a couple of helper
14/// function to get its value as usize/isize.
15#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[repr(transparent)]
17pub struct LabelValue(i32);
18
19impl PartialEq<i32> for LabelValue {
20    #[inline]
21    fn eq(&self, other: &i32) -> bool {
22        self.0 == *other
23    }
24}
25
26impl PartialEq<LabelValue> for i32 {
27    #[inline]
28    fn eq(&self, other: &LabelValue) -> bool {
29        *self == other.0
30    }
31}
32
33impl std::fmt::Debug for LabelValue {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "{}", self.0)
36    }
37}
38
39impl std::fmt::Display for LabelValue {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
46impl From<u32> for LabelValue {
47    #[inline]
48    fn from(value: u32) -> LabelValue {
49        assert!(value < i32::MAX as u32);
50        LabelValue(value as i32)
51    }
52}
53
54impl From<i32> for LabelValue {
55    #[inline]
56    fn from(value: i32) -> LabelValue {
57        LabelValue(value)
58    }
59}
60
61#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
62impl From<usize> for LabelValue {
63    #[inline]
64    fn from(value: usize) -> LabelValue {
65        assert!(value < i32::MAX as usize);
66        LabelValue(value as i32)
67    }
68}
69
70#[allow(clippy::cast_possible_truncation)]
71impl From<isize> for LabelValue {
72    #[inline]
73    fn from(value: isize) -> LabelValue {
74        assert!(value < i32::MAX as isize && value > i32::MIN as isize);
75        LabelValue(value as i32)
76    }
77}
78
79impl LabelValue {
80    /// Create a `LabelValue` with the given `value`
81    #[inline]
82    pub fn new(value: i32) -> LabelValue {
83        LabelValue(value)
84    }
85
86    /// Get the integer value of this `LabelValue` as a usize
87    #[inline]
88    #[allow(clippy::cast_sign_loss)]
89    pub fn usize(self) -> usize {
90        debug_assert!(self.0 >= 0);
91        self.0 as usize
92    }
93
94    /// Get the integer value of this `LabelValue` as an isize
95    #[inline]
96    pub fn isize(self) -> isize {
97        self.0 as isize
98    }
99
100    /// Get the integer value of this `LabelValue` as an i32
101    #[inline]
102    pub fn i32(self) -> i32 {
103        self.0
104    }
105}
106
107/// A set of labels used to carry metadata associated with a tensor map.
108///
109/// This is similar to a list of named tuples, but stored as a 2D array of shape
110/// `(labels.count(), labels.size())`, with a of set names associated with the
111/// columns of this array. Each row/entry in this array is unique, and they are
112/// often (but not always) sorted in  lexicographic order.
113///
114/// The main way to construct a new set of labels is to use a `LabelsBuilder`.
115///
116/// Labels are internally reference counted and immutable, so cloning a `Labels`
117/// should be a cheap operation.
118pub struct Labels {
119    pub(crate) raw: mts_labels_t,
120}
121
122// Labels can be sent to other thread safely since mts_labels_t uses an
123// `Arc<metatensor_core::Labels>`, so freeing them from another thread is fine
124unsafe impl Send for Labels {}
125// &Labels can be sent to other thread safely since there is no un-synchronized
126// interior mutability (`user_data` is protected by RwLock).
127unsafe impl Sync for Labels {}
128
129impl std::fmt::Debug for Labels {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        pretty_print_labels(self, "", f)
132    }
133}
134
135/// Helper function to print labels in a Debug mode
136pub(crate) fn pretty_print_labels(
137    labels: &Labels,
138    offset: &str,
139    f: &mut std::fmt::Formatter<'_>
140) -> std::fmt::Result {
141    let names = labels.names();
142
143    writeln!(f, "Labels @ {:p} {{", labels.raw.internal_ptr_)?;
144    writeln!(f, "{}    {}", offset, names.join(", "))?;
145
146    let widths = names.iter().map(|s| s.len()).collect::<Vec<_>>();
147    for values in labels {
148        write!(f, "{}    ", offset)?;
149        for (value, width) in values.iter().zip(&widths) {
150            write!(f, "{:^width$}  ", value.isize(), width=width)?;
151        }
152        writeln!(f)?;
153    }
154
155    writeln!(f, "{}}}", offset)
156}
157
158impl Clone for Labels {
159    #[inline]
160    fn clone(&self) -> Self {
161        let mut clone = mts_labels_t::null();
162        unsafe {
163            check_status(crate::c_api::mts_labels_clone(self.raw, &mut clone)).expect("failed to clone Labels");
164        }
165
166        return unsafe { Labels::from_raw(clone) };
167    }
168}
169
170impl std::ops::Drop for Labels {
171    #[allow(unused_must_use)]
172    fn drop(&mut self) {
173        unsafe {
174            crate::c_api::mts_labels_free(&mut self.raw);
175        }
176    }
177}
178
179impl Labels {
180    /// Create a new set of Labels with the given names and values.
181    ///
182    /// This is a convenience function replacing the manual use of
183    /// `LabelsBuilder`. If you need more flexibility or incremental `Labels`
184    /// construction, use `LabelsBuilder`.
185    ///
186    /// # Panics
187    ///
188    /// If the set of names is not valid, or any of the value is duplicated
189    #[inline]
190    pub fn new<T, const N: usize>(names: [&str; N], values: &[[T; N]]) -> Labels
191        where T: Copy + Into<LabelValue>
192    {
193        let mut builder = LabelsBuilder::new(names.to_vec());
194        for entry in values {
195            builder.add(entry);
196        }
197        return builder.finish();
198    }
199
200    /// Create a set of `Labels` with the given names, containing no entries.
201    #[inline]
202    pub fn empty(names: Vec<&str>) -> Labels {
203        return LabelsBuilder::new(names).finish()
204    }
205
206    /// Create a set of `Labels` containing a single entry, to be used when
207    /// there is no relevant information to store.
208    #[inline]
209    pub fn single() -> Labels {
210        let mut builder = LabelsBuilder::new(vec!["_"]);
211        builder.add(&[0]);
212        return builder.finish();
213    }
214
215    /// Load `Labels` from the file at `path`
216    ///
217    /// This is a convenience function calling [`crate::io::load_labels`]
218    pub fn load(path: impl AsRef<std::path::Path>) -> Result<Labels, Error> {
219        return crate::io::load_labels(path);
220    }
221
222    /// Load a `TensorMap` from an in-memory buffer
223    ///
224    /// This is a convenience function calling [`crate::io::load_buffer`]
225    pub fn load_buffer(buffer: &[u8]) -> Result<Labels, Error> {
226        return crate::io::load_labels_buffer(buffer);
227    }
228
229    /// Save the given tensor to the file at `path`
230    ///
231    /// This is a convenience function calling [`crate::io::save`]
232    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
233        return crate::io::save_labels(path, self);
234    }
235
236    /// Save the given tensor to an in-memory buffer
237    ///
238    /// This is a convenience function calling [`crate::io::save_buffer`]
239    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
240        return crate::io::save_labels_buffer(self, buffer);
241    }
242
243    /// Get the number of entries/named values in a single label
244    #[inline]
245    pub fn size(&self) -> usize {
246        self.raw.size
247    }
248
249    /// Get the names of the entries/columns in this set of labels
250    #[inline]
251    pub fn names(&self) -> Vec<&str> {
252        if self.raw.size == 0 {
253            return Vec::new();
254        } else {
255            unsafe {
256                let names = std::slice::from_raw_parts(self.raw.names, self.raw.size);
257                return names.iter()
258                            .map(|&ptr| CStr::from_ptr(ptr).to_str().expect("invalid UTF8"))
259                            .collect();
260            }
261        }
262    }
263
264    /// Get the total number of entries in this set of labels
265    #[inline]
266    pub fn count(&self) -> usize {
267        return self.raw.count;
268    }
269
270    /// Check if this set of Labels is empty (contains no entry)
271    #[inline]
272    pub fn is_empty(&self) -> bool {
273        self.count() == 0
274    }
275
276    /// Check whether the given `label` is part of this set of labels
277    #[inline]
278    pub fn contains(&self, label: &[LabelValue]) -> bool {
279        return self.position(label).is_some();
280    }
281
282    /// Get the position (i.e. row index) of the given label in the full labels
283    /// array, or None.
284    #[inline]
285    pub fn position(&self, value: &[LabelValue]) -> Option<usize> {
286        assert!(value.len() == self.size(), "invalid size of index in Labels::position");
287
288        let mut result = 0;
289        unsafe {
290            check_status(crate::c_api::mts_labels_position(
291                self.raw,
292                value.as_ptr().cast(),
293                value.len(),
294                &mut result,
295            )).expect("failed to check label position");
296        }
297
298        return result.try_into().ok();
299    }
300
301    /// Take the union of `self` with `other`.
302    ///
303    /// If requested, this function can also give the positions in the union
304    /// where each entry of the input `Labels` ended up.
305    ///
306    /// If `first_mapping` (respectively `second_mapping`) is `Some`, it should
307    /// contain a slice of length `self.count()` (respectively `other.count()`)
308    /// that will be filled with the position of the entries in `self`
309    /// (respectively `other`) in the union.
310    #[inline]
311    pub fn union(
312        &self,
313        other: &Labels,
314        first_mapping: Option<&mut [i64]>,
315        second_mapping: Option<&mut [i64]>,
316    ) -> Result<Labels, Error> {
317        let mut output = mts_labels_t::null();
318        let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
319            (m.as_mut_ptr(), m.len())
320        } else {
321            (std::ptr::null_mut(), 0)
322        };
323
324        let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
325            (m.as_mut_ptr(), m.len())
326        } else {
327            (std::ptr::null_mut(), 0)
328        };
329
330        unsafe {
331            check_status(crate::c_api::mts_labels_union(
332                self.raw,
333                other.raw,
334                &mut output,
335                first_mapping,
336                first_mapping_count,
337                second_mapping,
338                second_mapping_count,
339            ))?;
340
341            return Ok(Labels::from_raw(output));
342        }
343    }
344
345    /// Take the intersection of self with `other`.
346    ///
347    /// If requested, this function can also give the positions in the
348    /// intersection where each entry of the input `Labels` ended up.
349    ///
350    /// If `first_mapping` (respectively `second_mapping`) is `Some`, it should
351    /// contain a slice of length `self.count()` (respectively `other.count()`)
352    /// that will be filled by with the position of the entries in `self`
353    /// (respectively `other`) in the intersection. If an entry in `self` or
354    /// `other` are not used in the intersection, the mapping for this entry
355    /// will be set to `-1`.
356    #[inline]
357    pub fn intersection(
358        &self,
359        other: &Labels,
360        first_mapping: Option<&mut [i64]>,
361        second_mapping: Option<&mut [i64]>,
362    ) -> Result<Labels, Error> {
363        let mut output = mts_labels_t::null();
364        let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
365            (m.as_mut_ptr(), m.len())
366        } else {
367            (std::ptr::null_mut(), 0)
368        };
369
370        let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
371            (m.as_mut_ptr(), m.len())
372        } else {
373            (std::ptr::null_mut(), 0)
374        };
375
376        unsafe {
377            check_status(crate::c_api::mts_labels_intersection(
378                self.raw,
379                other.raw,
380                &mut output,
381                first_mapping,
382                first_mapping_count,
383                second_mapping,
384                second_mapping_count,
385            ))?;
386
387            return Ok(Labels::from_raw(output));
388        }
389    }
390
391    /// Take the set difference of `self` with `other`.
392    ///
393    /// If requested, this function can also give the positions in the
394    /// difference where each entry of `self` ended up.
395    ///
396    /// If `mapping` is `Some`, it should contain a slice of length
397    /// `self.count()` that will be filled by with the position of the entries
398    /// in `self` in the difference. If an entry is not used in the difference,
399    ///  the mapping for this entry will be set to `-1`.
400    #[inline]
401    pub fn difference(
402        &self,
403        other: &Labels,
404        mapping: Option<&mut [i64]>,
405    ) -> Result<Labels, Error> {
406        let mut output = mts_labels_t::null();
407        let (mapping, mapping_count) = if let Some(m) = mapping {
408            (m.as_mut_ptr(), m.len())
409        } else {
410            (std::ptr::null_mut(), 0)
411        };
412
413        unsafe {
414            check_status(crate::c_api::mts_labels_difference(
415                self.raw,
416                other.raw,
417                &mut output,
418                mapping,
419                mapping_count,
420            ))?;
421
422            return Ok(Labels::from_raw(output));
423        }
424    }
425
426    /// Iterate over the entries in this set of labels
427    #[inline]
428    pub fn iter(&self) -> LabelsIter<'_> {
429        return LabelsIter {
430            ptr: self.values().as_ptr(),
431            cur: 0,
432            len: self.count(),
433            chunk_len: self.size(),
434            phantom: std::marker::PhantomData,
435        };
436    }
437
438    /// Iterate over the entries in this set of labels in parallel
439    #[cfg(feature = "rayon")]
440    #[inline]
441    pub fn par_iter(&self) -> LabelsParIter<'_> {
442        use rayon::prelude::*;
443        return LabelsParIter {
444            chunks: self.values().par_chunks_exact(self.raw.size)
445        };
446    }
447
448    /// Iterate over the entries in this set of labels as fixed-size arrays
449    #[inline]
450    pub fn iter_fixed_size<const N: usize>(&self) -> LabelsFixedSizeIter<N> {
451        assert!(N == self.size(),
452            "wrong label size in `iter_fixed_size`: the entries contains {} element \
453            but this function was called with size of {}",
454            self.size(), N
455        );
456
457        return LabelsFixedSizeIter {
458            values: self.values()
459        };
460    }
461
462    /// Select entries in these `Labels` that match the `selection`.
463    ///
464    /// The selection's names must be a subset of the names of these labels.
465    ///
466    /// All entries in these `Labels` that match one of the entry in the
467    /// `selection` for all the selection's dimension will be picked. Any entry
468    /// in the `selection` but not in these `Labels` will be ignored.
469    pub fn select(&self, selection: &Labels) -> Result<Vec<i64>, Error> {
470        let mut selected = vec![-1; self.count()];
471        let mut selected_count = selected.len();
472
473        unsafe {
474            check_status(crate::c_api::mts_labels_select(
475                self.as_mts_labels_t(),
476                selection.as_mts_labels_t(),
477                selected.as_mut_ptr(),
478                &mut selected_count
479            ))?;
480        }
481
482        selected.resize(selected_count, 0);
483
484        return Ok(selected);
485    }
486
487    pub(crate) fn values(&self) -> &[LabelValue] {
488        if self.count() == 0 || self.size() == 0 {
489            return &[]
490        } else {
491            unsafe {
492                std::slice::from_raw_parts(self.raw.values.cast(), self.count() * self.size())
493            }
494        }
495    }
496}
497
498impl Labels {
499    /// Get the underlying `mts_labels_t`
500    pub(crate) fn as_mts_labels_t(&self) -> mts_labels_t {
501        return self.raw;
502    }
503
504    /// Create a new set of `Labels` from a raw `mts_labels_t`.
505    ///
506    /// This function takes ownership of the `mts_labels_t` and will call
507    /// `mts_labels_free` on it.
508    ///
509    /// # Safety
510    ///
511    /// The raw `mts_labels_t` must have been returned by one of the function
512    /// returning `mts_labels_t` in metatensor-core
513    #[inline]
514    pub unsafe fn from_raw(raw: mts_labels_t) -> Labels {
515        assert!(!raw.internal_ptr_.is_null(), "expected mts_labels_t.internal_ptr_ to not be NULL");
516        Labels {
517            raw: raw,
518        }
519    }
520}
521
522impl std::cmp::PartialEq<Labels> for Labels {
523    #[inline]
524    fn eq(&self, other: &Labels) -> bool {
525        self.names() == other.names() && self.values() == other.values()
526    }
527}
528
529impl std::ops::Index<usize> for Labels {
530    type Output = [LabelValue];
531
532    #[inline]
533    fn index(&self, i: usize) -> &[LabelValue] {
534        let start = i * self.size();
535        let stop = (i + 1) * self.size();
536        &self.values()[start..stop]
537    }
538}
539
540/// iterator over [`Labels`] entries
541pub struct LabelsIter<'a> {
542    /// start of the labels values
543    ptr: *const LabelValue,
544    /// Current entry index
545    cur: usize,
546    /// number of entries
547    len: usize,
548    /// size of an entry/the labels
549    chunk_len: usize,
550    phantom: std::marker::PhantomData<&'a LabelValue>,
551}
552
553impl<'a> Iterator for LabelsIter<'a> {
554    type Item = &'a [LabelValue];
555
556    #[inline]
557    fn next(&mut self) -> Option<Self::Item> {
558        if self.cur < self.len {
559            unsafe {
560                // SAFETY: this should be in-bounds
561                let data = self.ptr.add(self.cur * self.chunk_len);
562                self.cur += 1;
563                // SAFETY: the pointer should be valid for 'a
564                Some(std::slice::from_raw_parts(data, self.chunk_len))
565            }
566        } else {
567            None
568        }
569    }
570
571    #[inline]
572    fn size_hint(&self) -> (usize, Option<usize>) {
573        let remaining = self.len - self.cur;
574        return (remaining, Some(remaining));
575    }
576}
577
578impl ExactSizeIterator for LabelsIter<'_> {
579    #[inline]
580    fn len(&self) -> usize {
581        self.len
582    }
583}
584
585impl FusedIterator for LabelsIter<'_> {}
586
587impl<'a> IntoIterator for &'a Labels {
588    type IntoIter = LabelsIter<'a>;
589    type Item = &'a [LabelValue];
590
591    #[inline]
592    fn into_iter(self) -> Self::IntoIter {
593        self.iter()
594    }
595}
596
597/// Parallel iterator over entries in a set of [`Labels`]
598#[cfg(feature = "rayon")]
599#[derive(Debug, Clone)]
600pub struct LabelsParIter<'a> {
601    chunks: rayon::slice::ChunksExact<'a, LabelValue>,
602}
603
604#[cfg(feature = "rayon")]
605impl<'a> rayon::iter::ParallelIterator for LabelsParIter<'a> {
606    type Item = &'a [LabelValue];
607
608    #[inline]
609    fn drive_unindexed<C>(self, consumer: C) -> C::Result
610    where
611        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
612        self.chunks.drive_unindexed(consumer)
613    }
614}
615
616#[cfg(feature = "rayon")]
617impl rayon::iter::IndexedParallelIterator for LabelsParIter<'_> {
618    #[inline]
619    fn len(&self) -> usize {
620        self.chunks.len()
621    }
622
623    #[inline]
624    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
625        self.chunks.drive(consumer)
626    }
627
628    #[inline]
629    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
630        self.chunks.with_producer(callback)
631    }
632}
633
634/// Iterator over entries in a set of [`Labels`] as fixed size arrays
635#[derive(Debug, Clone)]
636pub struct LabelsFixedSizeIter<'a, const N: usize> {
637    values: &'a [LabelValue],
638}
639
640impl<'a, const N: usize> Iterator for LabelsFixedSizeIter<'a, N> {
641    type Item = &'a [LabelValue; N];
642
643    #[inline]
644    fn next(&mut self) -> Option<Self::Item> {
645        if self.values.is_empty() {
646            return None
647        }
648
649        let (value, rest) = self.values.split_at(N);
650        self.values = rest;
651        return Some(value.try_into().expect("wrong size in FixedSizeIter::next"));
652    }
653
654    fn size_hint(&self) -> (usize, Option<usize>) {
655        (self.len(), Some(self.len()))
656    }
657}
658
659impl<const N: usize> ExactSizeIterator for LabelsFixedSizeIter<'_, N> {
660    #[inline]
661    fn len(&self) -> usize {
662        self.values.len() / N
663    }
664}
665
666/// Builder for [`Labels`]
667#[derive(Debug, Clone)]
668pub struct LabelsBuilder {
669    // cf `Labels` for the documentation of the fields
670    names: Vec<String>,
671    values: Vec<LabelValue>,
672}
673
674impl LabelsBuilder {
675    /// Create a new empty `LabelsBuilder` with the given `names`
676    #[inline]
677    pub fn new(names: Vec<&str>) -> LabelsBuilder {
678        let n_unique_names = names.iter().collect::<BTreeSet<_>>().len();
679        assert!(n_unique_names == names.len(), "invalid labels: the same name is used multiple times");
680
681        LabelsBuilder {
682            names: names.into_iter().map(|s| s.into()).collect(),
683            values: Vec::new(),
684        }
685    }
686
687    /// Reserve space for `additional` other entries in the labels.
688    #[inline]
689    pub fn reserve(&mut self, additional: usize) {
690        self.values.reserve(additional * self.names.len());
691    }
692
693    /// Get the number of labels in a single value
694    #[inline]
695    pub fn size(&self) -> usize {
696        self.names.len()
697    }
698
699    /// Add a single `entry` to this set of labels.
700    ///
701    /// This function will panic when attempting to add the same `label` more
702    /// than once.
703    #[inline]
704    pub fn add<T>(&mut self, entry: &[T]) where T: Copy + Into<LabelValue> {
705        assert_eq!(
706            self.size(), entry.len(),
707            "wrong size for added label: got {}, but expected {}",
708            entry.len(), self.size()
709        );
710
711        // SmallVec allows us to convert everything to `LabelValue` without
712        // requiring an extra heap allocation
713        let entry = entry.iter().copied().map(Into::into).collect::<SmallVec<[LabelValue; 16]>>();
714        self.values.extend(&entry);
715    }
716
717    /// Finish building the `Labels`
718    #[inline]
719    pub fn finish(self) -> Labels {
720        let mut raw_names = Vec::new();
721        let mut raw_names_ptr = Vec::new();
722
723        let mut raw_labels = if self.names.is_empty() {
724            assert!(self.values.is_empty());
725            mts_labels_t::null()
726        } else {
727            for name in &self.names {
728                let name = CString::new(&**name).expect("name contains a NULL byte");
729                raw_names_ptr.push(name.as_ptr());
730                raw_names.push(name);
731            }
732
733            mts_labels_t {
734                internal_ptr_: std::ptr::null_mut(),
735                names: raw_names_ptr.as_ptr(),
736                values: self.values.as_ptr().cast(),
737                size: self.size(),
738                count: self.values.len() / self.size(),
739            }
740        };
741
742        unsafe {
743            check_status(
744                crate::c_api::mts_labels_create(&mut raw_labels)
745            ).expect("invalid labels?");
746        }
747
748        return unsafe { Labels::from_raw(raw_labels) };
749    }
750}
751
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756
757    #[test]
758    fn labels() {
759        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
760        builder.add(&[2, 3]);
761        builder.add(&[1, 243]);
762        builder.add(&[-4, -2413]);
763
764        let idx = builder.finish();
765        assert_eq!(idx.names(), &["foo", "bar"]);
766        assert_eq!(idx.size(), 2);
767        assert_eq!(idx.count(), 3);
768
769        assert_eq!(idx[0], [2, 3]);
770        assert_eq!(idx[1], [1, 243]);
771        assert_eq!(idx[2], [-4, -2413]);
772
773        let builder = LabelsBuilder::new(vec![]);
774        let labels = builder.finish();
775        assert_eq!(labels.size(), 0);
776        assert_eq!(labels.count(), 0);
777    }
778
779    #[test]
780    fn direct_construct() {
781        let labels = Labels::new(
782            ["foo", "bar"],
783            &[
784                [2, 3],
785                [1, 243],
786                [-4, -2413],
787            ]
788        );
789
790        assert_eq!(labels.names(), &["foo", "bar"]);
791        assert_eq!(labels.size(), 2);
792        assert_eq!(labels.count(), 3);
793
794        assert_eq!(labels[0], [2, 3]);
795        assert_eq!(labels[1], [1, 243]);
796        assert_eq!(labels[2], [-4, -2413]);
797    }
798
799    #[test]
800    fn labels_iter() {
801        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
802        builder.add(&[2, 3]);
803        builder.add(&[1, 2]);
804        builder.add(&[4, 3]);
805
806        let idx = builder.finish();
807        let mut iter = idx.iter();
808        assert_eq!(iter.len(), 3);
809
810        assert_eq!(iter.next().unwrap(), &[2, 3]);
811        assert_eq!(iter.next().unwrap(), &[1, 2]);
812        assert_eq!(iter.next().unwrap(), &[4, 3]);
813        assert_eq!(iter.next(), None);
814    }
815
816    #[test]
817    #[should_panic(expected = "'33 bar' is not a valid label name")]
818    fn invalid_label_name() {
819        LabelsBuilder::new(vec!["foo", "33 bar"]).finish();
820    }
821
822    #[test]
823    #[should_panic(expected = "invalid labels: the same name is used multiple times")]
824    fn duplicated_label_name() {
825        LabelsBuilder::new(vec!["foo", "bar", "foo"]).finish();
826    }
827
828    #[test]
829    #[should_panic(expected = "can not have the same label entry multiple time: [0, 1] is already present")]
830    fn duplicated_label_entry() {
831        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
832        builder.add(&[0, 1]);
833        builder.add(&[0, 1]);
834        builder.finish();
835    }
836
837    #[test]
838    fn single_label() {
839        let labels = Labels::single();
840        assert_eq!(labels.names(), &["_"]);
841        assert_eq!(labels.size(), 1);
842        assert_eq!(labels.count(), 1);
843    }
844
845    #[test]
846    fn indexing() {
847        let labels = Labels::new(
848            ["foo", "bar"],
849            &[
850                [2, 3],
851                [1, 243],
852                [-4, -2413],
853            ]
854        );
855
856        assert_eq!(labels[1], [1, 243]);
857        assert_eq!(labels[2], [-4, -2413]);
858    }
859
860    #[test]
861    fn iter() {
862        let labels = Labels::new(
863            ["foo", "bar"],
864            &[
865                [2, 3],
866                [1, 243],
867                [-4, -2413],
868            ]
869        );
870
871        let mut iter = labels.iter();
872
873        assert_eq!(iter.next().unwrap(), &[2, 3]);
874        assert_eq!(iter.next().unwrap(), &[1, 243]);
875        assert_eq!(iter.next().unwrap(), &[-4, -2413]);
876        assert_eq!(iter.next(), None);
877    }
878
879    #[test]
880    fn debug() {
881        let labels = Labels::new(
882            ["foo", "bar"],
883            &[
884                [2, 3],
885                [1, 243],
886                [-4, -2413],
887            ]
888        );
889
890        let expected = format!(
891            "Labels @ {:p} {{\n    foo, bar\n     2    3   \n     1   243  \n    -4   -2413  \n}}\n",
892            labels.as_mts_labels_t().internal_ptr_
893        );
894        assert_eq!(format!("{:?}", labels), expected);
895    }
896
897    #[test]
898    fn union() {
899        let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
900        let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
901
902        let mut first_mapping = vec![0; first.count()];
903        let mut second_mapping = vec![0; second.count()];
904        let union = first.union(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
905
906        assert_eq!(union.names(), ["aa", "bb"]);
907        assert_eq!(union.values(), [0, 1, 1, 2, 2, 3, 4, 5]);
908
909        assert_eq!(first_mapping, [0, 1]);
910        assert_eq!(second_mapping, [2, 1, 3]);
911    }
912
913    #[test]
914    fn intersection() {
915        let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
916        let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
917
918        let mut first_mapping = vec![0_i64; first.count()];
919        let mut second_mapping = vec![0_i64; second.count()];
920        let union = first.intersection(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
921
922        assert_eq!(union.names(), ["aa", "bb"]);
923        assert_eq!(union.values(), [1, 2]);
924
925        assert_eq!(first_mapping, [-1, 0]);
926        assert_eq!(second_mapping, [-1, 0, -1]);
927    }
928
929    #[test]
930    fn difference() {
931        let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
932        let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
933
934        let mut mapping = vec![0_i64; first.count()];
935        let union = first.difference(&second, Some(&mut mapping)).unwrap();
936
937        assert_eq!(union.names(), ["aa", "bb"]);
938        assert_eq!(union.values(), [0, 1]);
939
940        assert_eq!(mapping, [0, -1]);
941    }
942
943    #[test]
944    fn selection() {
945        // selection with a subset of names
946        let labels = Labels::new(["aa", "bb"], &[[1, 1], [1, 2], [3, 2], [2, 1]]);
947        let selection = Labels::new(["aa"], &[[1], [2], [5]]);
948
949        let selected = labels.select(&selection).unwrap();
950        assert_eq!(selected, [0, 1, 3]);
951
952        // selection with the same names
953        let selection = Labels::new(["aa", "bb"], &[[1, 1], [2, 1], [5, 1], [1, 2]]);
954        let selected = labels.select(&selection).unwrap();
955        assert_eq!(selected, [0, 3, 1]);
956
957        // empty selection
958        let selection = Labels::empty(vec!["aa"]);
959        let selected = labels.select(&selection).unwrap();
960        assert_eq!(selected, []);
961
962        // invalid selection names
963        let selection = Labels::empty(vec!["aaa"]);
964        let err = labels.select(&selection).unwrap_err();
965        assert_eq!(err.message,
966            "invalid parameter: 'aaa' in selection is not part of these Labels"
967        );
968    }
969}