Composition model¶
- class metatrain.utils.additive.composition.CompositionModel(model_hypers: Dict, dataset_info: DatasetInfo)[source]¶
Bases:
Module
A simple model that calculates the per-species contributions to targets based on the stoichiometry in a system.
- Parameters:
model_hypers (Dict) – A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API.
dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- outputs: Dict[str, ModelOutput]¶
- train_model(dataloader: DataLoader, additive_models: List[Module], fixed_weights: Dict[str, Dict[int, float]] | None = None) None [source]¶
Train the composition model on the provided training data in the
dataloader
.Assumes the systems are stored in the
system
attribute of the batch. Targets are expected to be in the batch as well, with keys corresponding to the target names defined in the dataset info.Any additive contributions from the provided
additive_models
will be removed from the targets before training. The fixed_weights argument can be used to specify which targets should be treated as fixed weights during training.
- restart(dataset_info: DatasetInfo) CompositionModel [source]¶
Restart the model with a new dataset info.
- Parameters:
dataset_info (DatasetInfo) – New dataset information to be used.
- Return type:
- forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap] [source]¶
Compute the targets for each system based on the composition weights.
- Parameters:
- Returns:
A dictionary with the computed predictions for each system.
- Raises:
ValueError – If no weights have been computed or if outputs keys contain unsupported keys.
- Return type:
- supported_outputs() Dict[str, ModelOutput] [source]¶
- Return type: