1use std::ffi::{CStr, CString};
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::mts_tensormap_t;
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, LabelValue, Labels, MtsArray, TensorBlock, TensorBlockRef};
9
10pub struct TensorMap {
20 pub(crate) ptr: *mut mts_tensormap_t,
21 keys: Labels,
23}
24
25unsafe impl Send for TensorMap {}
27unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 use crate::labels::pretty_print_labels;
33 writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35 write!(f, " keys: ")?;
36 pretty_print_labels(self.keys(), " ", f)?;
37 writeln!(f, "}}")
38 }
39}
40
41impl std::ops::Drop for TensorMap {
42 #[allow(unused_must_use)]
43 fn drop(&mut self) {
44 unsafe {
45 crate::c_api::mts_tensormap_free(self.ptr);
46 }
47 }
48}
49
50impl TensorMap {
51 #[allow(clippy::needless_pass_by_value)]
57 #[inline]
58 pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59 let ptr = unsafe {
60 crate::c_api::mts_tensormap(
61 keys.as_mts_labels_t(),
62 blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66 blocks.len()
67 )
68 };
69
70 for block in blocks {
71 std::mem::forget(block);
74 }
75
76 check_ptr(ptr)?;
77
78 return Ok(unsafe { TensorMap::from_raw(ptr) });
79 }
80
81 pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91 assert!(!ptr.is_null());
92
93 let keys_ptr = crate::c_api::mts_tensormap_keys(ptr);
94 assert!(!keys_ptr.is_null(), "failed to get the keys");
95 let keys = Labels::from_raw(keys_ptr);
96
97 return TensorMap {
98 ptr,
99 keys
100 };
101 }
102
103 pub fn into_raw(mut tensor: TensorMap) -> *mut mts_tensormap_t {
109 return std::mem::replace(&mut tensor.ptr, std::ptr::null_mut());
110 }
111
112 pub fn as_ptr(&self) -> *const mts_tensormap_t {
117 self.ptr
118 }
119
120 pub fn as_mut_ptr(&mut self) -> *mut mts_tensormap_t {
125 self.ptr
126 }
127
128 #[inline]
133 pub fn try_clone(&self) -> Result<TensorMap, Error> {
134 let ptr = unsafe {
135 crate::c_api::mts_tensormap_copy(self.ptr)
136 };
137 crate::errors::check_ptr(ptr)?;
138
139 return Ok(unsafe { TensorMap::from_raw(ptr) });
140 }
141
142 pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
146 return crate::io::load(path);
147 }
148
149 pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
153 return crate::io::load_buffer(buffer);
154 }
155
156 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
160 return crate::io::save(path, self);
161 }
162
163 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
167 return crate::io::save_buffer(self, buffer);
168 }
169
170 #[inline]
172 pub fn device(&self) -> Result<dlpk::sys::DLDevice, Error> {
173 let mut device = dlpk::sys::DLDevice::cpu();
174 unsafe {
175 check_status(crate::c_api::mts_tensormap_device(
176 self.ptr,
177 &mut device,
178 ))?;
179 }
180 return Ok(device);
181 }
182
183 #[inline]
185 pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
186 let mut dtype = dlpk::sys::DLDataType {
187 code: dlpk::sys::DLDataTypeCode::kDLFloat,
188 bits: 0,
189 lanes: 0,
190 };
191 unsafe {
192 check_status(crate::c_api::mts_tensormap_dtype(
193 self.ptr,
194 &mut dtype,
195 ))?;
196 }
197 return Ok(dtype);
198 }
199
200 #[inline]
202 pub fn keys(&self) -> &Labels {
203 &self.keys
204 }
205
206 #[inline]
212 pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
213
214 let mut block = std::ptr::null_mut();
215 unsafe {
216 check_status(crate::c_api::mts_tensormap_block_by_id(
217 self.ptr,
218 &mut block,
219 index,
220 )).expect("failed to get a block");
221 }
222
223 return unsafe { TensorBlockRef::from_raw(block) }
224 }
225
226 #[inline]
232 pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
233 return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
234 }
235
236 #[inline]
248 unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
249 let mut block = std::ptr::null_mut();
250
251 check_status(
252 crate::c_api::mts_tensormap_block_by_id(
253 ptr,
254 &mut block,
255 index,
256 )).expect("failed to get a block");
257
258 return TensorBlockRefMut::from_raw(block);
259 }
260
261 #[inline]
263 pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
264 let matching = self.keys.select(selection)?;
265 if matching.len() != 1 {
266 let selection_str = selection.names()
267 .iter()
268 .zip(&selection[0])
269 .map(|(name, value)| format!("{} = {}", name, value))
270 .collect::<Vec<_>>()
271 .join(", ");
272
273 if matching.is_empty() {
274 return Err(Error {
275 code: None,
276 message: format!(
277 "no blocks matched the selection ({})",
278 selection_str
279 ),
280 });
281 } else {
282 return Err(Error {
283 code: None,
284 message: format!(
285 "{} blocks matched the selection ({}), expected only one",
286 matching.len(),
287 selection_str
288 ),
289 });
290 }
291 }
292
293 return Ok(self.block_by_id(matching[0]));
294 }
295
296 #[inline]
298 pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
299 let mut blocks = Vec::new();
300 for i in 0..self.keys().count() {
301 blocks.push(self.block_by_id(i));
302 }
303 return blocks;
304 }
305
306 #[inline]
308 pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
309 let mut blocks = Vec::new();
310 for i in 0..self.keys().count() {
311 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
312 }
313 return blocks;
314 }
315
316 #[inline]
333 pub fn keys_to_samples(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
334 let ptr = unsafe {
335 crate::c_api::mts_tensormap_keys_to_samples(
336 self.ptr,
337 keys_to_move.as_mts_labels_t(),
338 fill_value.into_raw(),
339 sort_samples,
340 )
341 };
342
343 check_ptr(ptr)?;
344 return Ok(unsafe { TensorMap::from_raw(ptr) });
345 }
346
347 #[inline]
374 pub fn keys_to_properties(&self, keys_to_move: &Labels, fill_value: MtsArray, sort_samples: bool) -> Result<TensorMap, Error> {
375 let ptr = unsafe {
376 crate::c_api::mts_tensormap_keys_to_properties(
377 self.ptr,
378 keys_to_move.as_mts_labels_t(),
379 fill_value.into_raw(),
380 sort_samples,
381 )
382 };
383
384 check_ptr(ptr)?;
385 return Ok(unsafe { TensorMap::from_raw(ptr) });
386 }
387
388 #[inline]
391 pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
392 let dimensions_c = dimensions.iter()
393 .map(|&v| CString::new(v).expect("unexpected NULL byte"))
394 .collect::<Vec<_>>();
395
396 let dimensions_ptr = dimensions_c.iter()
397 .map(|v| v.as_ptr())
398 .collect::<Vec<_>>();
399
400
401 let ptr = unsafe {
402 crate::c_api::mts_tensormap_components_to_properties(
403 self.ptr,
404 dimensions_ptr.as_ptr(),
405 dimensions.len(),
406 )
407 };
408
409 check_ptr(ptr)?;
410 return Ok(unsafe { TensorMap::from_raw(ptr) });
411 }
412
413 #[inline]
415 pub fn iter(&self) -> TensorMapIter<'_> {
416 return TensorMapIter {
417 inner: self.keys().into_iter().zip(self.blocks())
418 };
419 }
420
421 #[inline]
424 pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
425 let mut blocks = Vec::new();
428 for i in 0..self.keys().count() {
429 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
430 }
431
432 return TensorMapIterMut {
433 inner: self.keys().into_iter().zip(blocks)
434 };
435 }
436
437 #[cfg(feature = "rayon")]
439 #[inline]
440 pub fn par_iter(&self) -> TensorMapParIter<'_> {
441 use rayon::prelude::*;
442 TensorMapParIter {
443 inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
444 }
445 }
446
447 #[cfg(feature = "rayon")]
450 #[inline]
451 pub fn par_iter_mut(&mut self) -> TensorMapParIterMut<'_> {
452 use rayon::prelude::*;
453
454 let mut blocks = Vec::new();
457 for i in 0..self.keys().count() {
458 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
459 }
460
461 TensorMapParIterMut {
462 inner: self.keys().par_iter().zip_eq(blocks)
463 }
464 }
465
466 pub fn set_info(&mut self, key: &str, value: &str) {
469 let mut key = key.to_owned().into_bytes();
470 key.push(b'\0');
471
472 let mut value = value.to_owned().into_bytes();
473 value.push(b'\0');
474
475 unsafe {
476 check_status(crate::c_api::mts_tensormap_set_info(
477 self.ptr, key.as_ptr().cast(), value.as_ptr().cast()
478 )).expect("failed to set info");
479 }
480 }
481
482 pub fn get_info(&self, key: &str) -> Option<&str> {
485 let mut key = key.to_owned().into_bytes();
486 key.push(b'\0');
487
488 let mut value = std::ptr::null();
489
490 unsafe {
491 check_status(crate::c_api::mts_tensormap_get_info(
492 self.ptr, key.as_ptr().cast(), &mut value
493 )).expect("failed to set info");
494 }
495
496 if value.is_null() {
497 return None;
498 }
499
500 let c_str = unsafe { CStr::from_ptr(value) };
501 return Some(c_str.to_str().expect("invalid UTF-8 string"));
502 }
503
504 pub fn info(&self) -> TensorMapInfoIter<'_> {
507 let mut keys = std::ptr::null();
508 let mut count = 0;
509 unsafe {
510 check_status(crate::c_api::mts_tensormap_info_keys(
511 self.ptr,
512 &mut keys,
513 &mut count,
514 )).expect("failed to get info keys");
515 };
516
517 let keys = unsafe {
518 std::slice::from_raw_parts(keys, count)
519 };
520 let keys = keys.iter()
521 .map(|&k| {
522 let c_str = unsafe { CStr::from_ptr(k) };
523 c_str.to_str().expect("invalid UTF-8 string")
524 })
525 .collect::<Vec<_>>();
526
527 TensorMapInfoIter {
528 keys: keys,
529 tensor: self,
530 index: 0,
531 count,
532 }
533 }
534}
535
536pub struct TensorMapIter<'a> {
540 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
541}
542
543impl<'a> Iterator for TensorMapIter<'a> {
544 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
545
546 #[inline]
547 fn next(&mut self) -> Option<Self::Item> {
548 self.inner.next()
549 }
550
551 fn size_hint(&self) -> (usize, Option<usize>) {
552 self.inner.size_hint()
553 }
554}
555
556impl ExactSizeIterator for TensorMapIter<'_> {
557 #[inline]
558 fn len(&self) -> usize {
559 self.inner.len()
560 }
561}
562
563impl FusedIterator for TensorMapIter<'_> {}
564
565impl<'a> IntoIterator for &'a TensorMap {
566 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
567
568 type IntoIter = TensorMapIter<'a>;
569
570 fn into_iter(self) -> Self::IntoIter {
571 self.iter()
572 }
573}
574
575pub struct TensorMapIterMut<'a> {
580 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
581}
582
583impl<'a> Iterator for TensorMapIterMut<'a> {
584 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
585
586 #[inline]
587 fn next(&mut self) -> Option<Self::Item> {
588 self.inner.next()
589 }
590
591 fn size_hint(&self) -> (usize, Option<usize>) {
592 self.inner.size_hint()
593 }
594}
595
596impl ExactSizeIterator for TensorMapIterMut<'_> {
597 #[inline]
598 fn len(&self) -> usize {
599 self.inner.len()
600 }
601}
602
603impl FusedIterator for TensorMapIterMut<'_> {}
604
605impl<'a> IntoIterator for &'a mut TensorMap {
606 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
607
608 type IntoIter = TensorMapIterMut<'a>;
609
610 fn into_iter(self) -> Self::IntoIter {
611 self.iter_mut()
612 }
613}
614
615
616#[cfg(feature = "rayon")]
620pub struct TensorMapParIter<'a> {
621 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
622}
623
624#[cfg(feature = "rayon")]
625impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
626 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
627
628 #[inline]
629 fn drive_unindexed<C>(self, consumer: C) -> C::Result
630 where
631 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
632 self.inner.drive_unindexed(consumer)
633 }
634}
635
636#[cfg(feature = "rayon")]
637impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
638 #[inline]
639 fn len(&self) -> usize {
640 self.inner.len()
641 }
642
643 #[inline]
644 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
645 self.inner.drive(consumer)
646 }
647
648 #[inline]
649 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
650 self.inner.with_producer(callback)
651 }
652}
653
654#[cfg(feature = "rayon")]
659pub struct TensorMapParIterMut<'a> {
660 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
661}
662
663#[cfg(feature = "rayon")]
664impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
665 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
666
667 #[inline]
668 fn drive_unindexed<C>(self, consumer: C) -> C::Result
669 where
670 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
671 self.inner.drive_unindexed(consumer)
672 }
673}
674
675#[cfg(feature = "rayon")]
676impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
677 #[inline]
678 fn len(&self) -> usize {
679 self.inner.len()
680 }
681
682 #[inline]
683 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
684 self.inner.drive(consumer)
685 }
686
687 #[inline]
688 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
689 self.inner.with_producer(callback)
690 }
691}
692
693pub struct TensorMapInfoIter<'a> {
697 keys: Vec<&'a str>,
698 tensor: &'a TensorMap,
699 index: usize,
700 count: usize,
701}
702
703impl<'a> Iterator for TensorMapInfoIter<'a> {
704 type Item = (&'a str, &'a str);
705
706 #[inline]
707 fn next(&mut self) -> Option<Self::Item> {
708 if self.index >= self.count {
709 return None;
710 }
711 let key = self.keys[self.index];
712 let value = self.tensor.get_info(key).expect("missing info");
713 self.index += 1;
714 return Some((key, value));
715 }
716
717 fn size_hint(&self) -> (usize, Option<usize>) {
718 (self.count, Some(self.count))
719 }
720}
721
722impl ExactSizeIterator for TensorMapInfoIter<'_> {
723 #[inline]
724 fn len(&self) -> usize {
725 self.count
726 }
727}
728
729impl FusedIterator for TensorMapInfoIter<'_> {}
730
731
732#[cfg(test)]
735#[allow(clippy::float_cmp)]
736mod tests {
737 use crate::{Labels, TensorBlock, TensorMap};
738
739 fn test_tensor() -> TensorMap {
740 let block_1 = TensorBlock::new(
741 ndarray::Array::from_elem(vec![2, 3], 1.0),
742 &Labels::new(["samples"], [[0], [1]]),
743 &[],
744 &Labels::new(["properties"], [[-2], [0], [1]]),
745 ).unwrap();
746
747 let block_2 = TensorBlock::new(
748 ndarray::Array::from_elem(vec![1, 1], 3.0),
749 &Labels::new(["samples"], [[1]]),
750 &[],
751 &Labels::new(["properties"], [[1]]),
752 ).unwrap();
753
754 let block_3 = TensorBlock::new(
755 ndarray::Array::from_elem(vec![3, 2], -4.0),
756 &Labels::new(["samples"], [[0], [1], [3]]),
757 &[],
758 &Labels::new(["properties"], [[-2], [1]]),
759 ).unwrap();
760
761 return TensorMap::new(
762 Labels::new(["key", "other"], [[1, 0], [3, 1], [-4, 0]]),
763 vec![block_1, block_2, block_3],
764 ).unwrap();
765 }
766
767 #[test]
768 fn block_access() {
769 let mut tensor = test_tensor();
770
771 let block = tensor.block_by_id(1);
772 assert_eq!(block.values().shape().unwrap(), [1, 1]);
773
774 let block = tensor.block_mut_by_id(2);
775 assert_eq!(block.values().shape().unwrap(), [3, 2]);
776
777 let selection = Labels::new(["key"], [[1]]);
778
779 let block = tensor.block(&selection).unwrap();
780 {
781 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
782 assert_eq!(values.shape(), [2, 3]);
783 }
784
785 let blocks = tensor.blocks();
786 assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
787 assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
788 assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
789
790 let blocks = tensor.blocks_mut();
791 assert_eq!(blocks[0].values().shape().unwrap(), [2, 3]);
792 assert_eq!(blocks[1].values().shape().unwrap(), [1, 1]);
793 assert_eq!(blocks[2].values().shape().unwrap(), [3, 2]);
794 }
795
796 #[test]
797 fn iter() {
798 let mut tensor = test_tensor();
799
800 for (key, block) in &tensor {
802 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
803 assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
804 }
805
806 for (key, mut block) in &mut tensor {
808 let array = block.values_mut().get_ndarray_mut::<f64>();
809 *array *= 2.0;
810 assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
811 }
812 }
813
814 #[cfg(feature = "rayon")]
815 #[test]
816 fn par_iter() {
817 use rayon::iter::ParallelIterator;
818
819 let mut tensor = test_tensor();
820
821 tensor.par_iter().for_each(|(key, block)| {
823 let values = block.values().to_ndarray_lock::<f64>().read().unwrap();
824 assert_eq!(values[[0, 0]], f64::from(key[0].i32()));
825 });
826
827 tensor.par_iter_mut().for_each(|(key, mut block)| {
829 let array = block.values_mut().get_ndarray_mut::<f64>();
830 *array *= 2.0;
831 assert_eq!(array[[0, 0]], 2.0 * f64::from(key[0].i32()));
832 });
833 }
834
835 #[test]
836 fn info() {
837 let mut tensor = test_tensor();
838 tensor.set_info("creator", "unit test");
839 tensor.set_info("version", "1.0");
840
841 assert_eq!(tensor.get_info("creator").unwrap(), "unit test");
842 assert_eq!(tensor.get_info("version").unwrap(), "1.0");
843 assert!(tensor.get_info("missing").is_none());
844
845 let mut info_iter = tensor.info();
846 let (key, value) = info_iter.next().unwrap();
847 assert_eq!(key, "creator");
848 assert_eq!(value, "unit test");
849 let (key, value) = info_iter.next().unwrap();
850 assert_eq!(key, "version");
851 assert_eq!(value, "1.0");
852 assert!(info_iter.next().is_none());
853 }
854
855 #[test]
856 fn device_and_dtype() {
857 let tensor = test_tensor();
858
859 let device = tensor.device().unwrap();
860 assert_eq!(device.device_type, dlpk::sys::DLDeviceType::kDLCPU);
861
862 let dtype = tensor.dtype().unwrap();
863 assert_eq!(dtype.code, dlpk::sys::DLDataTypeCode::kDLFloat);
864 assert_eq!(dtype.bits, 64);
865 }
866
867 #[test]
868 fn tensor_map_into_raw() {
869 let tensor = test_tensor();
870 let raw = TensorMap::into_raw(tensor);
871
872 let recovered = unsafe { TensorMap::from_raw(raw) };
873 assert_eq!(
874 recovered.keys(),
875 &Labels::new(["key", "other"], [[1, 0], [3, 1], [-4, 0]])
876 );
877 }
878}