Composition model

class metatrain.utils.additive.composition.CompositionModel(model_hypers: Dict, dataset_info: DatasetInfo)[source]

Bases: Module

A simple model that calculates the per-species contributions to targets based on the stoichiometry in a system.

Parameters:
  • model_hypers (Dict) – A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API.

  • dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

outputs: Dict[str, ModelOutput]
train_model(dataloader: DataLoader, additive_models: List[Module], fixed_weights: Dict[str, Dict[int, float]] | None = None) None[source]

Train the composition model on the provided training data in the dataloader.

Assumes the systems are stored in the system attribute of the batch. Targets are expected to be in the batch as well, with keys corresponding to the target names defined in the dataset info.

Any additive contributions from the provided additive_models will be removed from the targets before training. The fixed_weights argument can be used to specify which targets should be treated as fixed weights during training.

Parameters:
Return type:

None

restart(dataset_info: DatasetInfo) CompositionModel[source]

Restart the model with a new dataset info.

Parameters:

dataset_info (DatasetInfo) – New dataset information to be used.

Return type:

CompositionModel

forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap][source]

Compute the targets for each system based on the composition weights.

Parameters:
  • systems (List[System]) – List of systems to calculate the energy.

  • outputs (Dict[str, ModelOutput]) – Dictionary containing the model outputs.

  • selected_atoms (Labels | None) – Optional selection of samples for which to compute the predictions.

Returns:

A dictionary with the computed predictions for each system.

Raises:

ValueError – If no weights have been computed or if outputs keys contain unsupported keys.

Return type:

Dict[str, TensorMap]

supported_outputs() Dict[str, ModelOutput][source]
Return type:

Dict[str, ModelOutput]

static is_valid_target(target_name: str, target_info: TargetInfo) bool[source]

Finds if a TargetInfo object is compatible with a composition model.

Parameters:
  • target_info (TargetInfo) – The TargetInfo object to be checked.

  • target_name (str)

Return type:

bool

sync_tensor_maps()[source]