Augmentation

metatrain.utils.augmentation.get_random_rotation()[source]
metatrain.utils.augmentation.get_random_inversion()[source]
class metatrain.utils.augmentation.RotationalAugmenter(target_info_dict: Dict[str, TargetInfo])[source]

Bases: object

A class to apply random rotations and inversions to a set of systems and their targets.

Parameters:

target_info_dict (Dict[str, TargetInfo]) – A dictionary mapping target names to their corresponding TargetInfo objects. This is used to determine the type of targets and how to apply the augmentations.

apply_random_augmentations(systems: List[System], targets: Dict[str, TensorMap]) Tuple[List[System], Dict[str, TensorMap]][source]

Apply a random augmentation to a number of System objects and its targets.

Parameters:
  • systems (List[System]) – A list of System objects to be augmented.

  • targets (Dict[str, TensorMap]) – A dictionary mapping target names to their corresponding TensorMap objects. These are the targets to be augmented.

Returns:

A tuple containing the augmented systems and targets.

Return type:

Tuple[List[System], Dict[str, TensorMap]]