Loss

class metatrain.utils.loss.TensorMapLoss(reduction: str = 'mean', weight: float = 1.0, gradient_weights: Dict[str, float] | None = None, type: str | dict = 'mse')[source]

Bases: object

A loss function that operates on two metatensor.torch.TensorMap.

The loss is computed as the sum of the loss on the block values and the loss on the gradients, with weights specified at initialization.

At the moment, this loss function assumes that all the gradients declared at initialization are present in both TensorMaps.

Parameters:
  • reduction (str) – The reduction to apply to the loss. See torch.nn.MSELoss.

  • weight (float) – The weight to apply to the loss on the block values.

  • gradient_weights (Dict[str, float] | None) – The weights to apply to the loss on the gradients.

  • type (str | dict)

Returns:

The loss as a zero-dimensional torch.Tensor (with one entry).

class metatrain.utils.loss.TensorMapDictLoss(weights: Dict[str, float], reduction: str = 'mean', type: str | dict = 'mse')[source]

Bases: object

A loss function that operates on two Dict[str, metatensor.torch.TensorMap].

At initialization, the user specifies a list of keys to use for the loss, along with a weight for each key.

The loss is then computed as a weighted sum. Any keys that are not present in the dictionaries are ignored.

Parameters:
  • weights (Dict[str, float]) – A dictionary mapping keys to weights. This might contain gradient keys, in the form <output_name>_<gradient_name>_gradients.

  • reduction (str) – The reduction to apply to the loss. See torch.nn.MSELoss.

  • type (str | dict)

Returns:

The loss as a zero-dimensional torch.Tensor (with one entry).