Removing additive contributions

metatrain.utils.additive.remove.remove_additive(systems: List[System], targets: Dict[str, TensorMap], additive_model: Module, target_info_dict: Dict[str, TargetInfo]) Dict[str, TensorMap][source]

Remove an additive contribution from the training targets.

Parameters:
  • systems (List[System]) – List of systems.

  • targets (Dict[str, TensorMap]) – Dictionary containing the targets corresponding to the systems.

  • additive_model (Module) – The model used to calculate the additive contribution to be removed.

  • target_info_dict (Dict[str, TargetInfo]) – Dictionary containing information about the targets.

Returns:

The updated targets, with the additive contribution removed.

Return type:

Dict[str, TensorMap]

metatrain.utils.additive.remove.get_remove_additive_transform(additive_models: List[Module], target_info_dict: Dict[str, TargetInfo]) Callable[source]

Get a function that removes the additive contributions from the targets.

Parameters:
  • additive_models (List[Module]) – A list of additive models to use to remove the contributions.

  • target_info_dict (Dict[str, TargetInfo]) – A dictionary containing information about the targets.

Returns:

A function that takes in systems, targets and extra data, and returns the systems, updated targets and extra data.

Return type:

Callable