1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::{mts_tensormap_t, mts_labels_t};
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, TensorBlock, TensorBlockRef, Labels, LabelValue};
9
10pub struct TensorMap {
20 pub(crate) ptr: *mut mts_tensormap_t,
21 keys: Labels,
23}
24
25unsafe impl Send for TensorMap {}
27unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 use crate::labels::pretty_print_labels;
33 writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35 write!(f, " keys: ")?;
36 pretty_print_labels(self.keys(), " ", f)?;
37 writeln!(f, "}}")
38 }
39}
40
41impl std::ops::Drop for TensorMap {
42 #[allow(unused_must_use)]
43 fn drop(&mut self) {
44 unsafe {
45 crate::c_api::mts_tensormap_free(self.ptr);
46 }
47 }
48}
49
50impl TensorMap {
51 #[allow(clippy::needless_pass_by_value)]
57 #[inline]
58 pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59 let ptr = unsafe {
60 crate::c_api::mts_tensormap(
61 keys.as_mts_labels_t(),
62 blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66 blocks.len()
67 )
68 };
69
70 for block in blocks {
71 std::mem::forget(block);
74 }
75
76 check_ptr(ptr)?;
77
78 return Ok(unsafe { TensorMap::from_raw(ptr) });
79 }
80
81 pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91 assert!(!ptr.is_null());
92
93 let mut keys = mts_labels_t::null();
94 check_status(crate::c_api::mts_tensormap_keys(
95 ptr,
96 &mut keys
97 )).expect("failed to get the keys");
98
99 let keys = Labels::from_raw(keys);
100
101 return TensorMap {
102 ptr,
103 keys
104 };
105 }
106
107 pub fn into_raw(mut map: TensorMap) -> *mut mts_tensormap_t {
113 let ptr = map.ptr;
114 map.ptr = std::ptr::null_mut();
115 return ptr;
116 }
117
118 #[inline]
123 pub fn try_clone(&self) -> Result<TensorMap, Error> {
124 let ptr = unsafe {
125 crate::c_api::mts_tensormap_copy(self.ptr)
126 };
127 crate::errors::check_ptr(ptr)?;
128
129 return Ok(unsafe { TensorMap::from_raw(ptr) });
130 }
131
132 pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
136 return crate::io::load(path);
137 }
138
139 pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
143 return crate::io::load_buffer(buffer);
144 }
145
146 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
150 return crate::io::save(path, self);
151 }
152
153 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
157 return crate::io::save_buffer(self, buffer);
158 }
159
160 #[inline]
162 pub fn keys(&self) -> &Labels {
163 &self.keys
164 }
165
166 #[inline]
172 pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
173
174 let mut block = std::ptr::null_mut();
175 unsafe {
176 check_status(crate::c_api::mts_tensormap_block_by_id(
177 self.ptr,
178 &mut block,
179 index,
180 )).expect("failed to get a block");
181 }
182
183 return unsafe { TensorBlockRef::from_raw(block) }
184 }
185
186 #[inline]
192 pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
193 return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
194 }
195
196 #[inline]
208 unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
209 let mut block = std::ptr::null_mut();
210
211 check_status(crate::c_api::mts_tensormap_block_by_id(
212 ptr,
213 &mut block,
214 index,
215 )).expect("failed to get a block");
216
217 return TensorBlockRefMut::from_raw(block);
218 }
219
220 #[inline]
226 pub fn blocks_matching(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
227 let mut indexes = vec![0; self.keys().count()];
228 let mut matching = indexes.len();
229 unsafe {
230 check_status(crate::c_api::mts_tensormap_blocks_matching(
231 self.ptr,
232 indexes.as_mut_ptr(),
233 &mut matching,
234 selection.as_mts_labels_t(),
235 ))?;
236 }
237 indexes.resize(matching, 0);
238
239 return Ok(indexes);
240 }
241
242 #[inline]
247 pub fn block_matching(&self, selection: &Labels) -> Result<usize, Error> {
248 let matching = self.blocks_matching(selection)?;
249 if matching.len() != 1 {
250 let selection_str = selection.names()
251 .iter().zip(&selection[0])
252 .map(|(name, value)| format!("{} = {}", name, value))
253 .collect::<Vec<_>>()
254 .join(", ");
255
256
257 if matching.is_empty() {
258 return Err(Error {
259 code: None,
260 message: format!(
261 "no blocks matched the selection ({})",
262 selection_str
263 ),
264 });
265 } else {
266 return Err(Error {
267 code: None,
268 message: format!(
269 "{} blocks matched the selection ({}), expected only one",
270 matching.len(),
271 selection_str
272 ),
273 });
274 }
275 }
276
277 return Ok(matching[0])
278 }
279
280 #[inline]
285 pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
286 let id = self.block_matching(selection)?;
287 return Ok(self.block_by_id(id));
288 }
289
290 #[inline]
292 pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
293 let mut blocks = Vec::new();
294 for i in 0..self.keys().count() {
295 blocks.push(self.block_by_id(i));
296 }
297 return blocks;
298 }
299
300 #[inline]
302 pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
303 let mut blocks = Vec::new();
304 for i in 0..self.keys().count() {
305 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
306 }
307 return blocks;
308 }
309
310 #[inline]
330 pub fn keys_to_samples(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
331 let ptr = unsafe {
332 crate::c_api::mts_tensormap_keys_to_samples(
333 self.ptr,
334 keys_to_move.as_mts_labels_t(),
335 sort_samples,
336 )
337 };
338
339 check_ptr(ptr)?;
340 return Ok(unsafe { TensorMap::from_raw(ptr) });
341 }
342
343 #[inline]
370 pub fn keys_to_properties(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
371 let ptr = unsafe {
372 crate::c_api::mts_tensormap_keys_to_properties(
373 self.ptr,
374 keys_to_move.as_mts_labels_t(),
375 sort_samples,
376 )
377 };
378
379 check_ptr(ptr)?;
380 return Ok(unsafe { TensorMap::from_raw(ptr) });
381 }
382
383 #[inline]
386 pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
387 let dimensions_c = dimensions.iter()
388 .map(|&v| CString::new(v).expect("unexpected NULL byte"))
389 .collect::<Vec<_>>();
390
391 let dimensions_ptr = dimensions_c.iter()
392 .map(|v| v.as_ptr())
393 .collect::<Vec<_>>();
394
395
396 let ptr = unsafe {
397 crate::c_api::mts_tensormap_components_to_properties(
398 self.ptr,
399 dimensions_ptr.as_ptr(),
400 dimensions.len(),
401 )
402 };
403
404 check_ptr(ptr)?;
405 return Ok(unsafe { TensorMap::from_raw(ptr) });
406 }
407
408 #[inline]
410 pub fn iter(&self) -> TensorMapIter<'_> {
411 return TensorMapIter {
412 inner: self.keys().iter().zip(self.blocks())
413 };
414 }
415
416 #[inline]
419 pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
420 let mut blocks = Vec::new();
423 for i in 0..self.keys().count() {
424 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
425 }
426
427 return TensorMapIterMut {
428 inner: self.keys().into_iter().zip(blocks)
429 };
430 }
431
432 #[cfg(feature = "rayon")]
434 #[inline]
435 pub fn par_iter(&self) -> TensorMapParIter<'_> {
436 use rayon::prelude::*;
437 TensorMapParIter {
438 inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
439 }
440 }
441
442 #[cfg(feature = "rayon")]
445 #[inline]
446 pub fn par_iter_mut(&mut self) -> TensorMapParIterMut<'_> {
447 use rayon::prelude::*;
448
449 let mut blocks = Vec::new();
452 for i in 0..self.keys().count() {
453 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
454 }
455
456 TensorMapParIterMut {
457 inner: self.keys().par_iter().zip_eq(blocks)
458 }
459 }
460
461 pub fn set_info(&mut self, key: &str, value: &str) {
464 let mut key = key.to_owned().into_bytes();
465 key.push(b'\0');
466
467 let mut value = value.to_owned().into_bytes();
468 value.push(b'\0');
469
470 unsafe {
471 check_status(crate::c_api::mts_tensormap_set_info(
472 self.ptr, key.as_ptr().cast(), value.as_ptr().cast()
473 )).expect("failed to set info");
474 }
475 }
476
477 pub fn get_info(&self, key: &str) -> Option<&str> {
480 let mut key = key.to_owned().into_bytes();
481 key.push(b'\0');
482
483 let mut value = std::ptr::null();
484
485 unsafe {
486 check_status(crate::c_api::mts_tensormap_get_info(
487 self.ptr, key.as_ptr().cast(), &mut value
488 )).expect("failed to set info");
489 }
490
491 if value.is_null() {
492 return None;
493 }
494
495 let c_str = unsafe { CStr::from_ptr(value) };
496 return Some(c_str.to_str().expect("invalid UTF-8 string"));
497 }
498
499 pub fn info(&self) -> TensorMapInfoIter<'_> {
502 let mut keys = std::ptr::null();
503 let mut count = 0;
504 unsafe {
505 check_status(crate::c_api::mts_tensormap_info_keys(
506 self.ptr,
507 &mut keys,
508 &mut count,
509 )).expect("failed to get info keys");
510 };
511
512 let keys = unsafe {
513 std::slice::from_raw_parts(keys, count)
514 };
515 let keys = keys.iter()
516 .map(|&k| {
517 let c_str = unsafe { CStr::from_ptr(k) };
518 c_str.to_str().expect("invalid UTF-8 string")
519 })
520 .collect::<Vec<_>>();
521
522 TensorMapInfoIter {
523 keys: keys,
524 tensor: self,
525 index: 0,
526 count,
527 }
528 }
529}
530
531pub struct TensorMapIter<'a> {
535 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
536}
537
538impl<'a> Iterator for TensorMapIter<'a> {
539 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
540
541 #[inline]
542 fn next(&mut self) -> Option<Self::Item> {
543 self.inner.next()
544 }
545
546 fn size_hint(&self) -> (usize, Option<usize>) {
547 self.inner.size_hint()
548 }
549}
550
551impl ExactSizeIterator for TensorMapIter<'_> {
552 #[inline]
553 fn len(&self) -> usize {
554 self.inner.len()
555 }
556}
557
558impl FusedIterator for TensorMapIter<'_> {}
559
560impl<'a> IntoIterator for &'a TensorMap {
561 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
562
563 type IntoIter = TensorMapIter<'a>;
564
565 fn into_iter(self) -> Self::IntoIter {
566 self.iter()
567 }
568}
569
570pub struct TensorMapIterMut<'a> {
575 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
576}
577
578impl<'a> Iterator for TensorMapIterMut<'a> {
579 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
580
581 #[inline]
582 fn next(&mut self) -> Option<Self::Item> {
583 self.inner.next()
584 }
585
586 fn size_hint(&self) -> (usize, Option<usize>) {
587 self.inner.size_hint()
588 }
589}
590
591impl ExactSizeIterator for TensorMapIterMut<'_> {
592 #[inline]
593 fn len(&self) -> usize {
594 self.inner.len()
595 }
596}
597
598impl FusedIterator for TensorMapIterMut<'_> {}
599
600impl<'a> IntoIterator for &'a mut TensorMap {
601 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
602
603 type IntoIter = TensorMapIterMut<'a>;
604
605 fn into_iter(self) -> Self::IntoIter {
606 self.iter_mut()
607 }
608}
609
610
611#[cfg(feature = "rayon")]
615pub struct TensorMapParIter<'a> {
616 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
617}
618
619#[cfg(feature = "rayon")]
620impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
621 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
622
623 #[inline]
624 fn drive_unindexed<C>(self, consumer: C) -> C::Result
625 where
626 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
627 self.inner.drive_unindexed(consumer)
628 }
629}
630
631#[cfg(feature = "rayon")]
632impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
633 #[inline]
634 fn len(&self) -> usize {
635 self.inner.len()
636 }
637
638 #[inline]
639 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
640 self.inner.drive(consumer)
641 }
642
643 #[inline]
644 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
645 self.inner.with_producer(callback)
646 }
647}
648
649#[cfg(feature = "rayon")]
654pub struct TensorMapParIterMut<'a> {
655 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
656}
657
658#[cfg(feature = "rayon")]
659impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
660 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
661
662 #[inline]
663 fn drive_unindexed<C>(self, consumer: C) -> C::Result
664 where
665 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
666 self.inner.drive_unindexed(consumer)
667 }
668}
669
670#[cfg(feature = "rayon")]
671impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
672 #[inline]
673 fn len(&self) -> usize {
674 self.inner.len()
675 }
676
677 #[inline]
678 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
679 self.inner.drive(consumer)
680 }
681
682 #[inline]
683 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
684 self.inner.with_producer(callback)
685 }
686}
687
688pub struct TensorMapInfoIter<'a> {
692 keys: Vec<&'a str>,
693 tensor: &'a TensorMap,
694 index: usize,
695 count: usize,
696}
697
698impl<'a> Iterator for TensorMapInfoIter<'a> {
699 type Item = (&'a str, &'a str);
700
701 #[inline]
702 fn next(&mut self) -> Option<Self::Item> {
703 if self.index >= self.count {
704 return None;
705 }
706 let key = self.keys[self.index];
707 let value = self.tensor.get_info(key).expect("missing info");
708 self.index += 1;
709 return Some((key, value));
710 }
711
712 fn size_hint(&self) -> (usize, Option<usize>) {
713 (self.count, Some(self.count))
714 }
715}
716
717impl ExactSizeIterator for TensorMapInfoIter<'_> {
718 #[inline]
719 fn len(&self) -> usize {
720 self.count
721 }
722}
723
724impl FusedIterator for TensorMapInfoIter<'_> {}
725
726
727#[cfg(test)]
730#[allow(clippy::float_cmp)]
731mod tests {
732 use crate::{Labels, TensorBlock, TensorMap};
733
734 fn test_tensor() -> TensorMap {
735 let block_1 = TensorBlock::new(
736 ndarray::ArrayD::from_elem(vec![2, 3], 1.0),
737 &Labels::new(["samples"], &[[0], [1]]),
738 &[],
739 &Labels::new(["properties"], &[[-2], [0], [1]]),
740 ).unwrap();
741
742 let block_2 = TensorBlock::new(
743 ndarray::ArrayD::from_elem(vec![1, 1], 3.0),
744 &Labels::new(["samples"], &[[1]]),
745 &[],
746 &Labels::new(["properties"], &[[1]]),
747 ).unwrap();
748
749 let block_3 = TensorBlock::new(
750 ndarray::ArrayD::from_elem(vec![3, 2], -4.0),
751 &Labels::new(["samples"], &[[0], [1], [3]]),
752 &[],
753 &Labels::new(["properties"], &[[-2], [1]]),
754 ).unwrap();
755
756 return TensorMap::new(
757 Labels::new(["key", "other"], &[[1, 0], [3, 1], [-4, 0]]),
758 vec![block_1, block_2, block_3],
759 ).unwrap();
760 }
761
762 #[test]
763 fn block_access() {
764 let mut tensor = test_tensor();
765
766 let block = tensor.block_by_id(1);
767 assert_eq!(block.values().as_array().shape(), [1, 1]);
768
769 let block = tensor.block_mut_by_id(2);
770 assert_eq!(block.values().as_array().shape(), [3, 2]);
771
772 let selection = Labels::new(["key"], &[[1]]);
773 assert_eq!(tensor.block_matching(&selection).unwrap(), 0);
774 assert_eq!(tensor.blocks_matching(&selection).unwrap(), [0]);
775
776 let block = tensor.block(&selection).unwrap();
777 assert_eq!(block.values().as_array().shape(), [2, 3]);
778
779 let selection = Labels::new(["other"], &[[0]]);
780 assert!(tensor.block_matching(&selection).is_err());
781 assert_eq!(tensor.blocks_matching(&selection).unwrap(), [0, 2]);
782
783 let blocks = tensor.blocks();
784 assert_eq!(blocks[0].values().as_array().shape(), [2, 3]);
785 assert_eq!(blocks[1].values().as_array().shape(), [1, 1]);
786 assert_eq!(blocks[2].values().as_array().shape(), [3, 2]);
787
788 let blocks = tensor.blocks_mut();
789 assert_eq!(blocks[0].values().as_array().shape(), [2, 3]);
790 assert_eq!(blocks[1].values().as_array().shape(), [1, 1]);
791 assert_eq!(blocks[2].values().as_array().shape(), [3, 2]);
792 }
793
794 #[test]
795 fn iter() {
796 let mut tensor = test_tensor();
797
798 for (key, block) in &tensor {
800 assert_eq!(block.values().to_array()[[0, 0]], f64::from(key[0].i32()));
801 }
802
803 for (key, mut block) in &mut tensor {
805 let array = block.values_mut().to_array_mut();
806 *array *= 2.0;
807 assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
808 }
809 }
810
811 #[cfg(feature = "rayon")]
812 #[test]
813 fn par_iter() {
814 use rayon::iter::ParallelIterator;
815
816 let mut tensor = test_tensor();
817
818 tensor.par_iter().for_each(|(key, block)| {
820 assert_eq!(block.values().to_array()[[0, 0]], f64::from(key[0].i32()));
821 });
822
823 tensor.par_iter_mut().for_each(|(key, mut block)| {
825 let array = block.values_mut().to_array_mut();
826 *array *= 2.0;
827 assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
828 });
829 }
830
831 #[test]
832 fn info() {
833 let mut tensor = test_tensor();
834 tensor.set_info("creator", "unit test");
835 tensor.set_info("version", "1.0");
836
837 assert_eq!(tensor.get_info("creator").unwrap(), "unit test");
838 assert_eq!(tensor.get_info("version").unwrap(), "1.0");
839 assert!(tensor.get_info("missing").is_none());
840
841 let mut info_iter = tensor.info();
842 let (key, value) = info_iter.next().unwrap();
843 assert_eq!(key, "creator");
844 assert_eq!(value, "unit test");
845 let (key, value) = info_iter.next().unwrap();
846 assert_eq!(key, "version");
847 assert_eq!(value, "1.0");
848 assert!(info_iter.next().is_none());
849 }
850}