Augmentation

metatrain.utils.augmentation.get_random_rotation() Rotation[source]

Random 3D rotation that is Haar uniformly distributed over SO(3).

Returns:

a random 3D rotation

Return type:

Rotation

metatrain.utils.augmentation.get_random_inversion() int[source]

Randomly choose an inversion factor -1 or 1.

Returns:

either -1 or 1

Return type:

int

class metatrain.utils.augmentation.RotationalAugmenter(target_info_dict: Dict[str, TargetInfo], extra_data_info_dict: Dict[str, TargetInfo] | None = None)[source]

Bases: object

A class to apply random rotations and inversions to a set of systems and their targets.

Parameters:
  • target_info_dict (Dict[str, TargetInfo]) – A dictionary mapping target names to their corresponding TargetInfo objects. This is used to determine the type of targets and how to apply the augmentations.

  • extra_data_info_dict (Dict[str, TargetInfo] | None) – An optional dictionary mapping extra data names to their corresponding TargetInfo objects. This is used to determine the type of extra data and how to apply the augmentations.

apply_random_augmentations(systems: List[System], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]][source]

Apply a random augmentation to a number of System objects and its targets.

Parameters:
  • systems (List[System]) – A list of System objects to be augmented.

  • targets (Dict[str, TensorMap]) – A dictionary mapping target names to their corresponding TensorMap objects. These are the targets to be augmented.

  • extra_data (Dict[str, TensorMap] | None) – An optional dictionary mapping extra data names to their corresponding TensorMap objects. This extra data will also be augmented if provided.

Returns:

A tuple containing the augmented systems and targets.

Return type:

Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]