Scaler¶
- class metatrain.utils.scaler.Scaler(model_hypers: Dict, dataset_info: DatasetInfo)[source]¶
Bases:
Module
A class that scales the targets of regression problems to unit standard deviation.
In most cases, this should be used in conjunction with a composition model (that removes the multi-dimensional “mean” across the composition space) and/or other additive models. See the train_model method for more details.
The scaling is performed per-atom, i.e., in cases where the targets are per-structure, the standard deviation is calculated on the targets divided by the number of atoms in each structure.
- Parameters:
model_hypers (Dict) – A dictionary of model hyperparameters. The paramater is ignored and is only present to be consistent with the general model API.
dataset_info (DatasetInfo) – An object containing information about the dataset, including target quantities and atomic types.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- outputs: Dict[str, ModelOutput]¶
- train_model(datasets: List[Dataset | Subset], additive_models: List[Module], treat_as_additive: bool) None [source]¶
Calculate the scaling weights for all the targets in the datasets.
- Parameters:
datasets (List[Dataset | Subset]) – Dataset(s) to calculate the scaling weights for.
additive_models (List[Module]) – Additive models to be removed from the targets before calculating the statistics.
treat_as_additive (bool) – If True, all per-structure targets (i.e. those that) do not contain an
atom
label name, are treated as additive.
- Raises:
ValueError – If the provided datasets contain targets unknown to the scaler or if the targets are not treated as additive.
- Return type:
None
- restart(dataset_info: DatasetInfo) Scaler [source]¶
- Parameters:
dataset_info (DatasetInfo)
- Return type:
- forward(outputs: Dict[str, TensorMap]) Dict[str, TensorMap] [source]¶
Scales all the targets in the outputs dictionary back to their original scale.