Scaler

class metatrain.utils.scaler.Scaler(model_hypers: Dict, dataset_info: DatasetInfo)[source]

Bases: Module

A class that scales the targets of regression problems to unit standard deviation.

In most cases, this should be used in conjunction with a composition model (that removes the multi-dimensional “mean” across the composition space) and/or other additive models. See the train_model method for more details.

The scaling is performed per-atom, i.e., in cases where the targets are per-structure, the standard deviation is calculated on the targets divided by the number of atoms in each structure.

Parameters:
  • model_hypers (Dict) – A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API.

  • dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

scales: Tensor
outputs: Dict[str, ModelOutput]
train_model(datasets: List[Dataset | Subset], additive_models: List[Module], treat_as_additive: bool) None[source]

Calculate the scaling weights for all the targets in the datasets.

Parameters:
  • datasets (List[Dataset | Subset]) – Dataset(s) to calculate the scaling weights for.

  • additive_models (List[Module]) – Additive models to be removed from the targets before calculating the statistics.

  • treat_as_additive (bool) – If True, all per-structure targets (i.e. those that) do not contain an atom label name, are treated as additive.

Raises:

ValueError – If the provided datasets contain targets unknown to the scaler or if the targets are not treated as additive.

Return type:

None

restart(dataset_info: DatasetInfo) Scaler[source]
Parameters:

dataset_info (DatasetInfo)

Return type:

Scaler

forward(outputs: Dict[str, TensorMap]) Dict[str, TensorMap][source]

Scales all the targets in the outputs dictionary back to their original scale.

Parameters:

outputs (Dict[str, TensorMap]) – A dictionary of target quantities and their values to be scaled.

Raises:

ValueError – If an output does not have a corresponding scale in the scaler model.

Return type:

Dict[str, TensorMap]

get_scales_dict() Dict[str, float][source]

Return a dictionary with the scales for each output and output gradient.

Returns:

A dictionary with the scales for each output and output gradient. These correspond to the standard deviation of the targets in the original dataset. The scales for each output gradient are the same as the corresponding output.

Return type:

Dict[str, float]

metatrain.utils.scaler.remove_scale(targets: Dict[str, TensorMap], scaler: Scaler)[source]

Scale all targets to a standard deviation of one.

Parameters:
  • targets (Dict[str, TensorMap]) – Dictionary containing the targets to be scaled.

  • scaler (Scaler) – The scaler used to scale the targets.