from typing import Dict, Optional, Tuple, Union
import torch
from metatensor.torch import TensorMap
from omegaconf import DictConfig
from torch.nn.modules.loss import _Loss
from metatrain.utils.external_naming import to_internal_name
[docs]
class TensorMapLoss:
"""A loss function that operates on two ``metatensor.torch.TensorMap``.
The loss is computed as the sum of the loss on the block values and
the loss on the gradients, with weights specified at initialization.
At the moment, this loss function assumes that all the gradients
declared at initialization are present in both TensorMaps.
:param reduction: The reduction to apply to the loss.
See :py:class:`torch.nn.MSELoss`.
:param weight: The weight to apply to the loss on the block values.
:param gradient_weights: The weights to apply to the loss on the gradients.
:param sliding_factor: The factor to apply to the exponential moving average
of the "sliding" weights. These are weights that act on different components of
the loss (for example, energies and forces), based on their individual recent
history. If ``None``, no sliding weights are used in the computation of the
loss.
:param type: The type of loss to use. This can be either "mse" or "mae".
A Huber loss can also be requested as a dictionary with the key "huber" and
the value must be a dictionary with the key "deltas" and the value
must be a dictionary with the keys "values" and the gradient keys.
The values of the dictionary must be the deltas to use for the
Huber loss.
:returns: The loss as a zero-dimensional :py:class:`torch.Tensor`
(with one entry).
"""
def __init__(
self,
reduction: str = "mean",
weight: float = 1.0,
gradient_weights: Optional[Dict[str, float]] = None,
sliding_factor: Optional[float] = None,
type: Union[str, dict] = "mse",
):
if gradient_weights is None:
gradient_weights = {}
losses = {}
if type == "mse":
losses["values"] = torch.nn.MSELoss(reduction=reduction)
for key in gradient_weights.keys():
losses[key] = torch.nn.MSELoss(reduction=reduction)
elif type == "mae":
losses["values"] = torch.nn.L1Loss(reduction=reduction)
for key in gradient_weights.keys():
losses[key] = torch.nn.L1Loss(reduction=reduction)
elif isinstance(type, dict) and "huber" in type:
# Huber loss
deltas = type["huber"]["deltas"]
losses["values"] = torch.nn.HuberLoss(
reduction=reduction, delta=deltas["values"]
)
for key in gradient_weights.keys():
losses[key] = torch.nn.HuberLoss(reduction=reduction, delta=deltas[key])
else:
raise ValueError(f"Unknown loss type: {type}")
self.losses = losses
self.weight = weight
self.gradient_weights = gradient_weights
self.sliding_factor = sliding_factor
self.sliding_weights: Optional[Dict[str, TensorMap]] = None
def __call__(
self,
predictions_tensor_map: TensorMap,
targets_tensor_map: TensorMap,
) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]:
# Check that the two have the same metadata, except for the samples,
# which can be different due to batching, but must have the same size:
if predictions_tensor_map.keys != targets_tensor_map.keys:
raise ValueError(
"TensorMapLoss requires the two TensorMaps to have the same keys."
)
for block_1, block_2 in zip(
predictions_tensor_map.blocks(), targets_tensor_map.blocks()
):
if block_1.properties != block_2.properties:
raise ValueError(
"TensorMapLoss requires the two TensorMaps to have the same "
"properties."
)
if block_1.components != block_2.components:
raise ValueError(
"TensorMapLoss requires the two TensorMaps to have the same "
"components."
)
if len(block_1.samples) != len(block_2.samples):
raise ValueError(
"TensorMapLoss requires the two TensorMaps "
"to have the same number of samples."
)
for gradient_name in block_2.gradients_list():
if len(block_1.gradient(gradient_name).samples) != len(
block_2.gradient(gradient_name).samples
):
raise ValueError(
"TensorMapLoss requires the two TensorMaps "
"to have the same number of gradient samples."
)
if (
block_1.gradient(gradient_name).properties
!= block_2.gradient(gradient_name).properties
):
raise ValueError(
"TensorMapLoss requires the two TensorMaps "
"to have the same gradient properties."
)
if (
block_1.gradient(gradient_name).components
!= block_2.gradient(gradient_name).components
):
raise ValueError(
"TensorMapLoss requires the two TensorMaps "
"to have the same gradient components."
)
# First time the function is called: compute the sliding weights only
# from the targets (if they are enabled)
if self.sliding_factor is not None and self.sliding_weights is None:
self.sliding_weights = get_sliding_weights(
self.losses,
self.sliding_factor,
targets_tensor_map,
)
# Compute the loss:
loss = torch.zeros(
(),
dtype=predictions_tensor_map.block(0).values.dtype,
device=predictions_tensor_map.block(0).values.device,
)
for key in targets_tensor_map.keys:
block_1 = predictions_tensor_map.block(key)
block_2 = targets_tensor_map.block(key)
values_1 = block_1.values
values_2 = block_2.values
# sliding weights: default to 1.0 if not used/provided for this target
sliding_weight = (
1.0
if self.sliding_weights is None
else self.sliding_weights.get("values", 1.0)
)
loss += (
self.weight * self.losses["values"](values_1, values_2) / sliding_weight
)
for gradient_name in block_2.gradients_list():
gradient_weight = self.gradient_weights[gradient_name]
values_1 = block_1.gradient(gradient_name).values
values_2 = block_2.gradient(gradient_name).values
# sliding weights: default to 1.0 if not used/provided for this target
sliding_weigths_value = (
1.0
if self.sliding_weights is None
else self.sliding_weights.get(gradient_name, 1.0)
)
loss += (
gradient_weight
* self.losses[gradient_name](values_1, values_2)
/ sliding_weigths_value
)
if self.sliding_factor is not None:
self.sliding_weights = get_sliding_weights(
self.losses,
self.sliding_factor,
targets_tensor_map,
predictions_tensor_map,
self.sliding_weights,
)
return loss
[docs]
class TensorMapDictLoss:
"""A loss function that operates on two ``Dict[str, metatensor.torch.TensorMap]``.
At initialization, the user specifies a list of keys to use for the loss,
along with a weight for each key.
The loss is then computed as a weighted sum. Any keys that are not present
in the dictionaries are ignored.
:param weights: A dictionary mapping keys to weights. This might contain
gradient keys, in the form ``<output_name>_<gradient_name>_gradients``.
:param sliding_factor: The factor to apply to the exponential moving average
of the "sliding" weights. These are weights that act on different components of
the loss (for example, energies and forces), based on their individual recent
history. If ``None``, no sliding weights are used in the computation of the
loss.
:param reduction: The reduction to apply to the loss.
See :py:class:`torch.nn.MSELoss`.
:returns: The loss as a zero-dimensional :py:class:`torch.Tensor`
(with one entry).
"""
def __init__(
self,
weights: Dict[str, float],
sliding_factor: Optional[float] = None,
reduction: str = "mean",
type: Union[str, dict] = "mse",
):
outputs = [key for key in weights.keys() if "gradients" not in key]
self.losses = {}
for output in outputs:
value_weight = weights[output]
gradient_weights = {}
for key, weight in weights.items():
if key.startswith(output) and key.endswith("_gradients"):
gradient_name = key.replace(f"{output}_", "").replace(
"_gradients", ""
)
gradient_weights[gradient_name] = weight
type_output = _process_type(type, output)
if output == "energy" and sliding_factor is not None:
self.losses[output] = TensorMapLoss(
reduction=reduction,
weight=value_weight,
gradient_weights=gradient_weights,
sliding_factor=sliding_factor,
type=type_output,
)
else:
self.losses[output] = TensorMapLoss(
reduction=reduction,
weight=value_weight,
gradient_weights=gradient_weights,
type=type_output,
)
def __call__(
self,
tensor_map_dict_1: Dict[str, TensorMap],
tensor_map_dict_2: Dict[str, TensorMap],
) -> torch.Tensor:
# Assert that the two have the keys:
assert set(tensor_map_dict_1.keys()) == set(tensor_map_dict_2.keys())
# Initialize the loss:
first_values = next(iter(tensor_map_dict_1.values())).block(0).values
loss = torch.zeros((), dtype=first_values.dtype, device=first_values.device)
# Compute the loss:
for target in tensor_map_dict_1.keys():
target_loss = self.losses[target](
tensor_map_dict_1[target], tensor_map_dict_2[target]
)
loss += target_loss
return loss
[docs]
def get_sliding_weights(
losses: Dict[str, _Loss],
sliding_factor: float,
targets: TensorMap,
predictions: Optional[TensorMap] = None,
previous_sliding_weights: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""
Compute the sliding weights for the loss function.
The sliding weights are computed as the absolute difference between the
predictions and the targets.
:param predictions: The predictions.
:param targets: The targets.
:return: The sliding weights.
"""
sliding_weights = {}
if predictions is None:
for block in targets.blocks():
values = block.values
sliding_weights["values"] = (
losses["values"](values, values.mean() * torch.ones_like(values)) + 1e-6
)
for gradient_name, gradient_block in block.gradients():
values = gradient_block.values
sliding_weights[gradient_name] = losses[gradient_name](
values, torch.zeros_like(values)
)
elif predictions is not None:
if previous_sliding_weights is None:
raise RuntimeError(
"previous_sliding_weights must be provided if predictions is not None"
)
else:
for predictions_block, target_block in zip(
predictions.blocks(), targets.blocks()
):
target_values = target_block.values
predictions_values = predictions_block.values
sliding_weights["values"] = (
sliding_factor * previous_sliding_weights["values"]
+ (1 - sliding_factor)
* losses["values"](predictions_values, target_values).detach()
)
for gradient_name, gradient_block in target_block.gradients():
target_values = gradient_block.values
predictions_values = predictions_block.gradient(
gradient_name
).values
sliding_weights[gradient_name] = (
sliding_factor * previous_sliding_weights[gradient_name]
+ (1 - sliding_factor)
* losses[gradient_name](
predictions_values, target_values
).detach()
)
return sliding_weights
def _process_type(type: Union[str, DictConfig], output: str) -> Union[str, dict]:
if not isinstance(type, str):
assert "huber" in type
# we process the Huber loss delta dict to make it similar to the
# `weights` dict
type_output = {"huber": {"deltas": {}}} # type: ignore
for key, delta in type["huber"]["deltas"].items():
key_internal = to_internal_name(key)
if key_internal == output:
type_output["huber"]["deltas"]["values"] = delta
elif key_internal.startswith(output) and key_internal.endswith(
"_gradients"
):
gradient_name = key_internal.replace(f"{output}_", "").replace(
"_gradients", ""
)
type_output["huber"]["deltas"][gradient_name] = delta
else:
pass
else:
type_output = type # type: ignore
return type_output