Skip to main content

metatensor/
labels.rs

1use std::sync::OnceLock;
2use std::ffi::{CStr, CString};
3use std::collections::BTreeSet;
4use std::iter::FusedIterator;
5
6use crate::MtsArray;
7use crate::c_api::mts_labels_t;
8use crate::errors::check_ptr;
9use crate::errors::{Error, check_status};
10
11/// A single value inside a label.
12///
13/// This is represented as a 32-bit signed integer, with a couple of helper
14/// function to get its value as usize/isize.
15#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[repr(transparent)]
17pub struct LabelValue(i32);
18
19impl PartialEq<i32> for LabelValue {
20    #[inline]
21    fn eq(&self, other: &i32) -> bool {
22        self.0 == *other
23    }
24}
25
26impl PartialEq<LabelValue> for i32 {
27    #[inline]
28    fn eq(&self, other: &LabelValue) -> bool {
29        *self == other.0
30    }
31}
32
33impl std::fmt::Debug for LabelValue {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "{}", self.0)
36    }
37}
38
39impl std::fmt::Display for LabelValue {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
46impl From<u32> for LabelValue {
47    #[inline]
48    fn from(value: u32) -> LabelValue {
49        assert!(value < i32::MAX as u32);
50        LabelValue(value as i32)
51    }
52}
53
54impl From<i32> for LabelValue {
55    #[inline]
56    fn from(value: i32) -> LabelValue {
57        LabelValue(value)
58    }
59}
60
61#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
62impl From<usize> for LabelValue {
63    #[inline]
64    fn from(value: usize) -> LabelValue {
65        assert!(value < i32::MAX as usize);
66        LabelValue(value as i32)
67    }
68}
69
70#[allow(clippy::cast_possible_truncation)]
71impl From<isize> for LabelValue {
72    #[inline]
73    fn from(value: isize) -> LabelValue {
74        assert!(value < i32::MAX as isize && value > i32::MIN as isize);
75        LabelValue(value as i32)
76    }
77}
78
79impl LabelValue {
80    /// Create a `LabelValue` with the given `value`
81    #[inline]
82    pub fn new(value: i32) -> LabelValue {
83        LabelValue(value)
84    }
85
86    /// Get the integer value of this `LabelValue` as a usize
87    #[inline]
88    #[allow(clippy::cast_sign_loss)]
89    pub fn usize(self) -> usize {
90        debug_assert!(self.0 >= 0);
91        self.0 as usize
92    }
93
94    /// Get the integer value of this `LabelValue` as an isize
95    #[inline]
96    pub fn isize(self) -> isize {
97        self.0 as isize
98    }
99
100    /// Get the integer value of this `LabelValue` as an i32
101    #[inline]
102    pub fn i32(self) -> i32 {
103        self.0
104    }
105}
106
107/// A set of labels used to carry metadata associated with a tensor map.
108///
109/// This is similar to a list of named tuples, but stored as a 2D array of shape
110/// `(labels.count(), labels.size())`, with a of set names associated with the
111/// columns of this array. Each row/entry in this array is unique, and they are
112/// often (but not always) sorted in  lexicographic order.
113///
114/// The main way to construct a new set of labels is to use a `LabelsBuilder`.
115///
116/// Labels are internally reference counted and immutable, so cloning a `Labels`
117/// should be a cheap operation.
118pub struct Labels {
119    pub(crate) ptr: *const mts_labels_t,
120    values_cpu_ptr: OnceLock<*const LabelValue>,
121    count: OnceLock<usize>,
122    size: OnceLock<usize>,
123}
124
125// Labels can be sent to other thread safely since mts_labels_t uses an
126// `Arc<metatensor_core::Labels>`, so freeing them from another thread is fine
127unsafe impl Send for Labels {}
128// &Labels can be sent to other thread safely since the interior mutability
129// (values array) uses OnceCell, which is Sync.
130unsafe impl Sync for Labels {}
131
132impl std::fmt::Debug for Labels {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        pretty_print_labels(self, "", f)
135    }
136}
137
138/// Helper function to print labels in a Debug mode
139pub(crate) fn pretty_print_labels(
140    labels: &Labels,
141    offset: &str,
142    f: &mut std::fmt::Formatter<'_>
143) -> std::fmt::Result {
144    let names = labels.names();
145
146    writeln!(f, "Labels @ {:p} {{", labels.ptr)?;
147    writeln!(f, "{}    {}", offset, names.join(", "))?;
148
149    let widths = names.iter().map(|s| s.len()).collect::<Vec<_>>();
150    for values in labels {
151        write!(f, "{}    ", offset)?;
152        for (value, width) in values.iter().zip(&widths) {
153            write!(f, "{:^width$}  ", value.isize(), width=width)?;
154        }
155        writeln!(f)?;
156    }
157
158    writeln!(f, "{}}}", offset)
159}
160
161impl Clone for Labels {
162    #[inline]
163    fn clone(&self) -> Self {
164        let ptr = unsafe { crate::c_api::mts_labels_clone(self.ptr) };
165        assert!(!ptr.is_null(), "failed to clone Labels");
166        Labels {
167            ptr,
168            values_cpu_ptr: self.values_cpu_ptr.clone(),
169            count: self.count.clone(),
170            size: self.size.clone(),
171        }
172    }
173}
174
175impl std::ops::Drop for Labels {
176    #[allow(unused_must_use)]
177    fn drop(&mut self) {
178        unsafe {
179            crate::c_api::mts_labels_free(self.ptr);
180        }
181    }
182}
183
184impl Labels {
185    /// Create a new set of Labels with the given names and values.
186    ///
187    /// This is a convenience function replacing the manual use of
188    /// `LabelsBuilder`. If you need more flexibility or incremental `Labels`
189    /// construction, use `LabelsBuilder`.
190    ///
191    /// # Panics
192    ///
193    /// If the set of names is not valid, or any of the value is duplicated
194    #[inline]
195    pub fn new<T, const N: usize>(names: [&str; N], values: &[[T; N]]) -> Labels
196        where T: Copy + Into<LabelValue>
197    {
198        let mut builder = LabelsBuilder::new(names.to_vec());
199        for entry in values {
200            builder.add(entry);
201        }
202        return builder.finish();
203    }
204
205    /// Get a pointer to the underlying `mts_labels_t`
206    pub fn as_mts_labels_t(&self) -> *const mts_labels_t {
207        self.ptr
208    }
209
210    /// Create a new set of `Labels` from a raw `*mut mts_labels_t` pointer.
211    ///
212    /// This function takes ownership of the pointer and will call
213    /// `mts_labels_free` on it when dropped.
214    ///
215    /// # Safety
216    ///
217    /// The pointer must be non-null and returned by one of the metatensor-core
218    /// functions that create `mts_labels_t`.
219    #[inline]
220    pub unsafe fn from_raw(ptr: *const mts_labels_t) -> Labels {
221        assert!(!ptr.is_null(), "expected mts_labels_t pointer to not be NULL");
222        Labels {
223            ptr,
224            values_cpu_ptr: OnceLock::new(),
225            size: OnceLock::new(),
226            count: OnceLock::new(),
227        }
228    }
229
230    /// Create a set of `Labels` with the given names, containing no entries.
231    #[inline]
232    pub fn empty(names: Vec<&str>) -> Labels {
233        return LabelsBuilder::new(names).finish()
234    }
235
236    /// Create a set of `Labels` containing a single entry, to be used when
237    /// there is no relevant information to store.
238    #[inline]
239    pub fn single() -> Labels {
240        let mut builder = LabelsBuilder::new(vec!["_"]);
241        builder.add(&[0]);
242        return builder.finish();
243    }
244
245    /// Load `Labels` from the file at `path`
246    ///
247    /// This is a convenience function calling [`crate::io::load_labels`]
248    pub fn load(path: impl AsRef<std::path::Path>) -> Result<Labels, Error> {
249        return crate::io::load_labels(path);
250    }
251
252    /// Load a `TensorMap` from an in-memory buffer
253    ///
254    /// This is a convenience function calling [`crate::io::load_buffer`]
255    pub fn load_buffer(buffer: &[u8]) -> Result<Labels, Error> {
256        return crate::io::load_labels_buffer(buffer);
257    }
258
259    /// Save the given tensor to the file at `path`
260    ///
261    /// This is a convenience function calling [`crate::io::save`]
262    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
263        return crate::io::save_labels(path, self);
264    }
265
266    /// Save the given tensor to an in-memory buffer
267    ///
268    /// This is a convenience function calling [`crate::io::save_buffer`]
269    pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
270        return crate::io::save_labels_buffer(self, buffer);
271    }
272
273    /// Get the number of entries/named values in a single label
274    #[inline]
275    pub fn size(&self) -> usize {
276        *self.size.get_or_init(|| self.names().len())
277    }
278
279    /// Get the names of the entries/columns in this set of labels
280    #[inline]
281    pub fn names(&self) -> Vec<&str> {
282        let mut names_ptr = std::ptr::null();
283        let mut size = 0;
284        unsafe {
285            check_status(crate::c_api::mts_labels_dimensions(self.ptr, &mut names_ptr, &mut size))
286                .expect("failed to get labels dimensions");
287        }
288
289        if size == 0 {
290            return Vec::new();
291        }
292
293        unsafe {
294            let names = std::slice::from_raw_parts(names_ptr, size);
295            return names.iter()
296                        .map(|&ptr| CStr::from_ptr(ptr).to_str().expect("invalid UTF8"))
297                        .collect();
298        }
299    }
300
301    /// Get the values of these labels as a `MtsArray`
302    pub fn values(&self) -> MtsArray {
303        let mut array = crate::c_api::mts_array_t::null();
304        unsafe {
305            check_status(crate::c_api::mts_labels_values(
306                self.ptr, &mut array,
307            )).expect("failed to get labels values array");
308        }
309
310        return MtsArray::from_raw(array);
311    }
312
313    /// Get the values of these Labels on CPU, potentially copying them from
314    /// another device.
315    pub fn values_cpu(&self) -> &[LabelValue] {
316        let values_cpu_ptr = self.values_cpu_ptr.get_or_init(|| {
317            let mut values_cpu_ptr = std::ptr::null();
318            let mut count = 0;
319            let mut size = 0;
320
321            unsafe {
322                check_status(crate::c_api::mts_labels_values_cpu(
323                    self.ptr,
324                    &mut values_cpu_ptr,
325                    &mut count,
326                    &mut size
327                )).expect("failed to get CPU values for Labels");
328            }
329
330            debug_assert_eq!(count, self.count());
331            debug_assert_eq!(size, self.size());
332
333            values_cpu_ptr.cast()
334        });
335
336        unsafe {
337            let count = self.count();
338            let size = self.size();
339            std::slice::from_raw_parts(*values_cpu_ptr, count * size)
340        }
341    }
342
343    /// Get the total number of entries in this set of labels
344    #[inline]
345    pub fn device(&self) -> dlpk::DLDevice {
346        let array = self.values();
347        return array.device().expect("failed to get the array device");
348    }
349
350    /// Get the total number of entries in this set of labels
351    #[inline]
352    pub fn count(&self) -> usize {
353        return *self.count.get_or_init(|| self.values().shape().expect("failed to get the array shape")[0]);
354    }
355
356    /// Check if this set of Labels is empty (contains no entry)
357    #[inline]
358    pub fn is_empty(&self) -> bool {
359        self.count() == 0
360    }
361
362    /// Check whether the given `label` is part of this set of labels
363    #[inline]
364    pub fn contains(&self, label: &[LabelValue]) -> bool {
365        return self.position(label).is_some();
366    }
367
368    /// Get the position (i.e. row index) of the given label in the full labels
369    /// array, or None.
370    #[inline]
371    pub fn position(&self, value: &[LabelValue]) -> Option<usize> {
372        assert!(value.len() == self.size(), "invalid size of index in Labels::position");
373
374        let mut result = 0;
375        unsafe {
376            check_status(crate::c_api::mts_labels_position(
377                self.ptr,
378                value.as_ptr().cast(),
379                value.len(),
380                &mut result,
381            )).expect("failed to check label position");
382        }
383
384        return result.try_into().ok();
385    }
386
387    /// Take the union of `self` with `other`.
388    ///
389    /// If requested, this function can also give the positions in the union
390    /// where each entry of the input `Labels` ended up.
391    ///
392    /// If `first_mapping` (respectively `second_mapping`) is `Some`, it should
393    /// contain a slice of length `self.count()` (respectively `other.count()`)
394    /// that will be filled with the position of the entries in `self`
395    /// (respectively `other`) in the union.
396    #[inline]
397    pub fn union(
398        &self,
399        other: &Labels,
400        first_mapping: Option<&mut [i64]>,
401        second_mapping: Option<&mut [i64]>,
402    ) -> Result<Labels, Error> {
403        let mut output: *const mts_labels_t = std::ptr::null();
404        let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
405            (m.as_mut_ptr(), m.len())
406        } else {
407            (std::ptr::null_mut(), 0)
408        };
409
410        let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
411            (m.as_mut_ptr(), m.len())
412        } else {
413            (std::ptr::null_mut(), 0)
414        };
415
416        unsafe {
417            check_status(crate::c_api::mts_labels_union(
418                self.ptr,
419                other.ptr,
420                &mut output,
421                first_mapping,
422                first_mapping_count,
423                second_mapping,
424                second_mapping_count,
425            ))?;
426
427            return Ok(Labels::from_raw(output));
428        }
429    }
430
431    /// Take the intersection of self with `other`.
432    ///
433    /// If requested, this function can also give the positions in the
434    /// intersection where each entry of the input `Labels` ended up.
435    ///
436    /// If `first_mapping` (respectively `second_mapping`) is `Some`, it should
437    /// contain a slice of length `self.count()` (respectively `other.count()`)
438    /// that will be filled by with the position of the entries in `self`
439    /// (respectively `other`) in the intersection. If an entry in `self` or
440    /// `other` are not used in the intersection, the mapping for this entry
441    /// will be set to `-1`.
442    #[inline]
443    pub fn intersection(
444        &self,
445        other: &Labels,
446        first_mapping: Option<&mut [i64]>,
447        second_mapping: Option<&mut [i64]>,
448    ) -> Result<Labels, Error> {
449        let mut output: *const mts_labels_t = std::ptr::null();
450        let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
451            (m.as_mut_ptr(), m.len())
452        } else {
453            (std::ptr::null_mut(), 0)
454        };
455
456        let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
457            (m.as_mut_ptr(), m.len())
458        } else {
459            (std::ptr::null_mut(), 0)
460        };
461
462        unsafe {
463            check_status(crate::c_api::mts_labels_intersection(
464                self.ptr,
465                other.ptr,
466                &mut output,
467                first_mapping,
468                first_mapping_count,
469                second_mapping,
470                second_mapping_count,
471            ))?;
472
473            return Ok(Labels::from_raw(output));
474        }
475    }
476
477    /// Take the set difference of `self` with `other`.
478    ///
479    /// If requested, this function can also give the positions in the
480    /// difference where each entry of `self` ended up.
481    ///
482    /// If `mapping` is `Some`, it should contain a slice of length
483    /// `self.count()` that will be filled by with the position of the entries
484    /// in `self` in the difference. If an entry is not used in the difference,
485    ///  the mapping for this entry will be set to `-1`.
486    #[inline]
487    pub fn difference(
488        &self,
489        other: &Labels,
490        mapping: Option<&mut [i64]>,
491    ) -> Result<Labels, Error> {
492        let mut output: *const mts_labels_t = std::ptr::null();
493        let (mapping, mapping_count) = if let Some(m) = mapping {
494            (m.as_mut_ptr(), m.len())
495        } else {
496            (std::ptr::null_mut(), 0)
497        };
498
499        unsafe {
500            check_status(crate::c_api::mts_labels_difference(
501                self.ptr,
502                other.ptr,
503                &mut output,
504                mapping,
505                mapping_count,
506            ))?;
507
508            return Ok(Labels::from_raw(output));
509        }
510    }
511
512    /// Select entries in these `Labels` that match the `selection`.
513    ///
514    /// The selection's names must be a subset of the names of these labels.
515    ///
516    /// All entries in these `Labels` that match one of the entry in the
517    /// `selection` for all the selection's dimension will be picked. Any entry
518    /// in the `selection` but not in these `Labels` will be ignored.
519    #[allow(clippy::cast_possible_truncation)]
520    pub fn select(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
521        let mut selected = vec![0; self.count()];
522        let mut selected_count = selected.len();
523
524        unsafe {
525            check_status(crate::c_api::mts_labels_select(
526                self.as_mts_labels_t(),
527                selection.as_mts_labels_t(),
528                selected.as_mut_ptr(),
529                &mut selected_count
530            ))?;
531        }
532
533        selected.resize(selected_count, 0);
534
535        return Ok(selected.into_iter().map(|s| s as usize).collect());
536    }
537
538    /// Iterate over the entries in this set of labels
539    #[inline]
540    pub fn iter(&self) -> LabelsIter<'_> {
541        return LabelsIter {
542            ptr: self.values_cpu().as_ptr(),
543            cur: 0,
544            len: self.count(),
545            chunk_len: self.size(),
546            phantom: std::marker::PhantomData,
547        };
548    }
549
550    /// Iterate over the entries in this set of labels in parallel
551    #[cfg(feature = "rayon")]
552    #[inline]
553    pub fn par_iter(&self) -> LabelsParIter<'_> {
554        use rayon::prelude::*;
555        return LabelsParIter {
556            chunks: self.values_cpu().par_chunks_exact(self.size())
557        };
558    }
559
560    /// Iterate over the entries in this set of labels as fixed-size arrays
561    #[inline]
562    pub fn iter_fixed_size<const N: usize>(&self) -> LabelsFixedSizeIter<'_, N> {
563        assert!(N == self.size(),
564            "wrong label size in `iter_fixed_size`: the entries contains {} element \
565            but this function was called with size of {}",
566            self.size(), N
567        );
568
569        return LabelsFixedSizeIter {
570            values: self.values_cpu()
571        };
572    }
573}
574
575impl std::cmp::PartialEq<Labels> for Labels {
576    #[inline]
577    fn eq(&self, other: &Labels) -> bool {
578        if self.names() != other.names() {
579            return false;
580        }
581
582        if self.count() != other.count() {
583            return false;
584        }
585
586        if self.device() != other.device() {
587            return false;
588        }
589
590        if self.device().device_type == dlpk::DLDeviceType::kDLExtDev {
591            // kDLExtDev is used for torch's meta device, which has no data
592            // associated, so we consider all Labels as equal as long as they
593            // have the same dimensions and number of entries
594            return true;
595        } else {
596            return self.values_cpu() == other.values_cpu();
597        }
598    }
599}
600
601impl std::ops::Index<usize> for Labels {
602    type Output = [LabelValue];
603
604    #[inline]
605    fn index(&self, i: usize) -> &[LabelValue] {
606        let start = i * self.size();
607        let stop = (i + 1) * self.size();
608        &self.values_cpu()[start..stop]
609    }
610}
611
612/// iterator over [`Labels`] entries
613pub struct LabelsIter<'a> {
614    /// start of the labels values
615    ptr: *const LabelValue,
616    /// Current entry index
617    cur: usize,
618    /// number of entries
619    len: usize,
620    /// size of an entry/the labels
621    chunk_len: usize,
622    phantom: std::marker::PhantomData<&'a LabelValue>,
623}
624
625impl<'a> Iterator for LabelsIter<'a> {
626    type Item = &'a [LabelValue];
627
628    #[inline]
629    fn next(&mut self) -> Option<Self::Item> {
630        if self.cur < self.len {
631            unsafe {
632                // SAFETY: this should be in-bounds
633                let data = self.ptr.add(self.cur * self.chunk_len);
634                self.cur += 1;
635                // SAFETY: the pointer should be valid for 'a
636                Some(std::slice::from_raw_parts(data, self.chunk_len))
637            }
638        } else {
639            None
640        }
641    }
642
643    #[inline]
644    fn size_hint(&self) -> (usize, Option<usize>) {
645        let remaining = self.len - self.cur;
646        return (remaining, Some(remaining));
647    }
648}
649
650impl ExactSizeIterator for LabelsIter<'_> {
651    #[inline]
652    fn len(&self) -> usize {
653        self.len
654    }
655}
656
657impl FusedIterator for LabelsIter<'_> {}
658
659impl<'a> IntoIterator for &'a Labels {
660    type IntoIter = LabelsIter<'a>;
661    type Item = &'a [LabelValue];
662
663    #[inline]
664    fn into_iter(self) -> Self::IntoIter {
665        self.iter()
666    }
667}
668
669/// Parallel iterator over entries in a set of [`Labels`]
670#[cfg(feature = "rayon")]
671#[derive(Debug, Clone)]
672pub struct LabelsParIter<'a> {
673    chunks: rayon::slice::ChunksExact<'a, LabelValue>,
674}
675
676#[cfg(feature = "rayon")]
677impl<'a> rayon::iter::ParallelIterator for LabelsParIter<'a> {
678    type Item = &'a [LabelValue];
679
680    #[inline]
681    fn drive_unindexed<C>(self, consumer: C) -> C::Result
682    where
683        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
684        self.chunks.drive_unindexed(consumer)
685    }
686}
687
688#[cfg(feature = "rayon")]
689impl rayon::iter::IndexedParallelIterator for LabelsParIter<'_> {
690    #[inline]
691    fn len(&self) -> usize {
692        self.chunks.len()
693    }
694
695    #[inline]
696    fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
697        self.chunks.drive(consumer)
698    }
699
700    #[inline]
701    fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
702        self.chunks.with_producer(callback)
703    }
704}
705
706/// Iterator over entries in a set of [`Labels`] as fixed size arrays
707#[derive(Debug, Clone)]
708pub struct LabelsFixedSizeIter<'a, const N: usize> {
709    values: &'a [LabelValue],
710}
711
712impl<'a, const N: usize> Iterator for LabelsFixedSizeIter<'a, N> {
713    type Item = &'a [LabelValue; N];
714
715    #[inline]
716    fn next(&mut self) -> Option<Self::Item> {
717        if self.values.is_empty() {
718            return None
719        }
720
721        let (value, rest) = self.values.split_at(N);
722        self.values = rest;
723        return Some(value.try_into().expect("wrong size in FixedSizeIter::next"));
724    }
725
726    fn size_hint(&self) -> (usize, Option<usize>) {
727        (self.len(), Some(self.len()))
728    }
729}
730
731impl<const N: usize> ExactSizeIterator for LabelsFixedSizeIter<'_, N> {
732    #[inline]
733    fn len(&self) -> usize {
734        self.values.len() / N
735    }
736}
737
738/// Builder for [`Labels`]
739#[derive(Debug, Clone)]
740pub struct LabelsBuilder {
741    // cf `Labels` for the documentation of the fields
742    names: Vec<String>,
743    values: Vec<i32>,
744}
745
746impl LabelsBuilder {
747    /// Create a new empty `LabelsBuilder` with the given `names`
748    #[inline]
749    pub fn new(names: Vec<&str>) -> LabelsBuilder {
750        let n_unique_names = names.iter().collect::<BTreeSet<_>>().len();
751        assert!(n_unique_names == names.len(), "invalid labels: the same name is used multiple times");
752
753        LabelsBuilder {
754            names: names.into_iter().map(|s| s.into()).collect(),
755            values: Vec::new(),
756        }
757    }
758
759    /// Reserve space for `additional` other entries in the labels.
760    #[inline]
761    pub fn reserve(&mut self, additional: usize) {
762        self.values.reserve(additional * self.names.len());
763    }
764
765    /// Get the number of labels in a single value
766    #[inline]
767    pub fn size(&self) -> usize {
768        self.names.len()
769    }
770
771    /// Add a single `entry` to this set of labels.
772    ///
773    /// This function will panic when attempting to add the same `label` more
774    /// than once.
775    #[inline]
776    pub fn add<T>(&mut self, entry: &[T]) where T: Clone + Into<LabelValue> {
777        assert_eq!(
778            self.size(), entry.len(),
779            "wrong size for added label: got {}, but expected {}",
780            entry.len(), self.size()
781        );
782
783        // SmallVec allows us to convert everything to `LabelValue` without
784        // requiring an extra heap allocation
785        for e in entry {
786            self.values.push(Into::<LabelValue>::into(e.clone()).i32());
787        }
788    }
789
790    /// Common implementation for `finish` and `finish_unchecked`.
791    fn finish_with(
792        self,
793        creator: unsafe extern "C" fn(
794            *const *const std::os::raw::c_char,
795            usize,
796            crate::c_api::mts_array_t,
797        ) -> *const mts_labels_t
798    ) -> Labels {
799        let mut raw_names = Vec::new();
800        let mut raw_names_ptr = Vec::new();
801
802        for name in &self.names {
803            let name = CString::new(&**name).expect("name contains a NULL byte");
804            raw_names_ptr.push(name.as_ptr());
805            raw_names.push(name);
806        }
807
808        let size = raw_names_ptr.len();
809        let count = if size == 0 {
810            assert!(self.values.is_empty());
811            0
812        } else {
813            self.values.len() / size
814        };
815
816        // Wrap raw values in an ndarray-backed mts_array_t
817        let array = ndarray::Array::from_shape_vec(vec![count, size], self.values)
818                .expect("shape mismatch when creating labels array");
819        let array: MtsArray = array.into();
820
821        let ptr = unsafe {
822            creator(
823                raw_names_ptr.as_ptr(),
824                size,
825                array.into_raw(),
826            )
827        };
828        check_ptr(ptr).expect("invalid labels");
829
830        unsafe { Labels::from_raw(ptr) }
831    }
832
833    /// Finish building the `Labels`.
834    ///
835    /// This function checks that all entries in the labels are unique.
836    #[inline]
837    pub fn finish(self) -> Labels {
838        self.finish_with(crate::c_api::mts_labels)
839    }
840
841    /// Finish building the `Labels`, assuming that all entries are unique.
842    ///
843    /// This is faster than `finish` as it does not perform a uniqueness check
844    /// on the labels entries. It is the caller's responsibility to ensure that
845    /// entries are unique.
846    ///
847    /// # Panics
848    ///
849    /// If the set of names is not valid (contains duplicates or invalid names).
850    #[inline]
851    pub fn finish_assume_unique(self) -> Labels {
852        self.finish_with(crate::c_api::mts_labels_assume_unique)
853    }
854}
855
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860
861    #[test]
862    fn labels() {
863        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
864        builder.add(&[2, 3]);
865        builder.add(&[1, 243]);
866        builder.add(&[-4, -2413]);
867
868        let labels = builder.finish();
869        assert_eq!(labels.names(), &["foo", "bar"]);
870        assert_eq!(labels.size(), 2);
871        assert_eq!(labels.count(), 3);
872        assert!(!labels.is_empty());
873
874        assert_eq!(labels[0], [2, 3]);
875        assert_eq!(labels[1], [1, 243]);
876        assert_eq!(labels[2], [-4, -2413]);
877
878        let builder = LabelsBuilder::new(vec![]);
879        let labels = builder.finish();
880        assert_eq!(labels.size(), 0);
881        assert_eq!(labels.count(), 0);
882
883        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
884        builder.add(&[2, 3]);
885        builder.add(&[1, 243]);
886        let labels = builder.finish_assume_unique();
887        assert_eq!(labels.names(), &["foo", "bar"]);
888        assert_eq!(labels.size(), 2);
889        assert_eq!(labels.count(), 2);
890    }
891
892    #[test]
893    fn direct_construct() {
894        let labels = Labels::new(
895            ["foo", "bar"],
896            &[
897                [2, 3],
898                [1, 243],
899                [-4, -2413],
900            ]
901        );
902
903        assert_eq!(labels.names(), &["foo", "bar"]);
904        assert_eq!(labels.size(), 2);
905        assert_eq!(labels.count(), 3);
906
907        assert_eq!(labels[0], [2, 3]);
908        assert_eq!(labels[1], [1, 243]);
909        assert_eq!(labels[2], [-4, -2413]);
910    }
911
912    #[test]
913    fn iter() {
914        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
915        builder.add(&[2, 3]);
916        builder.add(&[1, 2]);
917        builder.add(&[4, 3]);
918
919        let labels = builder.finish();
920        let mut iter = labels.iter();
921        assert_eq!(iter.len(), 3);
922
923        assert_eq!(iter.next().unwrap(), &[2, 3]);
924        assert_eq!(iter.next().unwrap(), &[1, 2]);
925        assert_eq!(iter.next().unwrap(), &[4, 3]);
926        assert_eq!(iter.next(), None);
927    }
928
929    #[cfg( feature = "rayon")]
930    #[test]
931    fn par_iter() {
932        use rayon::iter::IndexedParallelIterator;
933
934        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
935        builder.add(&[2, 3]);
936        builder.add(&[1, 2]);
937        builder.add(&[4, 3]);
938
939        let labels = builder.finish();
940        let iter = labels.par_iter();
941        assert_eq!(iter.len(), 3);
942
943        let mut values = Vec::new();
944        iter.collect_into_vec(&mut values);
945
946        assert_eq!(values, [&[2, 3], &[1, 2], &[4, 3]]);
947    }
948
949    #[test]
950    fn iter_fixed_size() {
951        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
952        builder.add(&[1, 2]);
953        builder.add(&[2, 3]);
954
955        let labels = builder.finish();
956
957        for (i, [a, b]) in labels.iter_fixed_size().enumerate() {
958            assert_eq!(a.usize(), 1 + i);
959            assert_eq!(b.usize(), 2 + i);
960        }
961    }
962
963    #[test]
964    #[should_panic(expected = "wrong label size in `iter_fixed_size`: the entries contains 2 element but this function was called with size of 3")]
965    fn iter_fixed_size_wrong_size() {
966        let labels = LabelsBuilder::new(vec!["foo", "bar"]).finish();
967
968        for [_, _, _] in labels.iter_fixed_size() {}
969    }
970
971    #[test]
972    #[should_panic(expected = "'33 bar' is not a valid label name")]
973    fn invalid_label_name() {
974        LabelsBuilder::new(vec!["foo", "33 bar"]).finish();
975    }
976
977    #[test]
978    #[should_panic(expected = "invalid labels: the same name is used multiple times")]
979    fn duplicated_label_name() {
980        LabelsBuilder::new(vec!["foo", "bar", "foo"]).finish();
981    }
982
983    #[test]
984    #[should_panic(expected = "can not have the same label entry multiple times: [0, 1] is already present")]
985    fn duplicated_label_entry() {
986        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
987        builder.add(&[0, 1]);
988        builder.add(&[0, 1]);
989        builder.finish();
990    }
991
992    #[test]
993    fn single_label() {
994        let labels = Labels::single();
995        assert_eq!(labels.names(), &["_"]);
996        assert_eq!(labels.size(), 1);
997        assert_eq!(labels.count(), 1);
998    }
999
1000    #[test]
1001    fn empty_label() {
1002        let labels = LabelsBuilder::new(vec!["foo", "bar"]).finish();
1003
1004        assert!(labels.is_empty());
1005        assert_eq!(labels.count(), 0);
1006        assert_eq!(labels.size(), 2);
1007    }
1008
1009    #[test]
1010    fn position() {
1011        let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
1012        builder.add(&[1, 2]);
1013        builder.add(&[2, 3]);
1014        let labels = builder.finish();
1015
1016        assert!(labels.contains(&[LabelValue::new(1), LabelValue::new(2)]));
1017        assert_eq!(labels.position(&[LabelValue::new(1), LabelValue::new(2)]), Some(0));
1018
1019        assert!(labels.contains(&[LabelValue::new(2), LabelValue::new(3)]));
1020        assert_eq!(labels.position(&[LabelValue::new(2), LabelValue::new(3)]), Some(1));
1021
1022        assert!(!labels.contains(&[LabelValue::new(3), LabelValue::new(3)]));
1023        assert_eq!(labels.position(&[LabelValue::new(3), LabelValue::new(3)]), None);
1024    }
1025
1026    #[test]
1027    fn indexing() {
1028        let labels = Labels::new(
1029            ["foo", "bar"],
1030            &[
1031                [2, 3],
1032                [1, 243],
1033                [-4, -2413],
1034            ]
1035        );
1036
1037        assert_eq!(labels[1], [1, 243]);
1038        assert_eq!(labels[2], [-4, -2413]);
1039    }
1040
1041    #[test]
1042    fn debug() {
1043        let labels = Labels::new(
1044            ["foo", "bar"],
1045            &[
1046                [2, 3],
1047                [1, 243],
1048                [-4, -2413],
1049            ]
1050        );
1051
1052        let expected = format!(
1053            "Labels @ {:p} {{\n    foo, bar\n     2    3   \n     1   243  \n    -4   -2413  \n}}\n",
1054            labels.ptr
1055        );
1056        assert_eq!(format!("{:?}", labels), expected);
1057    }
1058
1059    #[test]
1060    fn union() {
1061        let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
1062        let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
1063
1064        let mut first_mapping = vec![0; first.count()];
1065        let mut second_mapping = vec![0; second.count()];
1066        let union = first.union(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
1067
1068        assert_eq!(union.names(), ["aa", "bb"]);
1069        assert_eq!(union.values_cpu(), [0, 1, 1, 2, 2, 3, 4, 5]);
1070
1071        assert_eq!(first_mapping, [0, 1]);
1072        assert_eq!(second_mapping, [2, 1, 3]);
1073    }
1074
1075    #[test]
1076    fn intersection() {
1077        let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
1078        let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
1079
1080        let mut first_mapping = vec![0_i64; first.count()];
1081        let mut second_mapping = vec![0_i64; second.count()];
1082        let union = first.intersection(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
1083
1084        assert_eq!(union.names(), ["aa", "bb"]);
1085        assert_eq!(union.values_cpu(), [1, 2]);
1086
1087        assert_eq!(first_mapping, [-1, 0]);
1088        assert_eq!(second_mapping, [-1, 0, -1]);
1089    }
1090
1091    #[test]
1092    fn difference() {
1093        let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
1094        let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
1095
1096        let mut mapping = vec![0_i64; first.count()];
1097        let union = first.difference(&second, Some(&mut mapping)).unwrap();
1098
1099        assert_eq!(union.names(), ["aa", "bb"]);
1100        assert_eq!(union.values_cpu(), [0, 1]);
1101
1102        assert_eq!(mapping, [0, -1]);
1103    }
1104
1105    #[test]
1106    fn selection() {
1107        // selection with a subset of names
1108        let labels = Labels::new(["aa", "bb"], &[[1, 1], [1, 2], [3, 2], [2, 1]]);
1109        let selection = Labels::new(["aa"], &[[1], [2], [5]]);
1110
1111        let selected = labels.select(&selection).unwrap();
1112        assert_eq!(selected, [0, 1, 3]);
1113
1114        // selection with the same names
1115        let selection = Labels::new(["aa", "bb"], &[[1, 1], [2, 1], [5, 1], [1, 2]]);
1116        let selected = labels.select(&selection).unwrap();
1117        assert_eq!(selected, [0, 3, 1]);
1118
1119        // empty selection
1120        let selection = Labels::empty(vec!["aa"]);
1121        let selected = labels.select(&selection).unwrap();
1122        assert_eq!(selected, []);
1123
1124        // invalid selection names
1125        let selection = Labels::empty(vec!["aaa"]);
1126        let err = labels.select(&selection).unwrap_err();
1127        assert_eq!(err.message,
1128            "invalid parameter: 'aaa' in selection is not part of these Labels"
1129        );
1130    }
1131}