Skip to main content

metatensor/
labels.rs

1use std::sync::OnceLock;
2use std::ffi::{CStr, CString};
3use std::collections::BTreeSet;
4use std::iter::FusedIterator;
5
6use crate::MtsArray;
7use crate::c_api::mts_labels_t;
8use crate::errors::check_ptr;
9use crate::errors::{Error, check_status};
10
11/// A single value inside a label.
12///
13/// This is represented as a 32-bit signed integer, with a couple of helper
14/// function to get its value as usize/isize.
15#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[repr(transparent)]
17pub struct LabelValue(i32);
18
19impl PartialEq<i32> for LabelValue {
20    #[inline]
21    fn eq(&self, other: &i32) -> bool {
22        self.0 == *other
23    }
24}
25
26impl PartialEq<LabelValue> for i32 {
27    #[inline]
28    fn eq(&self, other: &LabelValue) -> bool {
29        *self == other.0
30    }
31}
32
33impl std::fmt::Debug for LabelValue {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "{}", self.0)
36    }
37}
38
39impl std::fmt::Display for LabelValue {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
46impl From<u32> for LabelValue {
47    #[inline]
48    fn from(value: u32) -> LabelValue {
49        assert!(value < i32::MAX as u32);
50        LabelValue(value as i32)
51    }
52}
53
54impl From<i32> for LabelValue {
55    #[inline]
56    fn from(value: i32) -> LabelValue {
57        LabelValue(value)
58    }
59}
60
61#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
62impl From<usize> for LabelValue {
63    #[inline]
64    fn from(value: usize) -> LabelValue {
65        assert!(value < i32::MAX as usize);
66        LabelValue(value as i32)
67    }
68}
69
70#[allow(clippy::cast_possible_truncation)]
71impl From<isize> for LabelValue {
72    #[inline]
73    fn from(value: isize) -> LabelValue {
74        assert!(value < i32::MAX as isize && value > i32::MIN as isize);
75        LabelValue(value as i32)
76    }
77}
78
79impl LabelValue {
80    /// Create a `LabelValue` with the given `value`
81    #[inline]
82    pub fn new(value: i32) -> LabelValue {
83        LabelValue(value)
84    }
85
86    /// Get the integer value of this `LabelValue` as a usize
87    #[inline]
88    #[allow(clippy::cast_sign_loss)]
89    pub fn usize(self) -> usize {
90        debug_assert!(self.0 >= 0);
91        self.0 as usize
92    }
93
94    /// Get the integer value of this `LabelValue` as an isize
95    #[inline]
96    pub fn isize(self) -> isize {
97        self.0 as isize
98    }
99
100    /// Get the integer value of this `LabelValue` as an i32
101    #[inline]
102    pub fn i32(self) -> i32 {
103        self.0
104    }
105}
106
107/// A set of labels used to carry metadata associated with a tensor map.
108///
109/// This is similar to a list of named tuples, but stored as a 2D array of shape
110/// `(labels.count(), labels.size())`, with a of set names associated with the
111/// columns of this array. Each row/entry in this array is unique, and they are
112/// often (but not always) sorted in  lexicographic order.
113///
114/// The main way to construct a new set of labels is to use [`Labels::new`] or
115/// [`Labels::new_assume_unique`].
116///
117/// Labels are internally reference counted and immutable, so cloning a `Labels`
118/// should be a cheap operation.
119pub struct Labels {
120    pub(crate) ptr: *const mts_labels_t,
121    values_cpu_ptr: OnceLock<*const LabelValue>,
122    count: OnceLock<usize>,
123    size: OnceLock<usize>,
124}
125
126// Labels can be sent to other thread safely since mts_labels_t uses an
127// `Arc<metatensor_core::Labels>`, so freeing them from another thread is fine
128unsafe impl Send for Labels {}
129// &Labels can be sent to other thread safely since the interior mutability
130// (values array) uses OnceCell, which is Sync.
131unsafe impl Sync for Labels {}
132
133impl std::fmt::Debug for Labels {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        pretty_print_labels(self, "", f)
136    }
137}
138
139/// Helper function to print labels in a Debug mode
140pub(crate) fn pretty_print_labels(
141    labels: &Labels,
142    offset: &str,
143    f: &mut std::fmt::Formatter<'_>
144) -> std::fmt::Result {
145    let names = labels.names();
146
147    writeln!(f, "Labels @ {:p} {{", labels.ptr)?;
148    writeln!(f, "{}    {}", offset, names.join(", "))?;
149
150    let widths = names.iter().map(|s| s.len()).collect::<Vec<_>>();
151    for values in labels {
152        write!(f, "{}    ", offset)?;
153        for (value, width) in values.iter().zip(&widths) {
154            write!(f, "{:^width$}  ", value.isize(), width=width)?;
155        }
156        writeln!(f)?;
157    }
158
159    writeln!(f, "{}}}", offset)
160}
161
162impl Clone for Labels {
163    #[inline]
164    fn clone(&self) -> Self {
165        let ptr = unsafe { crate::c_api::mts_labels_clone(self.ptr) };
166        assert!(!ptr.is_null(), "failed to clone Labels");
167        Labels {
168            ptr,
169            values_cpu_ptr: self.values_cpu_ptr.clone(),
170            count: self.count.clone(),
171            size: self.size.clone(),
172        }
173    }
174}
175
176impl std::ops::Drop for Labels {
177    #[allow(unused_must_use)]
178    fn drop(&mut self) {
179        unsafe {
180            crate::c_api::mts_labels_free(self.ptr);
181        }
182    }
183}
184
185impl Labels {
186    /// Create a new set of Labels with the given names and values.
187    ///
188    /// The `values` can be any type that can be converted into an
189    /// [`MtsArray`], including `Vec<[i32; N]>`, `&[[i32; N]]`, or
190    /// `ndarray::Array`.
191    ///
192    /// # Panics
193    ///
194    /// If the set of names is not valid, or any of the value is duplicated.
195    #[inline]
196    pub fn new<'a>(names: impl AsRef<[&'a str]>, values: impl Into<MtsArray>) -> Labels {
197        Self::new_impl(names.as_ref(), values, crate::c_api::mts_labels)
198    }
199
200    /// Create a new set of Labels with the given names and values, without
201    /// checking that the entries are unique.
202    ///
203    /// This is faster than [`Labels::new`] as it does not perform a uniqueness
204    /// check on the labels entries. It is the caller's responsibility to ensure
205    /// that entries are unique.
206    ///
207    /// # Panics
208    ///
209    /// If the set of names is not valid.
210    #[inline]
211    pub fn new_assume_unique<'a>(names: impl AsRef<[&'a str]>, values: impl Into<MtsArray>) -> Labels {
212        Self::new_impl(names.as_ref(), values, crate::c_api::mts_labels_assume_unique)
213    }
214
215    /// Common implementation for `new` and `new_assume_unique`
216    fn new_impl(
217        names: &[&str],
218        values: impl Into<MtsArray>,
219        creator: unsafe extern "C" fn(
220            *const *const std::os::raw::c_char,
221            usize,
222            crate::c_api::mts_array_t,
223        ) -> *const mts_labels_t
224    ) -> Labels {
225        let n_unique_names = names.iter().collect::<BTreeSet<_>>().len();
226        assert!(n_unique_names == names.len(), "invalid labels: the same name is used multiple times");
227
228        let mut raw_names = Vec::new();
229        let mut raw_names_ptr = Vec::new();
230        for name in names {
231            let c_name = CString::new(*name).expect("name contains a NULL byte");
232            raw_names_ptr.push(c_name.as_ptr());
233            raw_names.push(c_name);
234        }
235
236        let array: MtsArray = values.into();
237        let ptr = unsafe {
238            creator(
239                raw_names_ptr.as_ptr(),
240                raw_names.len(),
241                array.into_raw(),
242            )
243        };
244        check_ptr(ptr).expect("invalid labels");
245
246        unsafe { Labels::from_raw(ptr) }
247    }
248
249    /// Get a pointer to the underlying `mts_labels_t`
250    pub fn as_mts_labels_t(&self) -> *const mts_labels_t {
251        self.ptr
252    }
253
254    /// Create a new set of `Labels` from a raw `*mut mts_labels_t` pointer.
255    ///
256    /// This function takes ownership of the pointer and will call
257    /// `mts_labels_free` on it when dropped.
258    ///
259    /// # Safety
260    ///
261    /// The pointer must be non-null and returned by one of the metatensor-core
262    /// functions that create `mts_labels_t`.
263    #[inline]
264    pub unsafe fn from_raw(ptr: *const mts_labels_t) -> Labels {
265        assert!(!ptr.is_null(), "expected mts_labels_t pointer to not be NULL");
266        Labels {
267            ptr,
268            values_cpu_ptr: OnceLock::new(),
269            size: OnceLock::new(),
270            count: OnceLock::new(),
271        }
272    }
273
274    /// Consume the Labels and get the underlying raw pointer.
275    ///
276    /// After calling this function, the user is responsible for free-ing the
277    /// data in `mts_labels_t`, either by re-creating labels with
278    /// [`Labels::from_raw`] or passing it to a C API function that will call
279    /// [`crate::c_api::mts_labels_free`] on it.
280    #[inline]
281    pub fn into_raw(mut labels: Labels) -> *const mts_labels_t {
282        return std::mem::replace(&mut labels.ptr, std::ptr::null());
283    }
284
285    /// Create a set of `Labels` with the given names, containing no entries.
286    #[inline]
287    pub fn empty<'a>(names: impl AsRef<[&'a str]>) -> Labels {
288        let names = names.as_ref();
289        let array = ndarray::Array::<i32, _>::from_shape_vec(
290            vec![0, names.len()], vec![]
291        ).expect("shape mismatch when creating empty labels array");
292        Labels::new(names, array)
293    }
294
295    /// Create a set of `Labels` containing a single entry, to be used when
296    /// there is no relevant information to store.
297    #[inline]
298    pub fn single() -> Labels {
299        Labels::new(["_"], vec![[0i32]])
300    }
301
302    /// Load `Labels` from the file at `path`
303    ///
304    /// This is a convenience function calling [`crate::io::load_labels`]
305    pub fn load(path: impl AsRef<std::path::Path>) -> Result<Labels, Error> {
306        return crate::io::load_labels(path);
307    }
308
309    /// Load a `TensorMap` from an in-memory buffer
310    ///
311    /// This is a convenience function calling [`crate::io::load_buffer`]
312    pub fn load_buffer(buffer: &[u8]) -> Result<Labels, Error> {
313        return crate::io::load_labels_buffer(buffer);
314    }
315
316    /// Save the given tensor to the file at `path`
317    ///
318    /// This is a convenience function calling [`crate::io::save`]
319    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
320        return crate::io::save_labels(path, self);
321    }
322
323    /// Save the given tensor to an in-memory buffer
324    ///
325    /// This is a convenience function calling [`crate::io::save_buffer`]
326    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
327        return crate::io::save_labels_buffer(self, buffer);
328    }
329
330    /// Get the number of entries/named values in a single label
331    #[inline]
332    pub fn size(&self) -> usize {
333        *self.size.get_or_init(|| self.names().len())
334    }
335
336    /// Get the names of the entries/columns in this set of labels
337    #[inline]
338    pub fn names(&self) -> Vec<&str> {
339        let mut names_ptr = std::ptr::null();
340        let mut size = 0;
341        unsafe {
342            check_status(crate::c_api::mts_labels_dimensions(self.ptr, &mut names_ptr, &mut size))
343                .expect("failed to get labels dimensions");
344        }
345
346        if size == 0 {
347            return Vec::new();
348        }
349
350        unsafe {
351            let names = std::slice::from_raw_parts(names_ptr, size);
352            return names.iter()
353                        .map(|&ptr| CStr::from_ptr(ptr).to_str().expect("invalid UTF8"))
354                        .collect();
355        }
356    }
357
358    /// Get the values of these labels as a `MtsArray`
359    pub fn values(&self) -> MtsArray {
360        let mut array = crate::c_api::mts_array_t::null();
361        unsafe {
362            check_status(crate::c_api::mts_labels_values(
363                self.ptr, &mut array,
364            )).expect("failed to get labels values array");
365        }
366
367        return MtsArray::from_raw(array);
368    }
369
370    /// Get the values of these Labels on CPU, potentially copying them from
371    /// another device.
372    pub fn values_cpu(&self) -> &[LabelValue] {
373        let values_cpu_ptr = self.values_cpu_ptr.get_or_init(|| {
374            let mut values_cpu_ptr = std::ptr::null();
375            let mut count = 0;
376            let mut size = 0;
377
378            unsafe {
379                check_status(crate::c_api::mts_labels_values_cpu(
380                    self.ptr,
381                    &mut values_cpu_ptr,
382                    &mut count,
383                    &mut size
384                )).expect("failed to get CPU values for Labels");
385            }
386
387            debug_assert_eq!(count, self.count());
388            debug_assert_eq!(size, self.size());
389
390            values_cpu_ptr.cast()
391        });
392
393        unsafe {
394            let count = self.count();
395            let size = self.size();
396            std::slice::from_raw_parts(*values_cpu_ptr, count * size)
397        }
398    }
399
400    /// Get the total number of entries in this set of labels
401    #[inline]
402    pub fn device(&self) -> dlpk::DLDevice {
403        let array = self.values();
404        return array.device().expect("failed to get the array device");
405    }
406
407    /// Get the total number of entries in this set of labels
408    #[inline]
409    pub fn count(&self) -> usize {
410        return *self.count.get_or_init(|| self.values().shape().expect("failed to get the array shape")[0]);
411    }
412
413    /// Check if this set of Labels is empty (contains no entry)
414    #[inline]
415    pub fn is_empty(&self) -> bool {
416        self.count() == 0
417    }
418
419    /// Check whether the given `label` is part of this set of labels
420    #[inline]
421    pub fn contains(&self, label: &[LabelValue]) -> bool {
422        return self.position(label).is_some();
423    }
424
425    /// Get the position (i.e. row index) of the given label in the full labels
426    /// array, or None.
427    #[inline]
428    pub fn position(&self, value: &[LabelValue]) -> Option<usize> {
429        assert!(value.len() == self.size(), "invalid size of index in Labels::position");
430
431        let mut result = 0;
432        unsafe {
433            check_status(crate::c_api::mts_labels_position(
434                self.ptr,
435                value.as_ptr().cast(),
436                value.len(),
437                &mut result,
438            )).expect("failed to check label position");
439        }
440
441        return result.try_into().ok();
442    }
443
444    /// Take the union of `self` with `other`.
445    ///
446    /// If requested, this function can also give the positions in the union
447    /// where each entry of the input `Labels` ended up.
448    ///
449    /// If `first_mapping` (respectively `second_mapping`) is `Some`, it should
450    /// contain a slice of length `self.count()` (respectively `other.count()`)
451    /// that will be filled with the position of the entries in `self`
452    /// (respectively `other`) in the union.
453    #[inline]
454    pub fn union(
455        &self,
456        other: &Labels,
457        first_mapping: Option<&mut [i64]>,
458        second_mapping: Option<&mut [i64]>,
459    ) -> Result<Labels, Error> {
460        let mut output: *const mts_labels_t = std::ptr::null();
461        let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
462            (m.as_mut_ptr(), m.len())
463        } else {
464            (std::ptr::null_mut(), 0)
465        };
466
467        let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
468            (m.as_mut_ptr(), m.len())
469        } else {
470            (std::ptr::null_mut(), 0)
471        };
472
473        unsafe {
474            check_status(crate::c_api::mts_labels_union(
475                self.ptr,
476                other.ptr,
477                &mut output,
478                first_mapping,
479                first_mapping_count,
480                second_mapping,
481                second_mapping_count,
482            ))?;
483
484            return Ok(Labels::from_raw(output));
485        }
486    }
487
488    /// Take the intersection of self with `other`.
489    ///
490    /// If requested, this function can also give the positions in the
491    /// intersection where each entry of the input `Labels` ended up.
492    ///
493    /// If `first_mapping` (respectively `second_mapping`) is `Some`, it should
494    /// contain a slice of length `self.count()` (respectively `other.count()`)
495    /// that will be filled by with the position of the entries in `self`
496    /// (respectively `other`) in the intersection. If an entry in `self` or
497    /// `other` are not used in the intersection, the mapping for this entry
498    /// will be set to `-1`.
499    #[inline]
500    pub fn intersection(
501        &self,
502        other: &Labels,
503        first_mapping: Option<&mut [i64]>,
504        second_mapping: Option<&mut [i64]>,
505    ) -> Result<Labels, Error> {
506        let mut output: *const mts_labels_t = std::ptr::null();
507        let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
508            (m.as_mut_ptr(), m.len())
509        } else {
510            (std::ptr::null_mut(), 0)
511        };
512
513        let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
514            (m.as_mut_ptr(), m.len())
515        } else {
516            (std::ptr::null_mut(), 0)
517        };
518
519        unsafe {
520            check_status(crate::c_api::mts_labels_intersection(
521                self.ptr,
522                other.ptr,
523                &mut output,
524                first_mapping,
525                first_mapping_count,
526                second_mapping,
527                second_mapping_count,
528            ))?;
529
530            return Ok(Labels::from_raw(output));
531        }
532    }
533
534    /// Take the set difference of `self` with `other`.
535    ///
536    /// If requested, this function can also give the positions in the
537    /// difference where each entry of `self` ended up.
538    ///
539    /// If `mapping` is `Some`, it should contain a slice of length
540    /// `self.count()` that will be filled by with the position of the entries
541    /// in `self` in the difference. If an entry is not used in the difference,
542    ///  the mapping for this entry will be set to `-1`.
543    #[inline]
544    pub fn difference(
545        &self,
546        other: &Labels,
547        mapping: Option<&mut [i64]>,
548    ) -> Result<Labels, Error> {
549        let mut output: *const mts_labels_t = std::ptr::null();
550        let (mapping, mapping_count) = if let Some(m) = mapping {
551            (m.as_mut_ptr(), m.len())
552        } else {
553            (std::ptr::null_mut(), 0)
554        };
555
556        unsafe {
557            check_status(crate::c_api::mts_labels_difference(
558                self.ptr,
559                other.ptr,
560                &mut output,
561                mapping,
562                mapping_count,
563            ))?;
564
565            return Ok(Labels::from_raw(output));
566        }
567    }
568
569    /// Select entries in these `Labels` that match the `selection`.
570    ///
571    /// The selection's names must be a subset of the names of these labels.
572    ///
573    /// All entries in these `Labels` that match one of the entry in the
574    /// `selection` for all the selection's dimension will be picked. Any entry
575    /// in the `selection` but not in these `Labels` will be ignored.
576    #[allow(clippy::cast_possible_truncation)]
577    pub fn select(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
578        let mut selected = vec![0; self.count()];
579        let mut selected_count = selected.len();
580
581        unsafe {
582            check_status(crate::c_api::mts_labels_select(
583                self.as_mts_labels_t(),
584                selection.as_mts_labels_t(),
585                selected.as_mut_ptr(),
586                &mut selected_count
587            ))?;
588        }
589
590        selected.resize(selected_count, 0);
591
592        return Ok(selected.into_iter().map(|s| s as usize).collect());
593    }
594
595    /// Iterate over the entries in this set of labels
596    #[inline]
597    pub fn iter(&self) -> LabelsIter<'_> {
598        return LabelsIter {
599            ptr: self.values_cpu().as_ptr(),
600            cur: 0,
601            len: self.count(),
602            chunk_len: self.size(),
603            phantom: std::marker::PhantomData,
604        };
605    }
606
607    /// Iterate over the entries in this set of labels in parallel
608    #[cfg(feature = "rayon")]
609    #[inline]
610    pub fn par_iter(&self) -> LabelsParIter<'_> {
611        use rayon::prelude::*;
612        return LabelsParIter {
613            chunks: self.values_cpu().par_chunks_exact(self.size())
614        };
615    }
616
617    /// Iterate over the entries in this set of labels as fixed-size arrays
618    #[inline]
619    pub fn iter_fixed_size<const N: usize>(&self) -> LabelsFixedSizeIter<'_, N> {
620        assert!(N == self.size(),
621            "wrong label size in `iter_fixed_size`: the entries contains {} element \
622            but this function was called with size of {}",
623            self.size(), N
624        );
625
626        return LabelsFixedSizeIter {
627            values: self.values_cpu()
628        };
629    }
630}
631
632impl std::cmp::PartialEq<Labels> for Labels {
633    #[inline]
634    fn eq(&self, other: &Labels) -> bool {
635        if self.names() != other.names() {
636            return false;
637        }
638
639        if self.count() != other.count() {
640            return false;
641        }
642
643        if self.device() != other.device() {
644            return false;
645        }
646
647        if self.device().device_type == dlpk::DLDeviceType::kDLExtDev {
648            // kDLExtDev is used for torch's meta device, which has no data
649            // associated, so we consider all Labels as equal as long as they
650            // have the same dimensions and number of entries
651            return true;
652        } else {
653            return self.values_cpu() == other.values_cpu();
654        }
655    }
656}
657
658impl std::ops::Index<usize> for Labels {
659    type Output = [LabelValue];
660
661    #[inline]
662    fn index(&self, i: usize) -> &[LabelValue] {
663        let start = i * self.size();
664        let stop = (i + 1) * self.size();
665        &self.values_cpu()[start..stop]
666    }
667}
668
669/// iterator over [`Labels`] entries
670pub struct LabelsIter<'a> {
671    /// start of the labels values
672    ptr: *const LabelValue,
673    /// Current entry index
674    cur: usize,
675    /// number of entries
676    len: usize,
677    /// size of an entry/the labels
678    chunk_len: usize,
679    phantom: std::marker::PhantomData<&'a LabelValue>,
680}
681
682impl<'a> Iterator for LabelsIter<'a> {
683    type Item = &'a [LabelValue];
684
685    #[inline]
686    fn next(&mut self) -> Option<Self::Item> {
687        if self.cur < self.len {
688            unsafe {
689                // SAFETY: this should be in-bounds
690                let data = self.ptr.add(self.cur * self.chunk_len);
691                self.cur += 1;
692                // SAFETY: the pointer should be valid for 'a
693                Some(std::slice::from_raw_parts(data, self.chunk_len))
694            }
695        } else {
696            None
697        }
698    }
699
700    #[inline]
701    fn size_hint(&self) -> (usize, Option<usize>) {
702        let remaining = self.len - self.cur;
703        return (remaining, Some(remaining));
704    }
705}
706
707impl ExactSizeIterator for LabelsIter<'_> {
708    #[inline]
709    fn len(&self) -> usize {
710        self.len
711    }
712}
713
714impl FusedIterator for LabelsIter<'_> {}
715
716impl<'a> IntoIterator for &'a Labels {
717    type IntoIter = LabelsIter<'a>;
718    type Item = &'a [LabelValue];
719
720    #[inline]
721    fn into_iter(self) -> Self::IntoIter {
722        self.iter()
723    }
724}
725
726/// Parallel iterator over entries in a set of [`Labels`]
727#[cfg(feature = "rayon")]
728#[derive(Debug, Clone)]
729pub struct LabelsParIter<'a> {
730    chunks: rayon::slice::ChunksExact<'a, LabelValue>,
731}
732
733#[cfg(feature = "rayon")]
734impl<'a> rayon::iter::ParallelIterator for LabelsParIter<'a> {
735    type Item = &'a [LabelValue];
736
737    #[inline]
738    fn drive_unindexed<C>(self, consumer: C) -> C::Result
739    where
740        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
741        self.chunks.drive_unindexed(consumer)
742    }
743}
744
745#[cfg(feature = "rayon")]
746impl rayon::iter::IndexedParallelIterator for LabelsParIter<'_> {
747    #[inline]
748    fn len(&self) -> usize {
749        self.chunks.len()
750    }
751
752    #[inline]
753    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
754        self.chunks.drive(consumer)
755    }
756
757    #[inline]
758    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
759        self.chunks.with_producer(callback)
760    }
761}
762
763/// Iterator over entries in a set of [`Labels`] as fixed size arrays
764#[derive(Debug, Clone)]
765pub struct LabelsFixedSizeIter<'a, const N: usize> {
766    values: &'a [LabelValue],
767}
768
769impl<'a, const N: usize> Iterator for LabelsFixedSizeIter<'a, N> {
770    type Item = &'a [LabelValue; N];
771
772    #[inline]
773    fn next(&mut self) -> Option<Self::Item> {
774        if self.values.is_empty() {
775            return None
776        }
777
778        let (value, rest) = self.values.split_at(N);
779        self.values = rest;
780        return Some(value.try_into().expect("wrong size in FixedSizeIter::next"));
781    }
782
783    fn size_hint(&self) -> (usize, Option<usize>) {
784        (self.len(), Some(self.len()))
785    }
786}
787
788impl<const N: usize> ExactSizeIterator for LabelsFixedSizeIter<'_, N> {
789    #[inline]
790    fn len(&self) -> usize {
791        self.values.len() / N
792    }
793}
794
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799
800    #[test]
801    fn labels() {
802        let labels = Labels::new(["foo", "bar"], vec![[2, 3], [1, 243], [-4, -2413]]);
803        assert_eq!(labels.names(), &["foo", "bar"]);
804        assert_eq!(labels.size(), 2);
805        assert_eq!(labels.count(), 3);
806        assert!(!labels.is_empty());
807
808        assert_eq!(labels[0], [2, 3]);
809        assert_eq!(labels[1], [1, 243]);
810        assert_eq!(labels[2], [-4, -2413]);
811
812        let labels = Labels::new(&[] as &[&str], Vec::<[i32; 0]>::new());
813        assert_eq!(labels.size(), 0);
814        assert_eq!(labels.count(), 0);
815
816        let labels = Labels::new_assume_unique(["foo", "bar"], vec![[2, 3], [1, 243]]);
817        assert_eq!(labels.names(), &["foo", "bar"]);
818        assert_eq!(labels.size(), 2);
819        assert_eq!(labels.count(), 2);
820    }
821
822    #[test]
823    fn direct_construct() {
824        let labels = Labels::new(
825            ["foo", "bar"],
826            vec![[2, 3], [1, 243], [-4, -2413]],
827        );
828
829        assert_eq!(labels.names(), &["foo", "bar"]);
830        assert_eq!(labels.size(), 2);
831        assert_eq!(labels.count(), 3);
832
833        assert_eq!(labels[0], [2, 3]);
834        assert_eq!(labels[1], [1, 243]);
835        assert_eq!(labels[2], [-4, -2413]);
836    }
837
838    #[test]
839    fn iter() {
840        let labels = Labels::new(["foo", "bar"], vec![[2, 3], [1, 2], [4, 3]]);
841        let mut iter = labels.iter();
842        assert_eq!(iter.len(), 3);
843
844        assert_eq!(iter.next().unwrap(), &[2, 3]);
845        assert_eq!(iter.next().unwrap(), &[1, 2]);
846        assert_eq!(iter.next().unwrap(), &[4, 3]);
847        assert_eq!(iter.next(), None);
848    }
849
850    #[cfg( feature = "rayon")]
851    #[test]
852    fn par_iter() {
853        use rayon::iter::IndexedParallelIterator;
854
855        let labels = Labels::new(["foo", "bar"], vec![[2, 3], [1, 2], [4, 3]]);
856        let iter = labels.par_iter();
857        assert_eq!(iter.len(), 3);
858
859        let mut values = Vec::new();
860        iter.collect_into_vec(&mut values);
861
862        assert_eq!(values, [&[2, 3], &[1, 2], &[4, 3]]);
863    }
864
865    #[test]
866    fn iter_fixed_size() {
867        let labels = Labels::new(["foo", "bar"], vec![[1, 2], [2, 3]]);
868
869        for (i, [a, b]) in labels.iter_fixed_size().enumerate() {
870            assert_eq!(a.usize(), 1 + i);
871            assert_eq!(b.usize(), 2 + i);
872        }
873    }
874
875    #[test]
876    #[should_panic(expected = "wrong label size in `iter_fixed_size`: the entries contains 2 element but this function was called with size of 3")]
877    fn iter_fixed_size_wrong_size() {
878        let labels = Labels::new(["foo", "bar"], Vec::<[i32; 2]>::new());
879
880        for [_, _, _] in labels.iter_fixed_size() {}
881    }
882
883    #[test]
884    #[should_panic(expected = "invalid parameter: '33 bar' is not a valid label name")]
885    fn invalid_label_name() {
886        Labels::new(["foo", "33 bar"], Vec::<[i32; 2]>::new());
887    }
888
889    #[test]
890    #[should_panic(expected = "invalid labels: the same name is used multiple times")]
891    fn duplicated_label_name() {
892        Labels::new(["foo", "bar", "foo"], Vec::<[i32; 3]>::new());
893    }
894
895    #[test]
896    #[should_panic(expected = "can not have the same label entry multiple times: [0, 1] is already present")]
897    fn duplicated_label_entry() {
898        Labels::new(["foo", "bar"], vec![[0, 1], [0, 1]]);
899    }
900
901    #[test]
902    fn single_label() {
903        let labels = Labels::single();
904        assert_eq!(labels.names(), &["_"]);
905        assert_eq!(labels.size(), 1);
906        assert_eq!(labels.count(), 1);
907    }
908
909    #[test]
910    fn empty_label() {
911        let labels = Labels::empty(vec!["foo", "bar"]);
912
913        assert!(labels.is_empty());
914        assert_eq!(labels.count(), 0);
915        assert_eq!(labels.size(), 2);
916    }
917
918    #[test]
919    fn position() {
920        let labels = Labels::new(["foo", "bar"], vec![[1, 2], [2, 3]]);
921
922        assert!(labels.contains(&[LabelValue::new(1), LabelValue::new(2)]));
923        assert_eq!(labels.position(&[LabelValue::new(1), LabelValue::new(2)]), Some(0));
924
925        assert!(labels.contains(&[LabelValue::new(2), LabelValue::new(3)]));
926        assert_eq!(labels.position(&[LabelValue::new(2), LabelValue::new(3)]), Some(1));
927
928        assert!(!labels.contains(&[LabelValue::new(3), LabelValue::new(3)]));
929        assert_eq!(labels.position(&[LabelValue::new(3), LabelValue::new(3)]), None);
930    }
931
932    #[test]
933    fn indexing() {
934        let labels = Labels::new(
935            ["foo", "bar"],
936            vec![[2, 3], [1, 243], [-4, -2413]],
937        );
938
939        assert_eq!(labels[1], [1, 243]);
940        assert_eq!(labels[2], [-4, -2413]);
941    }
942
943    #[test]
944    fn debug() {
945        let labels = Labels::new(
946            ["foo", "bar"],
947            vec![[2, 3], [1, 243], [-4, -2413]],
948        );
949
950        let expected = format!(
951            "Labels @ {:p} {{\n    foo, bar\n     2    3   \n     1   243  \n    -4   -2413  \n}}\n",
952            labels.ptr
953        );
954        assert_eq!(format!("{:?}", labels), expected);
955    }
956
957    #[test]
958    fn union() {
959        let first = Labels::new(["aa", "bb"], [[0, 1], [1, 2]]);
960        let second = Labels::new(["aa", "bb"], [[2, 3], [1, 2], [4, 5]]);
961
962        let mut first_mapping = vec![0; first.count()];
963        let mut second_mapping = vec![0; second.count()];
964        let union = first.union(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
965
966        assert_eq!(union.names(), ["aa", "bb"]);
967        assert_eq!(union.values_cpu(), [0, 1, 1, 2, 2, 3, 4, 5]);
968
969        assert_eq!(first_mapping, [0, 1]);
970        assert_eq!(second_mapping, [2, 1, 3]);
971    }
972
973    #[test]
974    fn intersection() {
975        let first = Labels::new(["aa", "bb"], [[0, 1], [1, 2]]);
976        let second = Labels::new(["aa", "bb"], [[2, 3], [1, 2], [4, 5]]);
977
978        let mut first_mapping = vec![0_i64; first.count()];
979        let mut second_mapping = vec![0_i64; second.count()];
980        let union = first.intersection(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
981
982        assert_eq!(union.names(), ["aa", "bb"]);
983        assert_eq!(union.values_cpu(), [1, 2]);
984
985        assert_eq!(first_mapping, [-1, 0]);
986        assert_eq!(second_mapping, [-1, 0, -1]);
987    }
988
989    #[test]
990    fn difference() {
991        let first = Labels::new(["aa", "bb"], [[0, 1], [1, 2]]);
992        let second = Labels::new(["aa", "bb"], [[2, 3], [1, 2], [4, 5]]);
993
994        let mut mapping = vec![0_i64; first.count()];
995        let union = first.difference(&second, Some(&mut mapping)).unwrap();
996
997        assert_eq!(union.names(), ["aa", "bb"]);
998        assert_eq!(union.values_cpu(), [0, 1]);
999
1000        assert_eq!(mapping, [0, -1]);
1001    }
1002
1003    #[test]
1004    fn selection() {
1005        // selection with a subset of names
1006        let labels = Labels::new(["aa", "bb"], [[1, 1], [1, 2], [3, 2], [2, 1]]);
1007        let selection = Labels::new(["aa"], [[1], [2], [5]]);
1008
1009        let selected = labels.select(&selection).unwrap();
1010        assert_eq!(selected, [0, 1, 3]);
1011
1012        // selection with the same names
1013        let selection = Labels::new(["aa", "bb"], [[1, 1], [2, 1], [5, 1], [1, 2]]);
1014        let selected = labels.select(&selection).unwrap();
1015        assert_eq!(selected, [0, 3, 1]);
1016
1017        // empty selection
1018        let selection = Labels::empty(vec!["aa"]);
1019        let selected = labels.select(&selection).unwrap();
1020        assert_eq!(selected, []);
1021
1022        // invalid selection names
1023        let selection = Labels::empty(vec!["aaa"]);
1024        let err = labels.select(&selection).unwrap_err();
1025        assert_eq!(err.message,
1026            "invalid parameter: 'aaa' in selection is not part of these Labels"
1027        );
1028    }
1029
1030    #[test]
1031    fn labels_into_raw() {
1032        let original = Labels::new(["foo", "bar"], [[1, 2], [3, 4], [5, 6]]);
1033        let raw = Labels::into_raw(original);
1034
1035        let recovered = unsafe { Labels::from_raw(raw) };
1036        assert_eq!(recovered, Labels::new(["foo", "bar"], [[1, 2], [3, 4], [5, 6]]));
1037    }
1038}