Augmentation¶
- metatrain.utils.augmentation.get_random_rotation() Rotation [source]¶
Random 3D rotation that is Haar uniformly distributed over SO(3).
- Returns:
a random 3D rotation
- Return type:
Rotation
- metatrain.utils.augmentation.get_random_inversion() int [source]¶
Randomly choose an inversion factor -1 or 1.
- Returns:
either -1 or 1
- Return type:
- class metatrain.utils.augmentation.RotationalAugmenter(target_info_dict: Dict[str, TargetInfo], extra_data_info_dict: Dict[str, TargetInfo] | None = None)[source]¶
Bases:
object
A class to apply random rotations and inversions to a set of systems and their targets.
- Parameters:
target_info_dict (Dict[str, TargetInfo]) – A dictionary mapping target names to their corresponding
TargetInfo
objects. This is used to determine the type of targets and how to apply the augmentations.extra_data_info_dict (Dict[str, TargetInfo] | None) – An optional dictionary mapping extra data names to their corresponding
TargetInfo
objects. This is used to determine the type of extra data and how to apply the augmentations.
- apply_random_augmentations(systems: List[System], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]] [source]¶
Apply a random augmentation to a number of
System
objects and its targets.- Parameters:
systems (List[System]) – A list of
System
objects to be augmented.targets (Dict[str, TensorMap]) – A dictionary mapping target names to their corresponding
TensorMap
objects. These are the targets to be augmented.extra_data (Dict[str, TensorMap] | None) – An optional dictionary mapping extra data names to their corresponding
TensorMap
objects. This extra data will also be augmented if provided.
- Returns:
A tuple containing the augmented systems and targets.
- Return type:
Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]