Composition model¶
- class metatrain.utils.additive.composition.CompositionModel(model_hypers: Dict, dataset_info: DatasetInfo)[source]¶
Bases:
Module
A simple model that calculates the contributions to scalar targets based on the stoichiometry in a system.
- Parameters:
model_hypers (Dict) – A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API.
dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- outputs: Dict[str, ModelOutput]¶
- train_model(datasets: List[Dataset | Subset], additive_models: List[Module], fixed_weights: Dict[str, Dict[int, str]] | None = None) None [source]¶
Train/fit the composition weights for the datasets.
- Parameters:
datasets (List[Dataset | Subset]) – Dataset(s) to calculate the composition weights for.
fixed_weights (Dict[str, Dict[int, str]] | None) – Optional fixed weights to use for the composition model, for one or more target quantities.
additive_models (List[Module]) – Additive models to be removed from the targets before calculating the statistics.
- Raises:
ValueError – If the provided datasets contain unknown targets.
ValueError – If the provided datasets contain unknown atomic types.
RuntimeError – If the linear system to calculate the composition weights cannot be solved.
- Return type:
None
- restart(dataset_info: DatasetInfo) CompositionModel [source]¶
Restart the model with a new dataset info.
- Parameters:
dataset_info (DatasetInfo) – New dataset information to be used.
- Return type:
- forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap] [source]¶
Compute the targets for each system based on the composition weights.
- Parameters:
- Returns:
A dictionary with the computed predictions for each system.
- Raises:
ValueError – If no weights have been computed or if outputs keys contain unsupported keys.
- Return type: