1use std::sync::OnceLock;
2use std::ffi::{CStr, CString};
3use std::collections::BTreeSet;
4use std::iter::FusedIterator;
5
6use crate::MtsArray;
7use crate::c_api::mts_labels_t;
8use crate::errors::check_ptr;
9use crate::errors::{Error, check_status};
10
11#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[repr(transparent)]
17pub struct LabelValue(i32);
18
19impl PartialEq<i32> for LabelValue {
20 #[inline]
21 fn eq(&self, other: &i32) -> bool {
22 self.0 == *other
23 }
24}
25
26impl PartialEq<LabelValue> for i32 {
27 #[inline]
28 fn eq(&self, other: &LabelValue) -> bool {
29 *self == other.0
30 }
31}
32
33impl std::fmt::Debug for LabelValue {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38
39impl std::fmt::Display for LabelValue {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
46impl From<u32> for LabelValue {
47 #[inline]
48 fn from(value: u32) -> LabelValue {
49 assert!(value < i32::MAX as u32);
50 LabelValue(value as i32)
51 }
52}
53
54impl From<i32> for LabelValue {
55 #[inline]
56 fn from(value: i32) -> LabelValue {
57 LabelValue(value)
58 }
59}
60
61#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
62impl From<usize> for LabelValue {
63 #[inline]
64 fn from(value: usize) -> LabelValue {
65 assert!(value < i32::MAX as usize);
66 LabelValue(value as i32)
67 }
68}
69
70#[allow(clippy::cast_possible_truncation)]
71impl From<isize> for LabelValue {
72 #[inline]
73 fn from(value: isize) -> LabelValue {
74 assert!(value < i32::MAX as isize && value > i32::MIN as isize);
75 LabelValue(value as i32)
76 }
77}
78
79impl LabelValue {
80 #[inline]
82 pub fn new(value: i32) -> LabelValue {
83 LabelValue(value)
84 }
85
86 #[inline]
88 #[allow(clippy::cast_sign_loss)]
89 pub fn usize(self) -> usize {
90 debug_assert!(self.0 >= 0);
91 self.0 as usize
92 }
93
94 #[inline]
96 pub fn isize(self) -> isize {
97 self.0 as isize
98 }
99
100 #[inline]
102 pub fn i32(self) -> i32 {
103 self.0
104 }
105}
106
107pub struct Labels {
119 pub(crate) ptr: *const mts_labels_t,
120 values_cpu_ptr: OnceLock<*const LabelValue>,
121 count: OnceLock<usize>,
122 size: OnceLock<usize>,
123}
124
125unsafe impl Send for Labels {}
128unsafe impl Sync for Labels {}
131
132impl std::fmt::Debug for Labels {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 pretty_print_labels(self, "", f)
135 }
136}
137
138pub(crate) fn pretty_print_labels(
140 labels: &Labels,
141 offset: &str,
142 f: &mut std::fmt::Formatter<'_>
143) -> std::fmt::Result {
144 let names = labels.names();
145
146 writeln!(f, "Labels @ {:p} {{", labels.ptr)?;
147 writeln!(f, "{} {}", offset, names.join(", "))?;
148
149 let widths = names.iter().map(|s| s.len()).collect::<Vec<_>>();
150 for values in labels {
151 write!(f, "{} ", offset)?;
152 for (value, width) in values.iter().zip(&widths) {
153 write!(f, "{:^width$} ", value.isize(), width=width)?;
154 }
155 writeln!(f)?;
156 }
157
158 writeln!(f, "{}}}", offset)
159}
160
161impl Clone for Labels {
162 #[inline]
163 fn clone(&self) -> Self {
164 let ptr = unsafe { crate::c_api::mts_labels_clone(self.ptr) };
165 assert!(!ptr.is_null(), "failed to clone Labels");
166 Labels {
167 ptr,
168 values_cpu_ptr: self.values_cpu_ptr.clone(),
169 count: self.count.clone(),
170 size: self.size.clone(),
171 }
172 }
173}
174
175impl std::ops::Drop for Labels {
176 #[allow(unused_must_use)]
177 fn drop(&mut self) {
178 unsafe {
179 crate::c_api::mts_labels_free(self.ptr);
180 }
181 }
182}
183
184impl Labels {
185 #[inline]
195 pub fn new<T, const N: usize>(names: [&str; N], values: &[[T; N]]) -> Labels
196 where T: Copy + Into<LabelValue>
197 {
198 let mut builder = LabelsBuilder::new(names.to_vec());
199 for entry in values {
200 builder.add(entry);
201 }
202 return builder.finish();
203 }
204
205 pub fn as_mts_labels_t(&self) -> *const mts_labels_t {
207 self.ptr
208 }
209
210 #[inline]
220 pub unsafe fn from_raw(ptr: *const mts_labels_t) -> Labels {
221 assert!(!ptr.is_null(), "expected mts_labels_t pointer to not be NULL");
222 Labels {
223 ptr,
224 values_cpu_ptr: OnceLock::new(),
225 size: OnceLock::new(),
226 count: OnceLock::new(),
227 }
228 }
229
230 #[inline]
232 pub fn empty(names: Vec<&str>) -> Labels {
233 return LabelsBuilder::new(names).finish()
234 }
235
236 #[inline]
239 pub fn single() -> Labels {
240 let mut builder = LabelsBuilder::new(vec!["_"]);
241 builder.add(&[0]);
242 return builder.finish();
243 }
244
245 pub fn load(path: impl AsRef<std::path::Path>) -> Result<Labels, Error> {
249 return crate::io::load_labels(path);
250 }
251
252 pub fn load_buffer(buffer: &[u8]) -> Result<Labels, Error> {
256 return crate::io::load_labels_buffer(buffer);
257 }
258
259 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
263 return crate::io::save_labels(path, self);
264 }
265
266 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
270 return crate::io::save_labels_buffer(self, buffer);
271 }
272
273 #[inline]
275 pub fn size(&self) -> usize {
276 *self.size.get_or_init(|| self.names().len())
277 }
278
279 #[inline]
281 pub fn names(&self) -> Vec<&str> {
282 let mut names_ptr = std::ptr::null();
283 let mut size = 0;
284 unsafe {
285 check_status(crate::c_api::mts_labels_dimensions(self.ptr, &mut names_ptr, &mut size))
286 .expect("failed to get labels dimensions");
287 }
288
289 if size == 0 {
290 return Vec::new();
291 }
292
293 unsafe {
294 let names = std::slice::from_raw_parts(names_ptr, size);
295 return names.iter()
296 .map(|&ptr| CStr::from_ptr(ptr).to_str().expect("invalid UTF8"))
297 .collect();
298 }
299 }
300
301 pub fn values(&self) -> MtsArray {
303 let mut array = crate::c_api::mts_array_t::null();
304 unsafe {
305 check_status(crate::c_api::mts_labels_values(
306 self.ptr, &mut array,
307 )).expect("failed to get labels values array");
308 }
309
310 return MtsArray::from_raw(array);
311 }
312
313 pub fn values_cpu(&self) -> &[LabelValue] {
316 let values_cpu_ptr = self.values_cpu_ptr.get_or_init(|| {
317 let mut values_cpu_ptr = std::ptr::null();
318 let mut count = 0;
319 let mut size = 0;
320
321 unsafe {
322 check_status(crate::c_api::mts_labels_values_cpu(
323 self.ptr,
324 &mut values_cpu_ptr,
325 &mut count,
326 &mut size
327 )).expect("failed to get CPU values for Labels");
328 }
329
330 debug_assert_eq!(count, self.count());
331 debug_assert_eq!(size, self.size());
332
333 values_cpu_ptr.cast()
334 });
335
336 unsafe {
337 let count = self.count();
338 let size = self.size();
339 std::slice::from_raw_parts(*values_cpu_ptr, count * size)
340 }
341 }
342
343 #[inline]
345 pub fn device(&self) -> dlpk::DLDevice {
346 let array = self.values();
347 return array.device().expect("failed to get the array device");
348 }
349
350 #[inline]
352 pub fn count(&self) -> usize {
353 return *self.count.get_or_init(|| self.values().shape().expect("failed to get the array shape")[0]);
354 }
355
356 #[inline]
358 pub fn is_empty(&self) -> bool {
359 self.count() == 0
360 }
361
362 #[inline]
364 pub fn contains(&self, label: &[LabelValue]) -> bool {
365 return self.position(label).is_some();
366 }
367
368 #[inline]
371 pub fn position(&self, value: &[LabelValue]) -> Option<usize> {
372 assert!(value.len() == self.size(), "invalid size of index in Labels::position");
373
374 let mut result = 0;
375 unsafe {
376 check_status(crate::c_api::mts_labels_position(
377 self.ptr,
378 value.as_ptr().cast(),
379 value.len(),
380 &mut result,
381 )).expect("failed to check label position");
382 }
383
384 return result.try_into().ok();
385 }
386
387 #[inline]
397 pub fn union(
398 &self,
399 other: &Labels,
400 first_mapping: Option<&mut [i64]>,
401 second_mapping: Option<&mut [i64]>,
402 ) -> Result<Labels, Error> {
403 let mut output: *const mts_labels_t = std::ptr::null();
404 let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
405 (m.as_mut_ptr(), m.len())
406 } else {
407 (std::ptr::null_mut(), 0)
408 };
409
410 let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
411 (m.as_mut_ptr(), m.len())
412 } else {
413 (std::ptr::null_mut(), 0)
414 };
415
416 unsafe {
417 check_status(crate::c_api::mts_labels_union(
418 self.ptr,
419 other.ptr,
420 &mut output,
421 first_mapping,
422 first_mapping_count,
423 second_mapping,
424 second_mapping_count,
425 ))?;
426
427 return Ok(Labels::from_raw(output));
428 }
429 }
430
431 #[inline]
443 pub fn intersection(
444 &self,
445 other: &Labels,
446 first_mapping: Option<&mut [i64]>,
447 second_mapping: Option<&mut [i64]>,
448 ) -> Result<Labels, Error> {
449 let mut output: *const mts_labels_t = std::ptr::null();
450 let (first_mapping, first_mapping_count) = if let Some(m) = first_mapping {
451 (m.as_mut_ptr(), m.len())
452 } else {
453 (std::ptr::null_mut(), 0)
454 };
455
456 let (second_mapping, second_mapping_count) = if let Some(m) = second_mapping {
457 (m.as_mut_ptr(), m.len())
458 } else {
459 (std::ptr::null_mut(), 0)
460 };
461
462 unsafe {
463 check_status(crate::c_api::mts_labels_intersection(
464 self.ptr,
465 other.ptr,
466 &mut output,
467 first_mapping,
468 first_mapping_count,
469 second_mapping,
470 second_mapping_count,
471 ))?;
472
473 return Ok(Labels::from_raw(output));
474 }
475 }
476
477 #[inline]
487 pub fn difference(
488 &self,
489 other: &Labels,
490 mapping: Option<&mut [i64]>,
491 ) -> Result<Labels, Error> {
492 let mut output: *const mts_labels_t = std::ptr::null();
493 let (mapping, mapping_count) = if let Some(m) = mapping {
494 (m.as_mut_ptr(), m.len())
495 } else {
496 (std::ptr::null_mut(), 0)
497 };
498
499 unsafe {
500 check_status(crate::c_api::mts_labels_difference(
501 self.ptr,
502 other.ptr,
503 &mut output,
504 mapping,
505 mapping_count,
506 ))?;
507
508 return Ok(Labels::from_raw(output));
509 }
510 }
511
512 #[allow(clippy::cast_possible_truncation)]
520 pub fn select(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
521 let mut selected = vec![0; self.count()];
522 let mut selected_count = selected.len();
523
524 unsafe {
525 check_status(crate::c_api::mts_labels_select(
526 self.as_mts_labels_t(),
527 selection.as_mts_labels_t(),
528 selected.as_mut_ptr(),
529 &mut selected_count
530 ))?;
531 }
532
533 selected.resize(selected_count, 0);
534
535 return Ok(selected.into_iter().map(|s| s as usize).collect());
536 }
537
538 #[inline]
540 pub fn iter(&self) -> LabelsIter<'_> {
541 return LabelsIter {
542 ptr: self.values_cpu().as_ptr(),
543 cur: 0,
544 len: self.count(),
545 chunk_len: self.size(),
546 phantom: std::marker::PhantomData,
547 };
548 }
549
550 #[cfg(feature = "rayon")]
552 #[inline]
553 pub fn par_iter(&self) -> LabelsParIter<'_> {
554 use rayon::prelude::*;
555 return LabelsParIter {
556 chunks: self.values_cpu().par_chunks_exact(self.size())
557 };
558 }
559
560 #[inline]
562 pub fn iter_fixed_size<const N: usize>(&self) -> LabelsFixedSizeIter<'_, N> {
563 assert!(N == self.size(),
564 "wrong label size in `iter_fixed_size`: the entries contains {} element \
565 but this function was called with size of {}",
566 self.size(), N
567 );
568
569 return LabelsFixedSizeIter {
570 values: self.values_cpu()
571 };
572 }
573}
574
575impl std::cmp::PartialEq<Labels> for Labels {
576 #[inline]
577 fn eq(&self, other: &Labels) -> bool {
578 if self.names() != other.names() {
579 return false;
580 }
581
582 if self.count() != other.count() {
583 return false;
584 }
585
586 if self.device() != other.device() {
587 return false;
588 }
589
590 if self.device().device_type == dlpk::DLDeviceType::kDLExtDev {
591 return true;
595 } else {
596 return self.values_cpu() == other.values_cpu();
597 }
598 }
599}
600
601impl std::ops::Index<usize> for Labels {
602 type Output = [LabelValue];
603
604 #[inline]
605 fn index(&self, i: usize) -> &[LabelValue] {
606 let start = i * self.size();
607 let stop = (i + 1) * self.size();
608 &self.values_cpu()[start..stop]
609 }
610}
611
612pub struct LabelsIter<'a> {
614 ptr: *const LabelValue,
616 cur: usize,
618 len: usize,
620 chunk_len: usize,
622 phantom: std::marker::PhantomData<&'a LabelValue>,
623}
624
625impl<'a> Iterator for LabelsIter<'a> {
626 type Item = &'a [LabelValue];
627
628 #[inline]
629 fn next(&mut self) -> Option<Self::Item> {
630 if self.cur < self.len {
631 unsafe {
632 let data = self.ptr.add(self.cur * self.chunk_len);
634 self.cur += 1;
635 Some(std::slice::from_raw_parts(data, self.chunk_len))
637 }
638 } else {
639 None
640 }
641 }
642
643 #[inline]
644 fn size_hint(&self) -> (usize, Option<usize>) {
645 let remaining = self.len - self.cur;
646 return (remaining, Some(remaining));
647 }
648}
649
650impl ExactSizeIterator for LabelsIter<'_> {
651 #[inline]
652 fn len(&self) -> usize {
653 self.len
654 }
655}
656
657impl FusedIterator for LabelsIter<'_> {}
658
659impl<'a> IntoIterator for &'a Labels {
660 type IntoIter = LabelsIter<'a>;
661 type Item = &'a [LabelValue];
662
663 #[inline]
664 fn into_iter(self) -> Self::IntoIter {
665 self.iter()
666 }
667}
668
669#[cfg(feature = "rayon")]
671#[derive(Debug, Clone)]
672pub struct LabelsParIter<'a> {
673 chunks: rayon::slice::ChunksExact<'a, LabelValue>,
674}
675
676#[cfg(feature = "rayon")]
677impl<'a> rayon::iter::ParallelIterator for LabelsParIter<'a> {
678 type Item = &'a [LabelValue];
679
680 #[inline]
681 fn drive_unindexed<C>(self, consumer: C) -> C::Result
682 where
683 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
684 self.chunks.drive_unindexed(consumer)
685 }
686}
687
688#[cfg(feature = "rayon")]
689impl rayon::iter::IndexedParallelIterator for LabelsParIter<'_> {
690 #[inline]
691 fn len(&self) -> usize {
692 self.chunks.len()
693 }
694
695 #[inline]
696 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
697 self.chunks.drive(consumer)
698 }
699
700 #[inline]
701 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
702 self.chunks.with_producer(callback)
703 }
704}
705
706#[derive(Debug, Clone)]
708pub struct LabelsFixedSizeIter<'a, const N: usize> {
709 values: &'a [LabelValue],
710}
711
712impl<'a, const N: usize> Iterator for LabelsFixedSizeIter<'a, N> {
713 type Item = &'a [LabelValue; N];
714
715 #[inline]
716 fn next(&mut self) -> Option<Self::Item> {
717 if self.values.is_empty() {
718 return None
719 }
720
721 let (value, rest) = self.values.split_at(N);
722 self.values = rest;
723 return Some(value.try_into().expect("wrong size in FixedSizeIter::next"));
724 }
725
726 fn size_hint(&self) -> (usize, Option<usize>) {
727 (self.len(), Some(self.len()))
728 }
729}
730
731impl<const N: usize> ExactSizeIterator for LabelsFixedSizeIter<'_, N> {
732 #[inline]
733 fn len(&self) -> usize {
734 self.values.len() / N
735 }
736}
737
738#[derive(Debug, Clone)]
740pub struct LabelsBuilder {
741 names: Vec<String>,
743 values: Vec<i32>,
744}
745
746impl LabelsBuilder {
747 #[inline]
749 pub fn new(names: Vec<&str>) -> LabelsBuilder {
750 let n_unique_names = names.iter().collect::<BTreeSet<_>>().len();
751 assert!(n_unique_names == names.len(), "invalid labels: the same name is used multiple times");
752
753 LabelsBuilder {
754 names: names.into_iter().map(|s| s.into()).collect(),
755 values: Vec::new(),
756 }
757 }
758
759 #[inline]
761 pub fn reserve(&mut self, additional: usize) {
762 self.values.reserve(additional * self.names.len());
763 }
764
765 #[inline]
767 pub fn size(&self) -> usize {
768 self.names.len()
769 }
770
771 #[inline]
776 pub fn add<T>(&mut self, entry: &[T]) where T: Clone + Into<LabelValue> {
777 assert_eq!(
778 self.size(), entry.len(),
779 "wrong size for added label: got {}, but expected {}",
780 entry.len(), self.size()
781 );
782
783 for e in entry {
786 self.values.push(Into::<LabelValue>::into(e.clone()).i32());
787 }
788 }
789
790 fn finish_with(
792 self,
793 creator: unsafe extern "C" fn(
794 *const *const std::os::raw::c_char,
795 usize,
796 crate::c_api::mts_array_t,
797 ) -> *const mts_labels_t
798 ) -> Labels {
799 let mut raw_names = Vec::new();
800 let mut raw_names_ptr = Vec::new();
801
802 for name in &self.names {
803 let name = CString::new(&**name).expect("name contains a NULL byte");
804 raw_names_ptr.push(name.as_ptr());
805 raw_names.push(name);
806 }
807
808 let size = raw_names_ptr.len();
809 let count = if size == 0 {
810 assert!(self.values.is_empty());
811 0
812 } else {
813 self.values.len() / size
814 };
815
816 let array = ndarray::Array::from_shape_vec(vec![count, size], self.values)
818 .expect("shape mismatch when creating labels array");
819 let array: MtsArray = array.into();
820
821 let ptr = unsafe {
822 creator(
823 raw_names_ptr.as_ptr(),
824 size,
825 array.into_raw(),
826 )
827 };
828 check_ptr(ptr).expect("invalid labels");
829
830 unsafe { Labels::from_raw(ptr) }
831 }
832
833 #[inline]
837 pub fn finish(self) -> Labels {
838 self.finish_with(crate::c_api::mts_labels)
839 }
840
841 #[inline]
851 pub fn finish_assume_unique(self) -> Labels {
852 self.finish_with(crate::c_api::mts_labels_assume_unique)
853 }
854}
855
856
857#[cfg(test)]
858mod tests {
859 use super::*;
860
861 #[test]
862 fn labels() {
863 let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
864 builder.add(&[2, 3]);
865 builder.add(&[1, 243]);
866 builder.add(&[-4, -2413]);
867
868 let labels = builder.finish();
869 assert_eq!(labels.names(), &["foo", "bar"]);
870 assert_eq!(labels.size(), 2);
871 assert_eq!(labels.count(), 3);
872 assert!(!labels.is_empty());
873
874 assert_eq!(labels[0], [2, 3]);
875 assert_eq!(labels[1], [1, 243]);
876 assert_eq!(labels[2], [-4, -2413]);
877
878 let builder = LabelsBuilder::new(vec![]);
879 let labels = builder.finish();
880 assert_eq!(labels.size(), 0);
881 assert_eq!(labels.count(), 0);
882
883 let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
884 builder.add(&[2, 3]);
885 builder.add(&[1, 243]);
886 let labels = builder.finish_assume_unique();
887 assert_eq!(labels.names(), &["foo", "bar"]);
888 assert_eq!(labels.size(), 2);
889 assert_eq!(labels.count(), 2);
890 }
891
892 #[test]
893 fn direct_construct() {
894 let labels = Labels::new(
895 ["foo", "bar"],
896 &[
897 [2, 3],
898 [1, 243],
899 [-4, -2413],
900 ]
901 );
902
903 assert_eq!(labels.names(), &["foo", "bar"]);
904 assert_eq!(labels.size(), 2);
905 assert_eq!(labels.count(), 3);
906
907 assert_eq!(labels[0], [2, 3]);
908 assert_eq!(labels[1], [1, 243]);
909 assert_eq!(labels[2], [-4, -2413]);
910 }
911
912 #[test]
913 fn iter() {
914 let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
915 builder.add(&[2, 3]);
916 builder.add(&[1, 2]);
917 builder.add(&[4, 3]);
918
919 let labels = builder.finish();
920 let mut iter = labels.iter();
921 assert_eq!(iter.len(), 3);
922
923 assert_eq!(iter.next().unwrap(), &[2, 3]);
924 assert_eq!(iter.next().unwrap(), &[1, 2]);
925 assert_eq!(iter.next().unwrap(), &[4, 3]);
926 assert_eq!(iter.next(), None);
927 }
928
929 #[cfg( feature = "rayon")]
930 #[test]
931 fn par_iter() {
932 use rayon::iter::IndexedParallelIterator;
933
934 let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
935 builder.add(&[2, 3]);
936 builder.add(&[1, 2]);
937 builder.add(&[4, 3]);
938
939 let labels = builder.finish();
940 let iter = labels.par_iter();
941 assert_eq!(iter.len(), 3);
942
943 let mut values = Vec::new();
944 iter.collect_into_vec(&mut values);
945
946 assert_eq!(values, [&[2, 3], &[1, 2], &[4, 3]]);
947 }
948
949 #[test]
950 fn iter_fixed_size() {
951 let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
952 builder.add(&[1, 2]);
953 builder.add(&[2, 3]);
954
955 let labels = builder.finish();
956
957 for (i, [a, b]) in labels.iter_fixed_size().enumerate() {
958 assert_eq!(a.usize(), 1 + i);
959 assert_eq!(b.usize(), 2 + i);
960 }
961 }
962
963 #[test]
964 #[should_panic(expected = "wrong label size in `iter_fixed_size`: the entries contains 2 element but this function was called with size of 3")]
965 fn iter_fixed_size_wrong_size() {
966 let labels = LabelsBuilder::new(vec!["foo", "bar"]).finish();
967
968 for [_, _, _] in labels.iter_fixed_size() {}
969 }
970
971 #[test]
972 #[should_panic(expected = "'33 bar' is not a valid label name")]
973 fn invalid_label_name() {
974 LabelsBuilder::new(vec!["foo", "33 bar"]).finish();
975 }
976
977 #[test]
978 #[should_panic(expected = "invalid labels: the same name is used multiple times")]
979 fn duplicated_label_name() {
980 LabelsBuilder::new(vec!["foo", "bar", "foo"]).finish();
981 }
982
983 #[test]
984 #[should_panic(expected = "can not have the same label entry multiple times: [0, 1] is already present")]
985 fn duplicated_label_entry() {
986 let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
987 builder.add(&[0, 1]);
988 builder.add(&[0, 1]);
989 builder.finish();
990 }
991
992 #[test]
993 fn single_label() {
994 let labels = Labels::single();
995 assert_eq!(labels.names(), &["_"]);
996 assert_eq!(labels.size(), 1);
997 assert_eq!(labels.count(), 1);
998 }
999
1000 #[test]
1001 fn empty_label() {
1002 let labels = LabelsBuilder::new(vec!["foo", "bar"]).finish();
1003
1004 assert!(labels.is_empty());
1005 assert_eq!(labels.count(), 0);
1006 assert_eq!(labels.size(), 2);
1007 }
1008
1009 #[test]
1010 fn position() {
1011 let mut builder = LabelsBuilder::new(vec!["foo", "bar"]);
1012 builder.add(&[1, 2]);
1013 builder.add(&[2, 3]);
1014 let labels = builder.finish();
1015
1016 assert!(labels.contains(&[LabelValue::new(1), LabelValue::new(2)]));
1017 assert_eq!(labels.position(&[LabelValue::new(1), LabelValue::new(2)]), Some(0));
1018
1019 assert!(labels.contains(&[LabelValue::new(2), LabelValue::new(3)]));
1020 assert_eq!(labels.position(&[LabelValue::new(2), LabelValue::new(3)]), Some(1));
1021
1022 assert!(!labels.contains(&[LabelValue::new(3), LabelValue::new(3)]));
1023 assert_eq!(labels.position(&[LabelValue::new(3), LabelValue::new(3)]), None);
1024 }
1025
1026 #[test]
1027 fn indexing() {
1028 let labels = Labels::new(
1029 ["foo", "bar"],
1030 &[
1031 [2, 3],
1032 [1, 243],
1033 [-4, -2413],
1034 ]
1035 );
1036
1037 assert_eq!(labels[1], [1, 243]);
1038 assert_eq!(labels[2], [-4, -2413]);
1039 }
1040
1041 #[test]
1042 fn debug() {
1043 let labels = Labels::new(
1044 ["foo", "bar"],
1045 &[
1046 [2, 3],
1047 [1, 243],
1048 [-4, -2413],
1049 ]
1050 );
1051
1052 let expected = format!(
1053 "Labels @ {:p} {{\n foo, bar\n 2 3 \n 1 243 \n -4 -2413 \n}}\n",
1054 labels.ptr
1055 );
1056 assert_eq!(format!("{:?}", labels), expected);
1057 }
1058
1059 #[test]
1060 fn union() {
1061 let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
1062 let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
1063
1064 let mut first_mapping = vec![0; first.count()];
1065 let mut second_mapping = vec![0; second.count()];
1066 let union = first.union(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
1067
1068 assert_eq!(union.names(), ["aa", "bb"]);
1069 assert_eq!(union.values_cpu(), [0, 1, 1, 2, 2, 3, 4, 5]);
1070
1071 assert_eq!(first_mapping, [0, 1]);
1072 assert_eq!(second_mapping, [2, 1, 3]);
1073 }
1074
1075 #[test]
1076 fn intersection() {
1077 let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
1078 let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
1079
1080 let mut first_mapping = vec![0_i64; first.count()];
1081 let mut second_mapping = vec![0_i64; second.count()];
1082 let union = first.intersection(&second, Some(&mut first_mapping), Some(&mut second_mapping)).unwrap();
1083
1084 assert_eq!(union.names(), ["aa", "bb"]);
1085 assert_eq!(union.values_cpu(), [1, 2]);
1086
1087 assert_eq!(first_mapping, [-1, 0]);
1088 assert_eq!(second_mapping, [-1, 0, -1]);
1089 }
1090
1091 #[test]
1092 fn difference() {
1093 let first = Labels::new(["aa", "bb"], &[[0, 1], [1, 2]]);
1094 let second = Labels::new(["aa", "bb"], &[[2, 3], [1, 2], [4, 5]]);
1095
1096 let mut mapping = vec![0_i64; first.count()];
1097 let union = first.difference(&second, Some(&mut mapping)).unwrap();
1098
1099 assert_eq!(union.names(), ["aa", "bb"]);
1100 assert_eq!(union.values_cpu(), [0, 1]);
1101
1102 assert_eq!(mapping, [0, -1]);
1103 }
1104
1105 #[test]
1106 fn selection() {
1107 let labels = Labels::new(["aa", "bb"], &[[1, 1], [1, 2], [3, 2], [2, 1]]);
1109 let selection = Labels::new(["aa"], &[[1], [2], [5]]);
1110
1111 let selected = labels.select(&selection).unwrap();
1112 assert_eq!(selected, [0, 1, 3]);
1113
1114 let selection = Labels::new(["aa", "bb"], &[[1, 1], [2, 1], [5, 1], [1, 2]]);
1116 let selected = labels.select(&selection).unwrap();
1117 assert_eq!(selected, [0, 3, 1]);
1118
1119 let selection = Labels::empty(vec!["aa"]);
1121 let selected = labels.select(&selection).unwrap();
1122 assert_eq!(selected, []);
1123
1124 let selection = Labels::empty(vec!["aaa"]);
1126 let err = labels.select(&selection).unwrap_err();
1127 assert_eq!(err.message,
1128 "invalid parameter: 'aaa' in selection is not part of these Labels"
1129 );
1130 }
1131}