Source code for metatrain.utils.loss

from typing import Dict, Optional, Tuple, Union

import torch
from metatensor.torch import TensorMap
from omegaconf import DictConfig
from torch.nn.modules.loss import _Loss

from metatrain.utils.external_naming import to_internal_name


[docs] class TensorMapLoss: """A loss function that operates on two ``metatensor.torch.TensorMap``. The loss is computed as the sum of the loss on the block values and the loss on the gradients, with weights specified at initialization. At the moment, this loss function assumes that all the gradients declared at initialization are present in both TensorMaps. :param reduction: The reduction to apply to the loss. See :py:class:`torch.nn.MSELoss`. :param weight: The weight to apply to the loss on the block values. :param gradient_weights: The weights to apply to the loss on the gradients. :param sliding_factor: The factor to apply to the exponential moving average of the "sliding" weights. These are weights that act on different components of the loss (for example, energies and forces), based on their individual recent history. If ``None``, no sliding weights are used in the computation of the loss. :param type: The type of loss to use. This can be either "mse" or "mae". A Huber loss can also be requested as a dictionary with the key "huber" and the value must be a dictionary with the key "deltas" and the value must be a dictionary with the keys "values" and the gradient keys. The values of the dictionary must be the deltas to use for the Huber loss. :returns: The loss as a zero-dimensional :py:class:`torch.Tensor` (with one entry). """ def __init__( self, reduction: str = "mean", weight: float = 1.0, gradient_weights: Optional[Dict[str, float]] = None, sliding_factor: Optional[float] = None, type: Union[str, dict] = "mse", ): if gradient_weights is None: gradient_weights = {} losses = {} if type == "mse": losses["values"] = torch.nn.MSELoss(reduction=reduction) for key in gradient_weights.keys(): losses[key] = torch.nn.MSELoss(reduction=reduction) elif type == "mae": losses["values"] = torch.nn.L1Loss(reduction=reduction) for key in gradient_weights.keys(): losses[key] = torch.nn.L1Loss(reduction=reduction) elif isinstance(type, dict) and "huber" in type: # Huber loss deltas = type["huber"]["deltas"] losses["values"] = torch.nn.HuberLoss( reduction=reduction, delta=deltas["values"] ) for key in gradient_weights.keys(): losses[key] = torch.nn.HuberLoss(reduction=reduction, delta=deltas[key]) else: raise ValueError(f"Unknown loss type: {type}") self.losses = losses self.weight = weight self.gradient_weights = gradient_weights self.sliding_factor = sliding_factor self.sliding_weights: Optional[Dict[str, TensorMap]] = None def __call__( self, predictions_tensor_map: TensorMap, targets_tensor_map: TensorMap, ) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]: # Check that the two have the same metadata, except for the samples, # which can be different due to batching, but must have the same size: if predictions_tensor_map.keys != targets_tensor_map.keys: raise ValueError( "TensorMapLoss requires the two TensorMaps to have the same keys." ) for block_1, block_2 in zip( predictions_tensor_map.blocks(), targets_tensor_map.blocks() ): if block_1.properties != block_2.properties: raise ValueError( "TensorMapLoss requires the two TensorMaps to have the same " "properties." ) if block_1.components != block_2.components: raise ValueError( "TensorMapLoss requires the two TensorMaps to have the same " "components." ) if len(block_1.samples) != len(block_2.samples): raise ValueError( "TensorMapLoss requires the two TensorMaps " "to have the same number of samples." ) for gradient_name in block_2.gradients_list(): if len(block_1.gradient(gradient_name).samples) != len( block_2.gradient(gradient_name).samples ): raise ValueError( "TensorMapLoss requires the two TensorMaps " "to have the same number of gradient samples." ) if ( block_1.gradient(gradient_name).properties != block_2.gradient(gradient_name).properties ): raise ValueError( "TensorMapLoss requires the two TensorMaps " "to have the same gradient properties." ) if ( block_1.gradient(gradient_name).components != block_2.gradient(gradient_name).components ): raise ValueError( "TensorMapLoss requires the two TensorMaps " "to have the same gradient components." ) # First time the function is called: compute the sliding weights only # from the targets (if they are enabled) if self.sliding_factor is not None and self.sliding_weights is None: self.sliding_weights = get_sliding_weights( self.losses, self.sliding_factor, targets_tensor_map, ) # Compute the loss: loss = torch.zeros( (), dtype=predictions_tensor_map.block(0).values.dtype, device=predictions_tensor_map.block(0).values.device, ) for key in targets_tensor_map.keys: block_1 = predictions_tensor_map.block(key) block_2 = targets_tensor_map.block(key) values_1 = block_1.values values_2 = block_2.values # sliding weights: default to 1.0 if not used/provided for this target sliding_weight = ( 1.0 if self.sliding_weights is None else self.sliding_weights.get("values", 1.0) ) loss += ( self.weight * self.losses["values"](values_1, values_2) / sliding_weight ) for gradient_name in block_2.gradients_list(): gradient_weight = self.gradient_weights[gradient_name] values_1 = block_1.gradient(gradient_name).values values_2 = block_2.gradient(gradient_name).values # sliding weights: default to 1.0 if not used/provided for this target sliding_weigths_value = ( 1.0 if self.sliding_weights is None else self.sliding_weights.get(gradient_name, 1.0) ) loss += ( gradient_weight * self.losses[gradient_name](values_1, values_2) / sliding_weigths_value ) if self.sliding_factor is not None: self.sliding_weights = get_sliding_weights( self.losses, self.sliding_factor, targets_tensor_map, predictions_tensor_map, self.sliding_weights, ) return loss
[docs] class TensorMapDictLoss: """A loss function that operates on two ``Dict[str, metatensor.torch.TensorMap]``. At initialization, the user specifies a list of keys to use for the loss, along with a weight for each key. The loss is then computed as a weighted sum. Any keys that are not present in the dictionaries are ignored. :param weights: A dictionary mapping keys to weights. This might contain gradient keys, in the form ``<output_name>_<gradient_name>_gradients``. :param sliding_factor: The factor to apply to the exponential moving average of the "sliding" weights. These are weights that act on different components of the loss (for example, energies and forces), based on their individual recent history. If ``None``, no sliding weights are used in the computation of the loss. :param reduction: The reduction to apply to the loss. See :py:class:`torch.nn.MSELoss`. :returns: The loss as a zero-dimensional :py:class:`torch.Tensor` (with one entry). """ def __init__( self, weights: Dict[str, float], sliding_factor: Optional[float] = None, reduction: str = "mean", type: Union[str, dict] = "mse", ): outputs = [key for key in weights.keys() if "gradients" not in key] self.losses = {} for output in outputs: value_weight = weights[output] gradient_weights = {} for key, weight in weights.items(): if key.startswith(output) and key.endswith("_gradients"): gradient_name = key.replace(f"{output}_", "").replace( "_gradients", "" ) gradient_weights[gradient_name] = weight type_output = _process_type(type, output) if output == "energy" and sliding_factor is not None: self.losses[output] = TensorMapLoss( reduction=reduction, weight=value_weight, gradient_weights=gradient_weights, sliding_factor=sliding_factor, type=type_output, ) else: self.losses[output] = TensorMapLoss( reduction=reduction, weight=value_weight, gradient_weights=gradient_weights, type=type_output, ) def __call__( self, tensor_map_dict_1: Dict[str, TensorMap], tensor_map_dict_2: Dict[str, TensorMap], ) -> torch.Tensor: # Assert that the two have the keys: assert set(tensor_map_dict_1.keys()) == set(tensor_map_dict_2.keys()) # Initialize the loss: first_values = next(iter(tensor_map_dict_1.values())).block(0).values loss = torch.zeros((), dtype=first_values.dtype, device=first_values.device) # Compute the loss: for target in tensor_map_dict_1.keys(): target_loss = self.losses[target]( tensor_map_dict_1[target], tensor_map_dict_2[target] ) loss += target_loss return loss
[docs] def get_sliding_weights( losses: Dict[str, _Loss], sliding_factor: float, targets: TensorMap, predictions: Optional[TensorMap] = None, previous_sliding_weights: Optional[Dict[str, float]] = None, ) -> Dict[str, float]: """ Compute the sliding weights for the loss function. The sliding weights are computed as the absolute difference between the predictions and the targets. :param predictions: The predictions. :param targets: The targets. :return: The sliding weights. """ sliding_weights = {} if predictions is None: for block in targets.blocks(): values = block.values sliding_weights["values"] = ( losses["values"](values, values.mean() * torch.ones_like(values)) + 1e-6 ) for gradient_name, gradient_block in block.gradients(): values = gradient_block.values sliding_weights[gradient_name] = losses[gradient_name]( values, torch.zeros_like(values) ) elif predictions is not None: if previous_sliding_weights is None: raise RuntimeError( "previous_sliding_weights must be provided if predictions is not None" ) else: for predictions_block, target_block in zip( predictions.blocks(), targets.blocks() ): target_values = target_block.values predictions_values = predictions_block.values sliding_weights["values"] = ( sliding_factor * previous_sliding_weights["values"] + (1 - sliding_factor) * losses["values"](predictions_values, target_values).detach() ) for gradient_name, gradient_block in target_block.gradients(): target_values = gradient_block.values predictions_values = predictions_block.gradient( gradient_name ).values sliding_weights[gradient_name] = ( sliding_factor * previous_sliding_weights[gradient_name] + (1 - sliding_factor) * losses[gradient_name]( predictions_values, target_values ).detach() ) return sliding_weights
def _process_type(type: Union[str, DictConfig], output: str) -> Union[str, dict]: if not isinstance(type, str): assert "huber" in type # we process the Huber loss delta dict to make it similar to the # `weights` dict type_output = {"huber": {"deltas": {}}} # type: ignore for key, delta in type["huber"]["deltas"].items(): key_internal = to_internal_name(key) if key_internal == output: type_output["huber"]["deltas"]["values"] = delta elif key_internal.startswith(output) and key_internal.endswith( "_gradients" ): gradient_name = key_internal.replace(f"{output}_", "").replace( "_gradients", "" ) type_output["huber"]["deltas"][gradient_name] = delta else: pass else: type_output = type # type: ignore return type_output