1use std::sync::OnceLock;
2use std::ffi::{CStr, CString};
3use std::collections::BTreeSet;
4use std::iter::FusedIterator;
5
6use crate::MtsArray;
7use crate::c_api::mts_labels_t;
8use crate::errors::check_ptr;
9use crate::errors::{Error, check_status};
10
11#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[repr(transparent)]
17pub struct LabelValue(i32);
18
19impl PartialEq<i32> for LabelValue {
20 #[inline]
21 fn eq(&self, other: &i32) -> bool {
22 self.0 == *other
23 }
24}
25
26impl PartialEq<LabelValue> for i32 {
27 #[inline]
28 fn eq(&self, other: &LabelValue) -> bool {
29 *self == other.0
30 }
31}
32
33impl std::fmt::Debug for LabelValue {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38
39impl std::fmt::Display for LabelValue {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
46impl From<u32> for LabelValue {
47 #[inline]
48 fn from(value: u32) -> LabelValue {
49 assert!(value < i32::MAX as u32);
50 LabelValue(value as i32)
51 }
52}
53
54impl From<i32> for LabelValue {
55 #[inline]
56 fn from(value: i32) -> LabelValue {
57 LabelValue(value)
58 }
59}
60
61#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
62impl From<usize> for LabelValue {
63 #[inline]
64 fn from(value: usize) -> LabelValue {
65 assert!(value < i32::MAX as usize);
66 LabelValue(value as i32)
67 }
68}
69
70#[allow(clippy::cast_possible_truncation)]
71impl From<isize> for LabelValue {
72 #[inline]
73 fn from(value: isize) -> LabelValue {
74 assert!(value < i32::MAX as isize && value > i32::MIN as isize);
75 LabelValue(value as i32)
76 }
77}
78
79impl LabelValue {
80 #[inline]
82 pub fn new(value: i32) -> LabelValue {
83 LabelValue(value)
84 }
85
86 #[inline]
88 #[allow(clippy::cast_sign_loss)]
89 pub fn usize(self) -> usize {
90 debug_assert!(self.0 >= 0);
91 self.0 as usize
92 }
93
94 #[inline]
96 pub fn isize(self) -> isize {
97 self.0 as isize
98 }
99
100 #[inline]
102 pub fn i32(self) -> i32 {
103 self.0
104 }
105}
106
107pub struct Labels {
120 pub(crate) ptr: *const mts_labels_t,
121 values_cpu_ptr: OnceLock<*const LabelValue>,
122 count: OnceLock<usize>,
123 size: OnceLock<usize>,
124}
125
126unsafe impl Send for Labels {}
129unsafe impl Sync for Labels {}
132
133impl std::fmt::Debug for Labels {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 pretty_print_labels(self, "", f)
136 }
137}
138
139pub(crate) fn pretty_print_labels(
141 labels: &Labels,
142 offset: &str,
143 f: &mut std::fmt::Formatter<'_>
144) -> std::fmt::Result {
145 let names = labels.names();
146
147 writeln!(f, "Labels @ {:p} {{", labels.ptr)?;
148 writeln!(f, "{} {}", offset, names.join(", "))?;
149
150 let widths = names.iter().map(|s| s.len()).collect::<Vec<_>>();
151 for values in labels {
152 write!(f, "{} ", offset)?;
153 for (value, width) in values.iter().zip(&widths) {
154 write!(f, "{:^width$} ", value.isize(), width=width)?;
155 }
156 writeln!(f)?;
157 }
158
159 writeln!(f, "{}}}", offset)
160}
161
162impl Clone for Labels {
163 #[inline]
164 fn clone(&self) -> Self {
165 let ptr = unsafe { crate::c_api::mts_labels_clone(self.ptr) };
166 assert!(!ptr.is_null(), "failed to clone Labels");
167 Labels {
168 ptr,
169 values_cpu_ptr: self.values_cpu_ptr.clone(),
170 count: self.count.clone(),
171 size: self.size.clone(),
172 }
173 }
174}
175
176impl std::ops::Drop for Labels {
177 #[allow(unused_must_use)]
178 fn drop(&mut self) {
179 unsafe {
180 crate::c_api::mts_labels_free(self.ptr);
181 }
182 }
183}
184
185impl Labels {
186 #[inline]
196 pub fn new<'a>(names: impl AsRef<[&'a str]>, values: impl Into<MtsArray>) -> Labels {
197 Self::new_impl(names.as_ref(), values, crate::c_api::mts_labels)
198 }
199
200 #[inline]
211 pub fn new_assume_unique<'a>(names: impl AsRef<[&'a str]>, values: impl Into<MtsArray>) -> Labels {
212 Self::new_impl(names.as_ref(), values, crate::c_api::mts_labels_assume_unique)
213 }
214
215 fn new_impl(
217 names: &[&str],
218 values: impl Into<MtsArray>,
219 creator: unsafe extern "C" fn(
220 *const *const std::os::raw::c_char,
221 usize,
222 crate::c_api::mts_array_t,
223 ) -> *const mts_labels_t
224 ) -> Labels {
225 let n_unique_names = names.iter().collect::<BTreeSet<_>>().len();
226 assert!(n_unique_names == names.len(), "invalid labels: the same name is used multiple times");
227
228 let mut raw_names = Vec::new();
229 let mut raw_names_ptr = Vec::new();
230 for name in names {
231 let c_name = CString::new(*name).expect("name contains a NULL byte");
232 raw_names_ptr.push(c_name.as_ptr());
233 raw_names.push(c_name);
234 }
235
236 let array: MtsArray = values.into();
237 let ptr = unsafe {
238 creator(
239 raw_names_ptr.as_ptr(),
240 raw_names.len(),
241 array.into_raw(),
242 )
243 };
244 check_ptr(ptr).expect("invalid labels");
245
246 unsafe { Labels::from_raw(ptr) }
247 }
248
249 pub fn as_mts_labels_t(&self) -> *const mts_labels_t {
251 self.ptr
252 }
253
254 #[inline]
264 pub unsafe fn from_raw(ptr: *const mts_labels_t) -> Labels {
265 assert!(!ptr.is_null(), "expected mts_labels_t pointer to not be NULL");
266 Labels {
267 ptr,
268 values_cpu_ptr: OnceLock::new(),
269 size: OnceLock::new(),
270 count: OnceLock::new(),
271 }
272 }
273
274 #[inline]
281 pub fn into_raw(mut labels: Labels) -> *const mts_labels_t {
282 return std::mem::replace(&mut labels.ptr, std::ptr::null());
283 }
284
285 #[inline]
287 pub fn empty<'a>(names: impl AsRef<[&'a str]>) -> Labels {
288 let names = names.as_ref();
289 let array = ndarray::Array::<i32, _>::from_shape_vec(
290 vec![0, names.len()], vec![]
291 ).expect("shape mismatch when creating empty labels array");
292 Labels::new(names, array)
293 }
294
295 #[inline]
298 pub fn single() -> Labels {
299 Labels::new(["_"], vec![[0i32]])
300 }
301
302 pub fn load(path: impl AsRef<std::path::Path>) -> Result<Labels, Error> {
306 return crate::io::load_labels(path);
307 }
308
309 pub fn load_buffer(buffer: &[u8]) -> Result<Labels, Error> {
313 return crate::io::load_labels_buffer(buffer);
314 }
315
316 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
320 return crate::io::save_labels(path, self);
321 }
322
323 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
327 return crate::io::save_labels_buffer(self, buffer);
328 }
329
330 #[inline]
332 pub fn size(&self) -> usize {
333 *self.size.get_or_init(|| self.names().len())
334 }
335
336 #[inline]
338 pub fn names(&self) -> Vec<&str> {
339 let mut names_ptr = std::ptr::null();
340 let mut size = 0;
341 unsafe {
342 check_status(crate::c_api::mts_labels_dimensions(self.ptr, &mut names_ptr, &mut size))
343 .expect("failed to get labels dimensions");
344 }
345
346 if size == 0 {
347 return Vec::new();
348 }
349
350 unsafe {
351 let names = std::slice::from_raw_parts(names_ptr, size);
352 return names.iter()
353 .map(|&ptr| CStr::from_ptr(ptr).to_str().expect("invalid UTF8"))
354 .collect();
355 }
356 }
357
358 pub fn values(&self) -> MtsArray {
360 let mut array = crate::c_api::mts_array_t::null();
361 unsafe {
362 check_status(crate::c_api::mts_labels_values(
363 self.ptr, &mut array,
364 )).expect("failed to get labels values array");
365 }
366
367 return MtsArray::from_raw(array);
368 }
369
370 pub fn values_cpu(&self) -> &[LabelValue] {
373 let values_cpu_ptr = self.values_cpu_ptr.get_or_init(|| {
374 let mut values_cpu_ptr = std::ptr::null();
375 let mut count = 0;
376 let mut size = 0;
377
378 unsafe {
379 check_status(crate::c_api::mts_labels_values_cpu(
380 self.ptr,
381 &mut values_cpu_ptr,
382 &mut count,
383 &mut size
384 )).expect("failed to get CPU values for Labels");
385 }
386
387 debug_assert_eq!(count, self.count());
388 debug_assert_eq!(size, self.size());
389
390 values_cpu_ptr.cast()
391 });
392
393 unsafe {
394 let count = self.count();
395 let size = self.size();
396 std::slice::from_raw_parts(*values_cpu_ptr, count * size)
397 }
398 }
399
400 #[inline]
402 pub fn device(&self) -> dlpk::DLDevice {
403 let array = self.values();
404 return array.device().expect("failed to get the array device");
405 }
406
407 #[inline]
409 pub fn count(&self) -> usize {
410 return *self.count.get_or_init(|| self.values().shape().expect("failed to get the array shape")[0]);
411 }
412
413 #[inline]
415 pub fn is_empty(&self) -> bool {
416 self.count() == 0
417 }
418
419 #[inline]
421 pub fn contains(&self, label: &[LabelValue]) -> bool {
422 return self.position(label).is_some();
423 }
424
425 #[inline]
428 pub fn position(&self, value: &[LabelValue]) -> Option<usize> {
429 assert!(value.len() == self.size(), "invalid size of index in Labels::position");
430
431 let mut result = 0;
432 unsafe {
433 check_status(crate::c_api::mts_labels_position(
434 self.ptr,
435 value.as_ptr().cast(),
436 value.len(),
437 &mut result,
438 )).expect("failed to check label position");
439 }
440
441 return result.try_into().ok();
442 }
443
444 #[inline]
454 pub fn union(
455 &self,
456 other: &Labels,
457 first_mapping: Option<&mut [i64]>,
458 second_mapping: Option<&mut [i64]>,
459 ) -> Result<Labels, Error> {
460 let mut output: *const mts_labels_t = std::ptr::null();
461 let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
462 (m.as_mut_ptr(), m.len())
463 } else {
464 (std::ptr::null_mut(), 0)
465 };
466
467 let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
468 (m.as_mut_ptr(), m.len())
469 } else {
470 (std::ptr::null_mut(), 0)
471 };
472
473 unsafe {
474 check_status(crate::c_api::mts_labels_union(
475 self.ptr,
476 other.ptr,
477 &mut output,
478 first_mapping,
479 first_mapping_count,
480 second_mapping,
481 second_mapping_count,
482 ))?;
483
484 return Ok(Labels::from_raw(output));
485 }
486 }
487
488 #[inline]
500 pub fn intersection(
501 &self,
502 other: &Labels,
503 first_mapping: Option<&mut [i64]>,
504 second_mapping: Option<&mut [i64]>,
505 ) -> Result<Labels, Error> {
506 let mut output: *const mts_labels_t = std::ptr::null();
507 let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
508 (m.as_mut_ptr(), m.len())
509 } else {
510 (std::ptr::null_mut(), 0)
511 };
512
513 let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
514 (m.as_mut_ptr(), m.len())
515 } else {
516 (std::ptr::null_mut(), 0)
517 };
518
519 unsafe {
520 check_status(crate::c_api::mts_labels_intersection(
521 self.ptr,
522 other.ptr,
523 &mut output,
524 first_mapping,
525 first_mapping_count,
526 second_mapping,
527 second_mapping_count,
528 ))?;
529
530 return Ok(Labels::from_raw(output));
531 }
532 }
533
534 #[inline]
544 pub fn difference(
545 &self,
546 other: &Labels,
547 mapping: Option<&mut [i64]>,
548 ) -> Result<Labels, Error> {
549 let mut output: *const mts_labels_t = std::ptr::null();
550 let (mapping, mapping_count) = if let Some(m) = mapping {
551 (m.as_mut_ptr(), m.len())
552 } else {
553 (std::ptr::null_mut(), 0)
554 };
555
556 unsafe {
557 check_status(crate::c_api::mts_labels_difference(
558 self.ptr,
559 other.ptr,
560 &mut output,
561 mapping,
562 mapping_count,
563 ))?;
564
565 return Ok(Labels::from_raw(output));
566 }
567 }
568
569 #[allow(clippy::cast_possible_truncation)]
577 pub fn select(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
578 let mut selected = vec![0; self.count()];
579 let mut selected_count = selected.len();
580
581 unsafe {
582 check_status(crate::c_api::mts_labels_select(
583 self.as_mts_labels_t(),
584 selection.as_mts_labels_t(),
585 selected.as_mut_ptr(),
586 &mut selected_count
587 ))?;
588 }
589
590 selected.resize(selected_count, 0);
591
592 return Ok(selected.into_iter().map(|s| s as usize).collect());
593 }
594
595 #[inline]
597 pub fn iter(&self) -> LabelsIter<'_> {
598 return LabelsIter {
599 ptr: self.values_cpu().as_ptr(),
600 cur: 0,
601 len: self.count(),
602 chunk_len: self.size(),
603 phantom: std::marker::PhantomData,
604 };
605 }
606
607 #[cfg(feature = "rayon")]
609 #[inline]
610 pub fn par_iter(&self) -> LabelsParIter<'_> {
611 use rayon::prelude::*;
612 return LabelsParIter {
613 chunks: self.values_cpu().par_chunks_exact(self.size())
614 };
615 }
616
617 #[inline]
619 pub fn iter_fixed_size<const N: usize>(&self) -> LabelsFixedSizeIter<'_, N> {
620 assert!(N == self.size(),
621 "wrong label size in `iter_fixed_size`: the entries contains {} element \
622 but this function was called with size of {}",
623 self.size(), N
624 );
625
626 return LabelsFixedSizeIter {
627 values: self.values_cpu()
628 };
629 }
630}
631
632impl std::cmp::PartialEq<Labels> for Labels {
633 #[inline]
634 fn eq(&self, other: &Labels) -> bool {
635 if self.names() != other.names() {
636 return false;
637 }
638
639 if self.count() != other.count() {
640 return false;
641 }
642
643 if self.device() != other.device() {
644 return false;
645 }
646
647 if self.device().device_type == dlpk::DLDeviceType::kDLExtDev {
648 return true;
652 } else {
653 return self.values_cpu() == other.values_cpu();
654 }
655 }
656}
657
658impl std::ops::Index<usize> for Labels {
659 type Output = [LabelValue];
660
661 #[inline]
662 fn index(&self, i: usize) -> &[LabelValue] {
663 let start = i * self.size();
664 let stop = (i + 1) * self.size();
665 &self.values_cpu()[start..stop]
666 }
667}
668
669pub struct LabelsIter<'a> {
671 ptr: *const LabelValue,
673 cur: usize,
675 len: usize,
677 chunk_len: usize,
679 phantom: std::marker::PhantomData<&'a LabelValue>,
680}
681
682impl<'a> Iterator for LabelsIter<'a> {
683 type Item = &'a [LabelValue];
684
685 #[inline]
686 fn next(&mut self) -> Option<Self::Item> {
687 if self.cur < self.len {
688 unsafe {
689 let data = self.ptr.add(self.cur * self.chunk_len);
691 self.cur += 1;
692 Some(std::slice::from_raw_parts(data, self.chunk_len))
694 }
695 } else {
696 None
697 }
698 }
699
700 #[inline]
701 fn size_hint(&self) -> (usize, Option<usize>) {
702 let remaining = self.len - self.cur;
703 return (remaining, Some(remaining));
704 }
705}
706
707impl ExactSizeIterator for LabelsIter<'_> {
708 #[inline]
709 fn len(&self) -> usize {
710 self.len
711 }
712}
713
714impl FusedIterator for LabelsIter<'_> {}
715
716impl<'a> IntoIterator for &'a Labels {
717 type IntoIter = LabelsIter<'a>;
718 type Item = &'a [LabelValue];
719
720 #[inline]
721 fn into_iter(self) -> Self::IntoIter {
722 self.iter()
723 }
724}
725
726#[cfg(feature = "rayon")]
728#[derive(Debug, Clone)]
729pub struct LabelsParIter<'a> {
730 chunks: rayon::slice::ChunksExact<'a, LabelValue>,
731}
732
733#[cfg(feature = "rayon")]
734impl<'a> rayon::iter::ParallelIterator for LabelsParIter<'a> {
735 type Item = &'a [LabelValue];
736
737 #[inline]
738 fn drive_unindexed<C>(self, consumer: C) -> C::Result
739 where
740 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
741 self.chunks.drive_unindexed(consumer)
742 }
743}
744
745#[cfg(feature = "rayon")]
746impl rayon::iter::IndexedParallelIterator for LabelsParIter<'_> {
747 #[inline]
748 fn len(&self) -> usize {
749 self.chunks.len()
750 }
751
752 #[inline]
753 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
754 self.chunks.drive(consumer)
755 }
756
757 #[inline]
758 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
759 self.chunks.with_producer(callback)
760 }
761}
762
763#[derive(Debug, Clone)]
765pub struct LabelsFixedSizeIter<'a, const N: usize> {
766 values: &'a [LabelValue],
767}
768
769impl<'a, const N: usize> Iterator for LabelsFixedSizeIter<'a, N> {
770 type Item = &'a [LabelValue; N];
771
772 #[inline]
773 fn next(&mut self) -> Option<Self::Item> {
774 if self.values.is_empty() {
775 return None
776 }
777
778 let (value, rest) = self.values.split_at(N);
779 self.values = rest;
780 return Some(value.try_into().expect("wrong size in FixedSizeIter::next"));
781 }
782
783 fn size_hint(&self) -> (usize, Option<usize>) {
784 (self.len(), Some(self.len()))
785 }
786}
787
788impl<const N: usize> ExactSizeIterator for LabelsFixedSizeIter<'_, N> {
789 #[inline]
790 fn len(&self) -> usize {
791 self.values.len() / N
792 }
793}
794
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799
800 #[test]
801 fn labels() {
802 let labels = Labels::new(["foo", "bar"], vec![[2, 3], [1, 243], [-4, -2413]]);
803 assert_eq!(labels.names(), &["foo", "bar"]);
804 assert_eq!(labels.size(), 2);
805 assert_eq!(labels.count(), 3);
806 assert!(!labels.is_empty());
807
808 assert_eq!(labels[0], [2, 3]);
809 assert_eq!(labels[1], [1, 243]);
810 assert_eq!(labels[2], [-4, -2413]);
811
812 let labels = Labels::new(&[] as &[&str], Vec::<[i32; 0]>::new());
813 assert_eq!(labels.size(), 0);
814 assert_eq!(labels.count(), 0);
815
816 let labels = Labels::new_assume_unique(["foo", "bar"], vec![[2, 3], [1, 243]]);
817 assert_eq!(labels.names(), &["foo", "bar"]);
818 assert_eq!(labels.size(), 2);
819 assert_eq!(labels.count(), 2);
820 }
821
822 #[test]
823 fn direct_construct() {
824 let labels = Labels::new(
825 ["foo", "bar"],
826 vec![[2, 3], [1, 243], [-4, -2413]],
827 );
828
829 assert_eq!(labels.names(), &["foo", "bar"]);
830 assert_eq!(labels.size(), 2);
831 assert_eq!(labels.count(), 3);
832
833 assert_eq!(labels[0], [2, 3]);
834 assert_eq!(labels[1], [1, 243]);
835 assert_eq!(labels[2], [-4, -2413]);
836 }
837
838 #[test]
839 fn iter() {
840 let labels = Labels::new(["foo", "bar"], vec![[2, 3], [1, 2], [4, 3]]);
841 let mut iter = labels.iter();
842 assert_eq!(iter.len(), 3);
843
844 assert_eq!(iter.next().unwrap(), &[2, 3]);
845 assert_eq!(iter.next().unwrap(), &[1, 2]);
846 assert_eq!(iter.next().unwrap(), &[4, 3]);
847 assert_eq!(iter.next(), None);
848 }
849
850 #[cfg( feature = "rayon")]
851 #[test]
852 fn par_iter() {
853 use rayon::iter::IndexedParallelIterator;
854
855 let labels = Labels::new(["foo", "bar"], vec![[2, 3], [1, 2], [4, 3]]);
856 let iter = labels.par_iter();
857 assert_eq!(iter.len(), 3);
858
859 let mut values = Vec::new();
860 iter.collect_into_vec(&mut values);
861
862 assert_eq!(values, [&[2, 3], &[1, 2], &[4, 3]]);
863 }
864
865 #[test]
866 fn iter_fixed_size() {
867 let labels = Labels::new(["foo", "bar"], vec![[1, 2], [2, 3]]);
868
869 for (i, [a, b]) in labels.iter_fixed_size().enumerate() {
870 assert_eq!(a.usize(), 1 + i);
871 assert_eq!(b.usize(), 2 + i);
872 }
873 }
874
875 #[test]
876 #[should_panic(expected = "wrong label size in `iter_fixed_size`: the entries contains 2 element but this function was called with size of 3")]
877 fn iter_fixed_size_wrong_size() {
878 let labels = Labels::new(["foo", "bar"], Vec::<[i32; 2]>::new());
879
880 for [_, _, _] in labels.iter_fixed_size() {}
881 }
882
883 #[test]
884 #[should_panic(expected = "invalid parameter: '33 bar' is not a valid label name")]
885 fn invalid_label_name() {
886 Labels::new(["foo", "33 bar"], Vec::<[i32; 2]>::new());
887 }
888
889 #[test]
890 #[should_panic(expected = "invalid labels: the same name is used multiple times")]
891 fn duplicated_label_name() {
892 Labels::new(["foo", "bar", "foo"], Vec::<[i32; 3]>::new());
893 }
894
895 #[test]
896 #[should_panic(expected = "can not have the same label entry multiple times: [0, 1] is already present")]
897 fn duplicated_label_entry() {
898 Labels::new(["foo", "bar"], vec![[0, 1], [0, 1]]);
899 }
900
901 #[test]
902 fn single_label() {
903 let labels = Labels::single();
904 assert_eq!(labels.names(), &["_"]);
905 assert_eq!(labels.size(), 1);
906 assert_eq!(labels.count(), 1);
907 }
908
909 #[test]
910 fn empty_label() {
911 let labels = Labels::empty(vec!["foo", "bar"]);
912
913 assert!(labels.is_empty());
914 assert_eq!(labels.count(), 0);
915 assert_eq!(labels.size(), 2);
916 }
917
918 #[test]
919 fn position() {
920 let labels = Labels::new(["foo", "bar"], vec![[1, 2], [2, 3]]);
921
922 assert!(labels.contains(&[LabelValue::new(1), LabelValue::new(2)]));
923 assert_eq!(labels.position(&[LabelValue::new(1), LabelValue::new(2)]), Some(0));
924
925 assert!(labels.contains(&[LabelValue::new(2), LabelValue::new(3)]));
926 assert_eq!(labels.position(&[LabelValue::new(2), LabelValue::new(3)]), Some(1));
927
928 assert!(!labels.contains(&[LabelValue::new(3), LabelValue::new(3)]));
929 assert_eq!(labels.position(&[LabelValue::new(3), LabelValue::new(3)]), None);
930 }
931
932 #[test]
933 fn indexing() {
934 let labels = Labels::new(
935 ["foo", "bar"],
936 vec![[2, 3], [1, 243], [-4, -2413]],
937 );
938
939 assert_eq!(labels[1], [1, 243]);
940 assert_eq!(labels[2], [-4, -2413]);
941 }
942
943 #[test]
944 fn debug() {
945 let labels = Labels::new(
946 ["foo", "bar"],
947 vec![[2, 3], [1, 243], [-4, -2413]],
948 );
949
950 let expected = format!(
951 "Labels @ {:p} {{\n foo, bar\n 2 3 \n 1 243 \n -4 -2413 \n}}\n",
952 labels.ptr
953 );
954 assert_eq!(format!("{:?}", labels), expected);
955 }
956
957 #[test]
958 fn union() {
959 let first = Labels::new(["aa", "bb"], [[0, 1], [1, 2]]);
960 let second = Labels::new(["aa", "bb"], [[2, 3], [1, 2], [4, 5]]);
961
962 let mut first_mapping = vec![0; first.count()];
963 let mut second_mapping = vec![0; second.count()];
964 let union = first.union(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
965
966 assert_eq!(union.names(), ["aa", "bb"]);
967 assert_eq!(union.values_cpu(), [0, 1, 1, 2, 2, 3, 4, 5]);
968
969 assert_eq!(first_mapping, [0, 1]);
970 assert_eq!(second_mapping, [2, 1, 3]);
971 }
972
973 #[test]
974 fn intersection() {
975 let first = Labels::new(["aa", "bb"], [[0, 1], [1, 2]]);
976 let second = Labels::new(["aa", "bb"], [[2, 3], [1, 2], [4, 5]]);
977
978 let mut first_mapping = vec![0_i64; first.count()];
979 let mut second_mapping = vec![0_i64; second.count()];
980 let union = first.intersection(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
981
982 assert_eq!(union.names(), ["aa", "bb"]);
983 assert_eq!(union.values_cpu(), [1, 2]);
984
985 assert_eq!(first_mapping, [-1, 0]);
986 assert_eq!(second_mapping, [-1, 0, -1]);
987 }
988
989 #[test]
990 fn difference() {
991 let first = Labels::new(["aa", "bb"], [[0, 1], [1, 2]]);
992 let second = Labels::new(["aa", "bb"], [[2, 3], [1, 2], [4, 5]]);
993
994 let mut mapping = vec![0_i64; first.count()];
995 let union = first.difference(&second, Some(&mut mapping)).unwrap();
996
997 assert_eq!(union.names(), ["aa", "bb"]);
998 assert_eq!(union.values_cpu(), [0, 1]);
999
1000 assert_eq!(mapping, [0, -1]);
1001 }
1002
1003 #[test]
1004 fn selection() {
1005 let labels = Labels::new(["aa", "bb"], [[1, 1], [1, 2], [3, 2], [2, 1]]);
1007 let selection = Labels::new(["aa"], [[1], [2], [5]]);
1008
1009 let selected = labels.select(&selection).unwrap();
1010 assert_eq!(selected, [0, 1, 3]);
1011
1012 let selection = Labels::new(["aa", "bb"], [[1, 1], [2, 1], [5, 1], [1, 2]]);
1014 let selected = labels.select(&selection).unwrap();
1015 assert_eq!(selected, [0, 3, 1]);
1016
1017 let selection = Labels::empty(vec!["aa"]);
1019 let selected = labels.select(&selection).unwrap();
1020 assert_eq!(selected, []);
1021
1022 let selection = Labels::empty(vec!["aaa"]);
1024 let err = labels.select(&selection).unwrap_err();
1025 assert_eq!(err.message,
1026 "invalid parameter: 'aaa' in selection is not part of these Labels"
1027 );
1028 }
1029
1030 #[test]
1031 fn labels_into_raw() {
1032 let original = Labels::new(["foo", "bar"], [[1, 2], [3, 4], [5, 6]]);
1033 let raw = Labels::into_raw(original);
1034
1035 let recovered = unsafe { Labels::from_raw(raw) };
1036 assert_eq!(recovered, Labels::new(["foo", "bar"], [[1, 2], [3, 4], [5, 6]]));
1037 }
1038}