Composition model¶
- class metatrain.utils.additive.composition.CompositionModel(hypers: Dict, dataset_info: DatasetInfo)[source]¶
Bases:
Module
A simple model that calculates the per-species contributions to targets based on the stoichiometry in a system.
- Parameters:
hypers (Dict) – A dictionary of model hyperparameters. This parameter is ignored and is only present to be consistent with the general model API.
dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- dataset_info¶
An
DatasetInfo
containing information about the dataset, including target quantities and atomic types.
- atomic_types¶
The list of atomic types used in the composition model.
- target_infos¶
A dictionary with a
TargetInfo
for each target that can be predicted by the model.
- model¶
The underlying composition model that handles the accumulation and fitting of the weights.
- outputs: Dict[str, ModelOutput]¶
A dictionary with a
metatomic.torch.ModelOutput
for each target that can be predicted by the model.
- train_model(datasets: List[Dataset | Subset], additive_models: List[Module], batch_size: int, is_distributed: bool, fixed_weights: Dict[str, Dict[int, float]] | None = None) None [source]¶
Train the composition model on the provided training data in the
datasets
.Assumes the systems are stored in the
system
attribute of each sample, with targets expected to be stored as well, with keys corresponding to the target names defined in the dataset info.Any additive contributions from the provided
additive_models
will be removed from the targets before training. The fixed_weights argument can be used to specify which targets should be treated as fixed weights during training.- Parameters:
datasets (List[Dataset | Subset]) – A list of datasets to use for training.
additive_models (List[Module]) – A list of additive models whose contributions will be removed from the targets before training.
batch_size (int) – The batch size to use for training.
is_distributed (bool) – Whether to use distributed sampling for the dataloader.
fixed_weights (Dict[str, Dict[int, float]] | None) – A dictionary specifying which targets should be treated as fixed weights during training. The keys are target names, and the values are dictionaries mapping atomic types to their fixed weights. If None, no weights are treated as fixed.
- Return type:
None
- restart(dataset_info: DatasetInfo) CompositionModel [source]¶
Restart the model with a new dataset info.
- Parameters:
dataset_info (DatasetInfo) – New dataset information to be used.
- Returns:
An instance of the restarted model.
- Return type:
- forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap] [source]¶
Compute the targets for each system based on the composition weights.
- Parameters:
- Returns:
A dictionary with the computed predictions for each system.
- Raises:
ValueError – If no weights have been computed or if outputs keys contain unsupported keys.
- Return type:
- supported_outputs() Dict[str, ModelOutput] [source]¶
- Return type: