1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::mts_tensormap_t;
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, LabelValue, Labels, MtsArray, TensorBlock, TensorBlockRef};
9
10pub struct TensorMap {
20 pub(crate) ptr: *mut mts_tensormap_t,
21 keys: Labels,
23}
24
25unsafe impl Send for TensorMap {}
27unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 use crate::labels::pretty_print_labels;
33 writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35 write!(f, " keys: ")?;
36 pretty_print_labels(self.keys(), " ", f)?;
37 writeln!(f, "}}")
38 }
39}
40
41impl std::ops::Drop for TensorMap {
42 #[allow(unused_must_use)]
43 fn drop(&mut self) {
44 unsafe {
45 crate::c_api::mts_tensormap_free(self.ptr);
46 }
47 }
48}
49
50impl TensorMap {
51 #[allow(clippy::needless_pass_by_value)]
57 #[inline]
58 pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59 let ptr = unsafe {
60 crate::c_api::mts_tensormap(
61 keys.as_mts_labels_t(),
62 blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66 blocks.len()
67 )
68 };
69
70 for block in blocks {
71 std::mem::forget(block);
74 }
75
76 check_ptr(ptr)?;
77
78 return Ok(unsafe { TensorMap::from_raw(ptr) });
79 }
80
81 pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91 assert!(!ptr.is_null());
92
93 let keys_ptr = crate::c_api::mts_tensormap_keys(ptr);
94 assert!(!keys_ptr.is_null(), "failed to get the keys");
95 let keys = Labels::from_raw(keys_ptr);
96
97 return TensorMap {
98 ptr,
99 keys
100 };
101 }
102
103 pub fn into_raw(mut map: TensorMap) -> *mut mts_tensormap_t {
109 let ptr = map.ptr;
110 map.ptr = std::ptr::null_mut();
111 return ptr;
112 }
113
114 #[inline]
119 pub fn try_clone(&self) -> Result<TensorMap, Error> {
120 let ptr = unsafe {
121 crate::c_api::mts_tensormap_copy(self.ptr)
122 };
123 crate::errors::check_ptr(ptr)?;
124
125 return Ok(unsafe { TensorMap::from_raw(ptr) });
126 }
127
128 pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
132 return crate::io::load(path);
133 }
134
135 pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
139 return crate::io::load_buffer(buffer);
140 }
141
142 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
146 return crate::io::save(path, self);
147 }
148
149 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
153 return crate::io::save_buffer(self, buffer);
154 }
155
156 #[inline]
158 pub fn keys(&self) -> &Labels {
159 &self.keys
160 }
161
162 #[inline]
168 pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
169
170 let mut block = std::ptr::null_mut();
171 unsafe {
172 check_status(crate::c_api::mts_tensormap_block_by_id(
173 self.ptr,
174 &mut block,
175 index,
176 )).expect("failed to get a block");
177 }
178
179 return unsafe { TensorBlockRef::from_raw(block) }
180 }
181
182 #[inline]
188 pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
189 return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
190 }
191
192 #[inline]
204 unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
205 let mut block = std::ptr::null_mut();
206
207 check_status(
208 crate::c_api::mts_tensormap_block_by_id(
209 ptr,
210 &mut block,
211 index,
212 )).expect("failed to get a block");
213
214 return TensorBlockRefMut::from_raw(block);
215 }
216
217 #[inline]
219 pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
220 let matching = self.keys.select(selection)?;
221 if matching.len() != 1 {
222 let selection_str = selection.names()
223 .iter()
224 .zip(&selection[0])
225 .map(|(name, value)| format!("{} = {}", name, value))
226 .collect::<Vec<_>>()
227 .join(", ");
228
229 if matching.is_empty() {
230 return Err(Error {
231 code: None,
232 message: format!(
233 "no blocks matched the selection ({})",
234 selection_str
235 ),
236 });
237 } else {
238 return Err(Error {
239 code: None,
240 message: format!(
241 "{} blocks matched the selection ({}), expected only one",
242 matching.len(),
243 selection_str
244 ),
245 });
246 }
247 }
248
249 return Ok(self.block_by_id(matching[0]));
250 }
251
252 #[inline]
254 pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
255 let mut blocks = Vec::new();
256 for i in 0..self.keys().count() {
257 blocks.push(self.block_by_id(i));
258 }
259 return blocks;
260 }
261
262 #[inline]
264 pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
265 let mut blocks = Vec::new();
266 for i in 0..self.keys().count() {
267 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
268 }
269 return blocks;
270 }
271
272 #[inline]
289 pub fn keys_to_samples(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
290 let ptr = unsafe {
291 crate::c_api::mts_tensormap_keys_to_samples(
292 self.ptr,
293 keys_to_move.as_mts_labels_t(),
294 fill_value.into_raw(),
295 sort_samples,
296 )
297 };
298
299 check_ptr(ptr)?;
300 return Ok(unsafe { TensorMap::from_raw(ptr) });
301 }
302
303 #[inline]
330 pub fn keys_to_properties(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
331 let ptr = unsafe {
332 crate::c_api::mts_tensormap_keys_to_properties(
333 self.ptr,
334 keys_to_move.as_mts_labels_t(),
335 fill_value.into_raw(),
336 sort_samples,
337 )
338 };
339
340 check_ptr(ptr)?;
341 return Ok(unsafe { TensorMap::from_raw(ptr) });
342 }
343
344 #[inline]
347 pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
348 let dimensions_c = dimensions.iter()
349 .map(|&v| CString::new(v).expect("unexpected NULL byte"))
350 .collect::<Vec<_>>();
351
352 let dimensions_ptr = dimensions_c.iter()
353 .map(|v| v.as_ptr())
354 .collect::<Vec<_>>();
355
356
357 let ptr = unsafe {
358 crate::c_api::mts_tensormap_components_to_properties(
359 self.ptr,
360 dimensions_ptr.as_ptr(),
361 dimensions.len(),
362 )
363 };
364
365 check_ptr(ptr)?;
366 return Ok(unsafe { TensorMap::from_raw(ptr) });
367 }
368
369 #[inline]
371 pub fn iter(&self) -> TensorMapIter<'_> {
372 return TensorMapIter {
373 inner: self.keys().into_iter().zip(self.blocks())
374 };
375 }
376
377 #[inline]
380 pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
381 let mut blocks = Vec::new();
384 for i in 0..self.keys().count() {
385 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
386 }
387
388 return TensorMapIterMut {
389 inner: self.keys().into_iter().zip(blocks)
390 };
391 }
392
393 #[cfg(feature = "rayon")]
395 #[inline]
396 pub fn par_iter(&self) -> TensorMapParIter<'_> {
397 use rayon::prelude::*;
398 TensorMapParIter {
399 inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
400 }
401 }
402
403 #[cfg(feature = "rayon")]
406 #[inline]
407 pub fn par_iter_mut(&mut self) -> TensorMapParIterMut<'_> {
408 use rayon::prelude::*;
409
410 let mut blocks = Vec::new();
413 for i in 0..self.keys().count() {
414 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
415 }
416
417 TensorMapParIterMut {
418 inner: self.keys().par_iter().zip_eq(blocks)
419 }
420 }
421
422 pub fn set_info(&mut self, key: &str, value: &str) {
425 let mut key = key.to_owned().into_bytes();
426 key.push(b'\0');
427
428 let mut value = value.to_owned().into_bytes();
429 value.push(b'\0');
430
431 unsafe {
432 check_status(crate::c_api::mts_tensormap_set_info(
433 self.ptr, key.as_ptr().cast(), value.as_ptr().cast()
434 )).expect("failed to set info");
435 }
436 }
437
438 pub fn get_info(&self, key: &str) -> Option<&str> {
441 let mut key = key.to_owned().into_bytes();
442 key.push(b'\0');
443
444 let mut value = std::ptr::null();
445
446 unsafe {
447 check_status(crate::c_api::mts_tensormap_get_info(
448 self.ptr, key.as_ptr().cast(), &mut value
449 )).expect("failed to set info");
450 }
451
452 if value.is_null() {
453 return None;
454 }
455
456 let c_str = unsafe { CStr::from_ptr(value) };
457 return Some(c_str.to_str().expect("invalid UTF-8 string"));
458 }
459
460 pub fn info(&self) -> TensorMapInfoIter<'_> {
463 let mut keys = std::ptr::null();
464 let mut count = 0;
465 unsafe {
466 check_status(crate::c_api::mts_tensormap_info_keys(
467 self.ptr,
468 &mut keys,
469 &mut count,
470 )).expect("failed to get info keys");
471 };
472
473 let keys = unsafe {
474 std::slice::from_raw_parts(keys, count)
475 };
476 let keys = keys.iter()
477 .map(|&k| {
478 let c_str = unsafe { CStr::from_ptr(k) };
479 c_str.to_str().expect("invalid UTF-8 string")
480 })
481 .collect::<Vec<_>>();
482
483 TensorMapInfoIter {
484 keys: keys,
485 tensor: self,
486 index: 0,
487 count,
488 }
489 }
490}
491
492pub struct TensorMapIter<'a> {
496 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
497}
498
499impl<'a> Iterator for TensorMapIter<'a> {
500 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
501
502 #[inline]
503 fn next(&mut self) -> Option<Self::Item> {
504 self.inner.next()
505 }
506
507 fn size_hint(&self) -> (usize, Option<usize>) {
508 self.inner.size_hint()
509 }
510}
511
512impl ExactSizeIterator for TensorMapIter<'_> {
513 #[inline]
514 fn len(&self) -> usize {
515 self.inner.len()
516 }
517}
518
519impl FusedIterator for TensorMapIter<'_> {}
520
521impl<'a> IntoIterator for &'a TensorMap {
522 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
523
524 type IntoIter = TensorMapIter<'a>;
525
526 fn into_iter(self) -> Self::IntoIter {
527 self.iter()
528 }
529}
530
531pub struct TensorMapIterMut<'a> {
536 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
537}
538
539impl<'a> Iterator for TensorMapIterMut<'a> {
540 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
541
542 #[inline]
543 fn next(&mut self) -> Option<Self::Item> {
544 self.inner.next()
545 }
546
547 fn size_hint(&self) -> (usize, Option<usize>) {
548 self.inner.size_hint()
549 }
550}
551
552impl ExactSizeIterator for TensorMapIterMut<'_> {
553 #[inline]
554 fn len(&self) -> usize {
555 self.inner.len()
556 }
557}
558
559impl FusedIterator for TensorMapIterMut<'_> {}
560
561impl<'a> IntoIterator for &'a mut TensorMap {
562 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
563
564 type IntoIter = TensorMapIterMut<'a>;
565
566 fn into_iter(self) -> Self::IntoIter {
567 self.iter_mut()
568 }
569}
570
571
572#[cfg(feature = "rayon")]
576pub struct TensorMapParIter<'a> {
577 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
578}
579
580#[cfg(feature = "rayon")]
581impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
582 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
583
584 #[inline]
585 fn drive_unindexed<C>(self, consumer: C) -> C::Result
586 where
587 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
588 self.inner.drive_unindexed(consumer)
589 }
590}
591
592#[cfg(feature = "rayon")]
593impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
594 #[inline]
595 fn len(&self) -> usize {
596 self.inner.len()
597 }
598
599 #[inline]
600 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
601 self.inner.drive(consumer)
602 }
603
604 #[inline]
605 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
606 self.inner.with_producer(callback)
607 }
608}
609
610#[cfg(feature = "rayon")]
615pub struct TensorMapParIterMut<'a> {
616 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
617}
618
619#[cfg(feature = "rayon")]
620impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
621 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
622
623 #[inline]
624 fn drive_unindexed<C>(self, consumer: C) -> C::Result
625 where
626 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
627 self.inner.drive_unindexed(consumer)
628 }
629}
630
631#[cfg(feature = "rayon")]
632impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
633 #[inline]
634 fn len(&self) -> usize {
635 self.inner.len()
636 }
637
638 #[inline]
639 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
640 self.inner.drive(consumer)
641 }
642
643 #[inline]
644 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
645 self.inner.with_producer(callback)
646 }
647}
648
649pub struct TensorMapInfoIter<'a> {
653 keys: Vec<&'a str>,
654 tensor: &'a TensorMap,
655 index: usize,
656 count: usize,
657}
658
659impl<'a> Iterator for TensorMapInfoIter<'a> {
660 type Item = (&'a str, &'a str);
661
662 #[inline]
663 fn next(&mut self) -> Option<Self::Item> {
664 if self.index >= self.count {
665 return None;
666 }
667 let key = self.keys[self.index];
668 let value = self.tensor.get_info(key).expect("missing info");
669 self.index += 1;
670 return Some((key, value));
671 }
672
673 fn size_hint(&self) -> (usize, Option<usize>) {
674 (self.count, Some(self.count))
675 }
676}
677
678impl ExactSizeIterator for TensorMapInfoIter<'_> {
679 #[inline]
680 fn len(&self) -> usize {
681 self.count
682 }
683}
684
685impl FusedIterator for TensorMapInfoIter<'_> {}
686
687
688#[cfg(test)]
691#[allow(clippy::float_cmp)]
692mod tests {
693 use crate::{Labels, TensorBlock, TensorMap};
694
695 fn test_tensor() -> TensorMap {
696 let block_1 = TensorBlock::new(
697 ndarray::Array::from_elem(vec![2, 3], 1.0),
698 &Labels::new(["samples"], &[[0], [1]]),
699 &[],
700 &Labels::new(["properties"], &[[-2], [0], [1]]),
701 ).unwrap();
702
703 let block_2 = TensorBlock::new(
704 ndarray::Array::from_elem(vec![1, 1], 3.0),
705 &Labels::new(["samples"], &[[1]]),
706 &[],
707 &Labels::new(["properties"], &[[1]]),
708 ).unwrap();
709
710 let block_3 = TensorBlock::new(
711 ndarray::Array::from_elem(vec![3, 2], -4.0),
712 &Labels::new(["samples"], &[[0], [1], [3]]),
713 &[],
714 &Labels::new(["properties"], &[[-2], [1]]),
715 ).unwrap();
716
717 return TensorMap::new(
718 Labels::new(["key", "other"], &[[1, 0], [3, 1], [-4, 0]]),
719 vec![block_1, block_2, block_3],
720 ).unwrap();
721 }
722
723 #[test]
724 fn block_access() {
725 let mut tensor = test_tensor();
726
727 let block = tensor.block_by_id(1);
728 assert_eq!(block.values().shape().unwrap(), [1, 1]);
729
730 let block = tensor.block_mut_by_id(2);
731 assert_eq!(block.values().shape().unwrap(), [3, 2]);
732
733 let selection = Labels::new(["key"], &[[1]]);
734
735 let block = tensor.block(&selection).unwrap();
736 {
737 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
738 assert_eq!(values.shape(), [2, 3]);
739 }
740
741 let blocks = tensor.blocks();
742 assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
743 assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
744 assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
745
746 let blocks = tensor.blocks_mut();
747 assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
748 assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
749 assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
750 }
751
752 #[test]
753 fn iter() {
754 let mut tensor = test_tensor();
755
756 for (key, block) in &tensor {
758 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
759 assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
760 }
761
762 for (key, mut block) in &mut tensor {
764 let array = block.values_mut().get_ndarray_mut::<f64>();
765 *array *= 2.0;
766 assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
767 }
768 }
769
770 #[cfg(feature = "rayon")]
771 #[test]
772 fn par_iter() {
773 use rayon::iter::ParallelIterator;
774
775 let mut tensor = test_tensor();
776
777 tensor.par_iter().for_each(|(key, block)| {
779 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
780 assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
781 });
782
783 tensor.par_iter_mut().for_each(|(key, mut block)| {
785 let array = block.values_mut().get_ndarray_mut::<f64>();
786 *array *= 2.0;
787 assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
788 });
789 }
790
791 #[test]
792 fn info() {
793 let mut tensor = test_tensor();
794 tensor.set_info("creator", "unit test");
795 tensor.set_info("version", "1.0");
796
797 assert_eq!(tensor.get_info("creator").unwrap(), "unit test");
798 assert_eq!(tensor.get_info("version").unwrap(), "1.0");
799 assert!(tensor.get_info("missing").is_none());
800
801 let mut info_iter = tensor.info();
802 let (key, value) = info_iter.next().unwrap();
803 assert_eq!(key, "creator");
804 assert_eq!(value, "unit test");
805 let (key, value) = info_iter.next().unwrap();
806 assert_eq!(key, "version");
807 assert_eq!(value, "1.0");
808 assert!(info_iter.next().is_none());
809 }
810}