Composition model

class metatrain.utils.additive.composition.CompositionModel(model_hypers: Dict, dataset_info: DatasetInfo)[source]

Bases: Module

A simple model that calculates the contributions to scalar targets based on the stoichiometry in a system.

Parameters:
  • model_hypers (Dict) – A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API.

  • dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

weights: Dict[str, TensorMap]
outputs: Dict[str, ModelOutput]
train_model(datasets: List[Dataset | Subset], additive_models: List[Module], fixed_weights: Dict[str, Dict[int, str]] | None = None) None[source]

Train/fit the composition weights for the datasets.

Parameters:
  • datasets (List[Dataset | Subset]) – Dataset(s) to calculate the composition weights for.

  • fixed_weights (Dict[str, Dict[int, str]] | None) – Optional fixed weights to use for the composition model, for one or more target quantities.

  • additive_models (List[Module]) – Additive models to be removed from the targets before calculating the statistics.

Raises:
  • ValueError – If the provided datasets contain unknown targets.

  • ValueError – If the provided datasets contain unknown atomic types.

  • RuntimeError – If the linear system to calculate the composition weights cannot be solved.

Return type:

None

restart(dataset_info: DatasetInfo) CompositionModel[source]

Restart the model with a new dataset info.

Parameters:

dataset_info (DatasetInfo) – New dataset information to be used.

Return type:

CompositionModel

forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap][source]

Compute the targets for each system based on the composition weights.

Parameters:
  • systems (List[System]) – List of systems to calculate the energy.

  • outputs (Dict[str, ModelOutput]) – Dictionary containing the model outputs.

  • selected_atoms (Labels | None) – Optional selection of atoms for which to compute the predictions.

Returns:

A dictionary with the computed predictions for each system.

Raises:

ValueError – If no weights have been computed or if outputs keys contain unsupported keys.

Return type:

Dict[str, TensorMap]

static is_valid_target(target_name: str, target_info: TargetInfo) bool[source]

Finds if a TargetInfo object is compatible with a composition model.

Parameters:
  • target_info (TargetInfo) – The TargetInfo object to be checked.

  • target_name (str)

Return type:

bool