1use std::ffi::CString;
2use std::iter::FusedIterator;
3
4use crate::block::TensorBlockRefMut;
5use crate::c_api::{mts_tensormap_t, mts_labels_t};
6
7use crate::errors::{check_status, check_ptr};
8use crate::{Error, TensorBlock, TensorBlockRef, Labels, LabelValue};
9
10pub struct TensorMap {
20 pub(crate) ptr: *mut mts_tensormap_t,
21 keys: Labels,
23}
24
25unsafe impl Send for TensorMap {}
27unsafe impl Sync for TensorMap {}
29
30impl std::fmt::Debug for TensorMap {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 use crate::labels::pretty_print_labels;
33 writeln!(f, "Tensormap @ {:p} {{", self.ptr)?;
34
35 write!(f, " keys: ")?;
36 pretty_print_labels(self.keys(), " ", f)?;
37 writeln!(f, "}}")
38 }
39}
40
41impl std::ops::Drop for TensorMap {
42 #[allow(unused_must_use)]
43 fn drop(&mut self) {
44 unsafe {
45 crate::c_api::mts_tensormap_free(self.ptr);
46 }
47 }
48}
49
50impl TensorMap {
51 #[allow(clippy::needless_pass_by_value)]
57 #[inline]
58 pub fn new(keys: Labels, mut blocks: Vec<TensorBlock>) -> Result<TensorMap, Error> {
59 let ptr = unsafe {
60 crate::c_api::mts_tensormap(
61 keys.as_mts_labels_t(),
62 blocks.as_mut_ptr().cast::<*mut crate::c_api::mts_block_t>(),
66 blocks.len()
67 )
68 };
69
70 for block in blocks {
71 std::mem::forget(block);
74 }
75
76 check_ptr(ptr)?;
77
78 return Ok(unsafe { TensorMap::from_raw(ptr) });
79 }
80
81 pub unsafe fn from_raw(ptr: *mut mts_tensormap_t) -> TensorMap {
91 assert!(!ptr.is_null());
92
93 let mut keys = mts_labels_t::null();
94 check_status(crate::c_api::mts_tensormap_keys(
95 ptr,
96 &mut keys
97 )).expect("failed to get the keys");
98
99 let keys = Labels::from_raw(keys);
100
101 return TensorMap {
102 ptr,
103 keys
104 };
105 }
106
107 pub fn into_raw(mut map: TensorMap) -> *mut mts_tensormap_t {
113 let ptr = map.ptr;
114 map.ptr = std::ptr::null_mut();
115 return ptr;
116 }
117
118 #[inline]
123 pub fn try_clone(&self) -> Result<TensorMap, Error> {
124 let ptr = unsafe {
125 crate::c_api::mts_tensormap_copy(self.ptr)
126 };
127 crate::errors::check_ptr(ptr)?;
128
129 return Ok(unsafe { TensorMap::from_raw(ptr) });
130 }
131
132 pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorMap, Error> {
136 return crate::io::load(path);
137 }
138
139 pub fn load_buffer(buffer: &[u8]) -> Result<TensorMap, Error> {
143 return crate::io::load_buffer(buffer);
144 }
145
146 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
150 return crate::io::save(path, self);
151 }
152
153 pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
157 return crate::io::save_buffer(self, buffer);
158 }
159
160 #[inline]
162 pub fn keys(&self) -> &Labels {
163 &self.keys
164 }
165
166 #[inline]
172 pub fn block_by_id(&self, index: usize) -> TensorBlockRef<'_> {
173
174 let mut block = std::ptr::null_mut();
175 unsafe {
176 check_status(crate::c_api::mts_tensormap_block_by_id(
177 self.ptr,
178 &mut block,
179 index,
180 )).expect("failed to get a block");
181 }
182
183 return unsafe { TensorBlockRef::from_raw(block) }
184 }
185
186 #[inline]
192 pub fn block_mut_by_id(&mut self, index: usize) -> TensorBlockRefMut<'_> {
193 return unsafe { TensorMap::raw_block_mut_by_id(self.ptr, index) };
194 }
195
196 #[inline]
208 unsafe fn raw_block_mut_by_id<'a>(ptr: *mut mts_tensormap_t, index: usize) -> TensorBlockRefMut<'a> {
209 let mut block = std::ptr::null_mut();
210
211 check_status(crate::c_api::mts_tensormap_block_by_id(
212 ptr,
213 &mut block,
214 index,
215 )).expect("failed to get a block");
216
217 return TensorBlockRefMut::from_raw(block);
218 }
219
220 #[inline]
226 pub fn blocks_matching(&self, selection: &Labels) -> Result<Vec<usize>, Error> {
227 let mut indexes = vec![0; self.keys().count()];
228 let mut matching = indexes.len();
229 unsafe {
230 check_status(crate::c_api::mts_tensormap_blocks_matching(
231 self.ptr,
232 indexes.as_mut_ptr(),
233 &mut matching,
234 selection.as_mts_labels_t(),
235 ))?;
236 }
237 indexes.resize(matching, 0);
238
239 return Ok(indexes);
240 }
241
242 #[inline]
247 pub fn block_matching(&self, selection: &Labels) -> Result<usize, Error> {
248 let matching = self.blocks_matching(selection)?;
249 if matching.len() != 1 {
250 let selection_str = selection.names()
251 .iter().zip(&selection[0])
252 .map(|(name, value)| format!("{} = {}", name, value))
253 .collect::<Vec<_>>()
254 .join(", ");
255
256
257 if matching.is_empty() {
258 return Err(Error {
259 code: None,
260 message: format!(
261 "no blocks matched the selection ({})",
262 selection_str
263 ),
264 });
265 } else {
266 return Err(Error {
267 code: None,
268 message: format!(
269 "{} blocks matched the selection ({}), expected only one",
270 matching.len(),
271 selection_str
272 ),
273 });
274 }
275 }
276
277 return Ok(matching[0])
278 }
279
280 #[inline]
285 pub fn block(&self, selection: &Labels) -> Result<TensorBlockRef<'_>, Error> {
286 let id = self.block_matching(selection)?;
287 return Ok(self.block_by_id(id));
288 }
289
290 #[inline]
292 pub fn blocks(&self) -> Vec<TensorBlockRef<'_>> {
293 let mut blocks = Vec::new();
294 for i in 0..self.keys().count() {
295 blocks.push(self.block_by_id(i));
296 }
297 return blocks;
298 }
299
300 #[inline]
302 pub fn blocks_mut(&mut self) -> Vec<TensorBlockRefMut<'_>> {
303 let mut blocks = Vec::new();
304 for i in 0..self.keys().count() {
305 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
306 }
307 return blocks;
308 }
309
310 #[inline]
330 pub fn keys_to_samples(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
331 let ptr = unsafe {
332 crate::c_api::mts_tensormap_keys_to_samples(
333 self.ptr,
334 keys_to_move.as_mts_labels_t(),
335 sort_samples,
336 )
337 };
338
339 check_ptr(ptr)?;
340 return Ok(unsafe { TensorMap::from_raw(ptr) });
341 }
342
343 #[inline]
370 pub fn keys_to_properties(&self, keys_to_move: &Labels, sort_samples: bool) -> Result<TensorMap, Error> {
371 let ptr = unsafe {
372 crate::c_api::mts_tensormap_keys_to_properties(
373 self.ptr,
374 keys_to_move.as_mts_labels_t(),
375 sort_samples,
376 )
377 };
378
379 check_ptr(ptr)?;
380 return Ok(unsafe { TensorMap::from_raw(ptr) });
381 }
382
383 #[inline]
386 pub fn components_to_properties(&self, dimensions: &[&str]) -> Result<TensorMap, Error> {
387 let dimensions_c = dimensions.iter()
388 .map(|&v| CString::new(v).expect("unexpected NULL byte"))
389 .collect::<Vec<_>>();
390
391 let dimensions_ptr = dimensions_c.iter()
392 .map(|v| v.as_ptr())
393 .collect::<Vec<_>>();
394
395
396 let ptr = unsafe {
397 crate::c_api::mts_tensormap_components_to_properties(
398 self.ptr,
399 dimensions_ptr.as_ptr(),
400 dimensions.len(),
401 )
402 };
403
404 check_ptr(ptr)?;
405 return Ok(unsafe { TensorMap::from_raw(ptr) });
406 }
407
408 #[inline]
410 pub fn iter(&self) -> TensorMapIter<'_> {
411 return TensorMapIter {
412 inner: self.keys().iter().zip(self.blocks())
413 };
414 }
415
416 #[inline]
419 pub fn iter_mut(&mut self) -> TensorMapIterMut<'_> {
420 let mut blocks = Vec::new();
423 for i in 0..self.keys().count() {
424 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
425 }
426
427 return TensorMapIterMut {
428 inner: self.keys().into_iter().zip(blocks)
429 };
430 }
431
432 #[cfg(feature = "rayon")]
434 #[inline]
435 pub fn par_iter(&self) -> TensorMapParIter {
436 use rayon::prelude::*;
437 TensorMapParIter {
438 inner: self.keys().par_iter().zip_eq(self.blocks().into_par_iter())
439 }
440 }
441
442 #[cfg(feature = "rayon")]
445 #[inline]
446 pub fn par_iter_mut(&mut self) -> TensorMapParIterMut {
447 use rayon::prelude::*;
448
449 let mut blocks = Vec::new();
452 for i in 0..self.keys().count() {
453 blocks.push(unsafe { TensorMap::raw_block_mut_by_id(self.ptr, i) });
454 }
455
456 TensorMapParIterMut {
457 inner: self.keys().par_iter().zip_eq(blocks)
458 }
459 }
460}
461
462pub struct TensorMapIter<'a> {
466 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRef<'a>>>
467}
468
469impl<'a> Iterator for TensorMapIter<'a> {
470 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
471
472 #[inline]
473 fn next(&mut self) -> Option<Self::Item> {
474 self.inner.next()
475 }
476
477 fn size_hint(&self) -> (usize, Option<usize>) {
478 self.inner.size_hint()
479 }
480}
481
482impl ExactSizeIterator for TensorMapIter<'_> {
483 #[inline]
484 fn len(&self) -> usize {
485 self.inner.len()
486 }
487}
488
489impl FusedIterator for TensorMapIter<'_> {}
490
491impl<'a> IntoIterator for &'a TensorMap {
492 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
493
494 type IntoIter = TensorMapIter<'a>;
495
496 fn into_iter(self) -> Self::IntoIter {
497 self.iter()
498 }
499}
500
501pub struct TensorMapIterMut<'a> {
506 inner: std::iter::Zip<crate::labels::LabelsIter<'a>, std::vec::IntoIter<TensorBlockRefMut<'a>>>
507}
508
509impl<'a> Iterator for TensorMapIterMut<'a> {
510 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
511
512 #[inline]
513 fn next(&mut self) -> Option<Self::Item> {
514 self.inner.next()
515 }
516
517 fn size_hint(&self) -> (usize, Option<usize>) {
518 self.inner.size_hint()
519 }
520}
521
522impl ExactSizeIterator for TensorMapIterMut<'_> {
523 #[inline]
524 fn len(&self) -> usize {
525 self.inner.len()
526 }
527}
528
529impl FusedIterator for TensorMapIterMut<'_> {}
530
531impl<'a> IntoIterator for &'a mut TensorMap {
532 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
533
534 type IntoIter = TensorMapIterMut<'a>;
535
536 fn into_iter(self) -> Self::IntoIter {
537 self.iter_mut()
538 }
539}
540
541
542#[cfg(feature = "rayon")]
546pub struct TensorMapParIter<'a> {
547 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRef<'a>>>,
548}
549
550#[cfg(feature = "rayon")]
551impl<'a> rayon::iter::ParallelIterator for TensorMapParIter<'a> {
552 type Item = (&'a [LabelValue], TensorBlockRef<'a>);
553
554 #[inline]
555 fn drive_unindexed<C>(self, consumer: C) -> C::Result
556 where
557 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
558 self.inner.drive_unindexed(consumer)
559 }
560}
561
562#[cfg(feature = "rayon")]
563impl rayon::iter::IndexedParallelIterator for TensorMapParIter<'_> {
564 #[inline]
565 fn len(&self) -> usize {
566 self.inner.len()
567 }
568
569 #[inline]
570 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
571 self.inner.drive(consumer)
572 }
573
574 #[inline]
575 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
576 self.inner.with_producer(callback)
577 }
578}
579
580#[cfg(feature = "rayon")]
585pub struct TensorMapParIterMut<'a> {
586 inner: rayon::iter::ZipEq<crate::labels::LabelsParIter<'a>, rayon::vec::IntoIter<TensorBlockRefMut<'a>>>,
587}
588
589#[cfg(feature = "rayon")]
590impl<'a> rayon::iter::ParallelIterator for TensorMapParIterMut<'a> {
591 type Item = (&'a [LabelValue], TensorBlockRefMut<'a>);
592
593 #[inline]
594 fn drive_unindexed<C>(self, consumer: C) -> C::Result
595 where
596 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> {
597 self.inner.drive_unindexed(consumer)
598 }
599}
600
601#[cfg(feature = "rayon")]
602impl rayon::iter::IndexedParallelIterator for TensorMapParIterMut<'_> {
603 #[inline]
604 fn len(&self) -> usize {
605 self.inner.len()
606 }
607
608 #[inline]
609 fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
610 self.inner.drive(consumer)
611 }
612
613 #[inline]
614 fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
615 self.inner.with_producer(callback)
616 }
617}
618
619#[cfg(test)]
622mod tests {
623 use crate::{Labels, TensorBlock, TensorMap};
624
625 #[test]
626 #[allow(clippy::cast_lossless, clippy::float_cmp)]
627 fn iter() {
628 let block_1 = TensorBlock::new(
629 ndarray::ArrayD::from_elem(vec![2, 3], 1.0),
630 &Labels::new(["samples"], &[[0], [1]]),
631 &[],
632 &Labels::new(["properties"], &[[-2], [0], [1]]),
633 ).unwrap();
634
635 let block_2 = TensorBlock::new(
636 ndarray::ArrayD::from_elem(vec![1, 1], 3.0),
637 &Labels::new(["samples"], &[[1]]),
638 &[],
639 &Labels::new(["properties"], &[[1]]),
640 ).unwrap();
641
642 let block_3 = TensorBlock::new(
643 ndarray::ArrayD::from_elem(vec![3, 2], -4.0),
644 &Labels::new(["samples"], &[[0], [1], [3]]),
645 &[],
646 &Labels::new(["properties"], &[[-2], [1]]),
647 ).unwrap();
648
649 let mut tensor = TensorMap::new(
650 Labels::new(["key"], &[[1], [3], [-4]]),
651 vec![block_1, block_2, block_3],
652 ).unwrap();
653
654 for (key, block) in &tensor {
656 assert_eq!(block.values().to_array()[[0, 0]], key[0].i32() as f64);
657 }
658
659 for (key, mut block) in &mut tensor {
661 let array = block.values_mut().to_array_mut();
662 *array *= 2.0;
663 assert_eq!(array[[0, 0]], 2.0 * (key[0].i32() as f64));
664 }
665 }
666}