Source code for metatrain.utils.loss

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Type

import metatensor.torch as mts
import torch
from metatensor.torch import TensorMap
from torch.nn.modules.loss import _Loss

from metatrain.utils.data import TargetInfo


[docs] class LossInterface(ABC): """ Abstract base for all loss functions. Subclasses must implement the ``compute`` method. """ weight: float reduction: str loss_kwargs: Dict[str, Any] target: str gradient: Optional[str] def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, ) -> None: """ :param name: key in the predictions/targets dict to select the TensorMap. :param gradient: optional name of a gradient field to extract. :param weight: multiplicative weight (used by ScheduledLoss). :param reduction: reduction mode for torch losses ("mean", "sum", etc.). """ self.target = name self.gradient = gradient self.weight = weight self.reduction = reduction self.loss_kwargs = {} super().__init__()
[docs] @abstractmethod def compute( self, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Optional[Any] = None, ) -> torch.Tensor: """ Compute the loss. :param predictions: mapping from target names to :py:class:`TensorMap`. :param targets: mapping from target names to :py:class:`TensorMap`. :param extra_data: optional additional data (e.g., masks). :return: scalar torch.Tensor representing the loss. """ ...
def __call__( self, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Optional[Any] = None, ) -> torch.Tensor: """ Alias to compute() for direct invocation. """ return self.compute(predictions, targets, extra_data)
[docs] @classmethod def from_config(cls, cfg: Dict[str, Any]) -> "LossInterface": """ Instantiate a loss from a config dict. :param cfg: keyword args matching the loss constructor. :return: instance of a LossInterface subclass. """ return cls(**cfg)
# --- scheduler interface and implementations ------------------------------------------
[docs] class WeightScheduler(ABC): """ Abstract interface for scheduling a weight for a :py:class:`LossInterface`. """ initialized: bool = False
[docs] @abstractmethod def initialize( self, loss_fn: LossInterface, targets: Dict[str, TensorMap] ) -> float: """ Compute and return the initial weight. :param loss_fn: the base loss to initialize. :param targets: mapping of target names to :py:class:`TensorMap`. :return: initial weight as a float. """
[docs] @abstractmethod def update( self, loss_fn: LossInterface, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], ) -> float: """ Update and return the new weight after a batch. :param loss_fn: the base loss. :param predictions: mapping of target names to :py:class:`TensorMap`. :param targets: mapping of target names to :py:class:`TensorMap`. :return: updated weight as a float. """
[docs] class EMAScheduler(WeightScheduler): """ Exponential moving average scheduler for loss weights. """ EPSILON = 1e-6 def __init__(self, sliding_factor: Optional[float]) -> None: """ :param sliding_factor: factor in [0,1] for EMA (0 disables scheduling). """ self.sliding_factor = float(sliding_factor or 0.0) self.current_weight = 1.0 self.initialized = False
[docs] def initialize( self, loss_fn: LossInterface, targets: Dict[str, TensorMap] ) -> float: # If scheduling disabled, keep weight = 1.0 if self.sliding_factor <= 0.0: self.current_weight = 1.0 else: # Compute a baseline loss against a constant mean or zero-gradient map target_name = loss_fn.target gradient_name = getattr(loss_fn, "gradient", None) tensor_map_for_target = targets[target_name] if gradient_name is None: # Create a baseline TensorMap with all values = mean over samples mean_tensor_map = mts.mean_over_samples( tensor_map_for_target, tensor_map_for_target.sample_names ) baseline_tensor_map = TensorMap( keys=tensor_map_for_target.keys, blocks=[ mts.TensorBlock( samples=block.samples, components=block.components, properties=block.properties, values=torch.ones_like(block.values) * mean_block.values, ) for block, mean_block in zip( tensor_map_for_target, mean_tensor_map ) ], ) else: # Zero baseline for gradient-based losses baseline_tensor_map = mts.zeros_like(tensor_map_for_target) initial_loss_value = loss_fn.compute( {target_name: tensor_map_for_target}, {target_name: baseline_tensor_map} ) self.current_weight = float(initial_loss_value.clamp_min(self.EPSILON)) self.initialized = True return self.current_weight
[docs] def update( self, loss_fn: LossInterface, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], ) -> float: # If scheduling disabled, return fixed weight if self.sliding_factor <= 0.0: return self.current_weight # Compute the instantaneous error instantaneous_error = loss_fn.compute(predictions, targets).detach().item() # EMA update new_weight = ( self.sliding_factor * self.current_weight + (1.0 - self.sliding_factor) * instantaneous_error ) self.current_weight = max(new_weight, self.EPSILON) return self.current_weight
[docs] class ScheduledLoss(LossInterface): """ Wrap a base :py:class:`LossInterface` with a :py:class:`WeightScheduler`. After each compute, the scheduler updates the loss weight. """ def __init__(self, base_loss: LossInterface, weight_scheduler: WeightScheduler): """ :param base_loss: underlying LossInterface to wrap. :param weight_scheduler: scheduler that controls the multiplier. """ super().__init__( base_loss.target, base_loss.gradient, base_loss.weight, base_loss.reduction, ) self.base_loss = base_loss self.scheduler = weight_scheduler self.loss_kwargs = getattr(base_loss, "loss_kwargs", {})
[docs] def compute( self, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Optional[Any] = None, ) -> torch.Tensor: # Initialize scheduler on first call if not self.scheduler.initialized: self.normalization_factor = self.scheduler.initialize( self.base_loss, targets ) # compute the raw loss using the base loss function raw_loss_value = self.base_loss.compute(predictions, targets, extra_data) # scale by the fixed weight and divide by the sliding weight weighted_loss_value = raw_loss_value * ( self.base_loss.weight / self.normalization_factor ) # update the sliding weight self.normalization_factor = self.scheduler.update( self.base_loss, predictions, targets ) return weighted_loss_value
# --- specific losses ------------------------------------------------------------------
[docs] class BaseTensorMapLoss(LossInterface): """ Backbone for pointwise losses on :py:class:`TensorMap` entries. Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss. """ def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, *, loss_fn: _Loss, ): """ :param name: key in the predictions/targets dict. :param gradient: optional gradient field name. :param weight: dummy here; real weighting in ScheduledLoss. :param reduction: reduction mode for torch loss. :param loss_fn: pre-instantiated torch.nn loss (e.g. MSELoss). """ super().__init__(name, gradient, weight, reduction) self.torch_loss = loss_fn
[docs] def compute_flattened( self, tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: Optional[TensorMap] = None, ) -> torch.Tensor: """ Flatten prediction and target blocks (and optional mask), then apply the torch loss. :param tensor_map_predictions_for_target: predicted :py:class:`TensorMap`. :param tensor_map_targets_for_target: target :py:class:`TensorMap`. :param tensor_map_mask_for_target: optional mask :py:class:`TensorMap`. :return: scalar torch.Tensor of the computed loss. """ list_of_prediction_segments = [] list_of_target_segments = [] def extract_flattened_values_from_block( tensor_block: mts.TensorBlock, ) -> torch.Tensor: """ Extract values or gradients from a block, flatten to 1D. """ if self.gradient is not None: values = tensor_block.gradient(self.gradient).values else: values = tensor_block.values return values.reshape(-1) # Loop over each key in the TensorMap for single_key in tensor_map_predictions_for_target.keys: block_for_prediction = tensor_map_predictions_for_target.block(single_key) block_for_target = tensor_map_targets_for_target.block(single_key) flattened_prediction = extract_flattened_values_from_block( block_for_prediction ) flattened_target = extract_flattened_values_from_block(block_for_target) if tensor_map_mask_for_target is not None: # Apply boolean mask if provided block_for_mask = tensor_map_mask_for_target.block(single_key) flattened_mask = extract_flattened_values_from_block( block_for_mask ).bool() flattened_prediction = flattened_prediction[flattened_mask] flattened_target = flattened_target[flattened_mask] list_of_prediction_segments.append(flattened_prediction) list_of_target_segments.append(flattened_target) # Concatenate all segments and apply the torch loss all_predictions_flattened = torch.cat(list_of_prediction_segments) all_targets_flattened = torch.cat(list_of_target_segments) return self.torch_loss(all_predictions_flattened, all_targets_flattened)
[docs] def compute( self, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Optional[Any] = None, ) -> torch.Tensor: """ Compute the unmasked pointwise loss. :param predictions: mapping of names to :py:class:`TensorMap`. :param targets: mapping of names to :py:class:`TensorMap`. :param extra_data: ignored for unmasked losses. :return: scalar torch.Tensor loss. """ tensor_map_pred = predictions[self.target] tensor_map_targ = targets[self.target] # Check gradients are present in the target TensorMap if self.gradient is not None: if self.gradient not in tensor_map_targ[0].gradients_list(): # Skip loss computation if block gradient is missing in the dataset # Tensor gradients are not tracked return torch.zeros( (), dtype=torch.float, device=tensor_map_targ[0].values.device ) return self.compute_flattened(tensor_map_pred, tensor_map_targ)
[docs] class MaskedTensorMapLoss(BaseTensorMapLoss): """ Pointwise masked loss on :py:class:`TensorMap` entries. Inherits flattening and torch-loss logic from BaseTensorMapLoss. """
[docs] def compute( self, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Optional[Dict[str, TensorMap]] = None, ) -> torch.Tensor: """ Gather and flatten target and prediction blocks, then compute loss. :param predictions: Mapping from target names to TensorMaps. :param targets: Mapping from target names to TensorMaps. :param extra_data: Additional data for loss computation. Assumes that, for the target ``name`` used in the constructor, there is a corresponding data field ``name + "_mask"`` that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors. :return: Scalar loss tensor. """ mask_key = f"{self.target}_mask" if extra_data is None or mask_key not in extra_data: raise ValueError( f"Expected extra_data to contain TensorMap under '{mask_key}'" ) tensor_map_pred = predictions[self.target] tensor_map_targ = targets[self.target] tensor_map_mask = extra_data[mask_key] return self.compute_flattened(tensor_map_pred, tensor_map_targ, tensor_map_mask)
# ------------------------------------------------------------------------ # Simple explicit subclasses for common pointwise losses # ------------------------------------------------------------------------
[docs] class TensorMapMSELoss(BaseTensorMapLoss): """ Unmasked mean-squared error on :py:class:`TensorMap` entries. """ def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, ): super().__init__( name, gradient, weight, reduction, loss_fn=torch.nn.MSELoss(reduction=reduction), )
[docs] class TensorMapMAELoss(BaseTensorMapLoss): """ Unmasked mean-absolute error on :py:class:`TensorMap` entries. """ def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, ): super().__init__( name, gradient, weight, reduction, loss_fn=torch.nn.L1Loss(reduction=reduction), )
[docs] class TensorMapHuberLoss(BaseTensorMapLoss): """ Unmasked Huber loss on :py:class:`TensorMap` entries. :param delta: threshold parameter for HuberLoss. """ def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, delta: float, ): super().__init__( name, gradient, weight, reduction, loss_fn=torch.nn.HuberLoss(reduction=reduction, delta=delta), )
[docs] class TensorMapMaskedMSELoss(MaskedTensorMapLoss): """ Masked mean-squared error on :py:class:`TensorMap` entries. """ def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, ): super().__init__( name, gradient, weight, reduction, loss_fn=torch.nn.MSELoss(reduction=reduction), )
[docs] class TensorMapMaskedMAELoss(MaskedTensorMapLoss): """ Masked mean-absolute error on :py:class:`TensorMap` entries. """ def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, ): super().__init__( name, gradient, weight, reduction, loss_fn=torch.nn.L1Loss(reduction=reduction), )
[docs] class TensorMapMaskedHuberLoss(MaskedTensorMapLoss): """ Masked Huber loss on :py:class:`TensorMap` entries. :param delta: threshold parameter for HuberLoss. """ def __init__( self, name: str, gradient: Optional[str], weight: float, reduction: str, delta: float, ): super().__init__( name, gradient, weight, reduction, loss_fn=torch.nn.HuberLoss(reduction=reduction, delta=delta), )
# --- aggregator -----------------------------------------------------------------------
[docs] class LossAggregator(LossInterface): """ Aggregate multiple :py:class:`LossInterface` terms with scheduled weights and metadata. """ def __init__( self, targets: Dict[str, TargetInfo], config: Dict[str, Dict[str, Any]] ): """ :param targets: mapping from target names to :py:class:`TargetInfo`. :param config: per-target configuration dict. """ super().__init__(name="", gradient=None, weight=0.0, reduction="mean") self.scheduled_losses: Dict[str, ScheduledLoss] = {} self.metadata: Dict[str, Dict[str, Any]] = {} for target_name, target_info in targets.items(): target_config = config.get( target_name, { "type": "mse", "weight": 1.0, "reduction": "mean", "sliding_factor": None, "gradients": {}, }, ) # Create main loss and its scheduler base_loss = create_loss( target_config["type"], name=target_name, gradient=None, weight=target_config["weight"], reduction=target_config["reduction"], **{ pname: pval for pname, pval in target_config.items() if pname not in ( "type", "weight", "reduction", "sliding_factor", "gradients", ) }, ) ema_scheduler = EMAScheduler(target_config["sliding_factor"]) scheduled_main_loss = ScheduledLoss(base_loss, ema_scheduler) self.scheduled_losses[target_name] = scheduled_main_loss self.metadata[target_name] = { "type": target_config["type"], "weight": base_loss.weight, "reduction": base_loss.reduction, "sliding_factor": target_config["sliding_factor"], "gradients": {}, } for pname, pval in target_config.items(): if pname not in ( "type", "weight", "reduction", "sliding_factor", "gradients", ): self.metadata[target_name][pname] = pval # Create gradient-based losses gradient_config = target_config["gradients"] for gradient_name in target_info.layout[0].gradients_list(): gradient_key = f"{target_name}_grad_{gradient_name}" gradient_specific_config = gradient_config.get( gradient_name, { "type": "mse", "weight": 1.0, "reduction": "mean", "sliding_factor": None, }, ) grad_loss = create_loss( gradient_specific_config["type"], name=target_name, gradient=gradient_name, weight=gradient_specific_config["weight"], reduction=gradient_specific_config["reduction"], **{ pname: pval for pname, pval in gradient_specific_config.items() if pname not in ( "type", "weight", "reduction", "sliding_factor", "gradients", ) }, ) ema_scheduler_for_grad = EMAScheduler(target_config["sliding_factor"]) scheduled_grad_loss = ScheduledLoss(grad_loss, ema_scheduler_for_grad) self.scheduled_losses[gradient_key] = scheduled_grad_loss self.metadata[target_name]["gradients"][gradient_name] = { "type": gradient_specific_config["type"], "weight": grad_loss.weight, "reduction": grad_loss.reduction, "sliding_factor": target_config["sliding_factor"], } for pname, pval in gradient_specific_config.items(): if pname not in ( "type", "weight", "reduction", "sliding_factor", "gradients", ): self.metadata[target_name]["gradients"][gradient_name][ pname ] = pval
[docs] def compute( self, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Optional[Any] = None, ) -> torch.Tensor: """ Sum over all scheduled losses present in the predictions. """ # Initialize a zero tensor matching the dtype and device of the first block first_tensor_map = next(iter(predictions.values())) first_block = first_tensor_map.block(first_tensor_map.keys[0]) total_loss = torch.zeros( (), dtype=first_block.values.dtype, device=first_block.values.device ) # Sum each scheduled term that has a matching prediction for scheduled_term in self.scheduled_losses.values(): if scheduled_term.target not in predictions: continue total_loss = total_loss + scheduled_term.compute( predictions, targets, extra_data ) return total_loss
[docs] class LossType(Enum): """ Enumeration of available loss types and their implementing classes. """ MSE = ("mse", TensorMapMSELoss) MAE = ("mae", TensorMapMAELoss) HUBER = ("huber", TensorMapHuberLoss) MASKED_MSE = ("masked_mse", TensorMapMaskedMSELoss) MASKED_MAE = ("masked_mae", TensorMapMaskedMAELoss) MASKED_HUBER = ("masked_huber", TensorMapMaskedHuberLoss) POINTWISE = ("pointwise", BaseTensorMapLoss) MASKED_POINTWISE = ("masked_pointwise", MaskedTensorMapLoss) def __init__(self, key: str, cls: Type[LossInterface]): self._key = key self._cls = cls @property def key(self) -> str: """String key for this loss type.""" return self._key @property def cls(self) -> Type[LossInterface]: """Class implementing this loss type.""" return self._cls
[docs] @classmethod def from_key(cls, key: str) -> "LossType": """ Look up a LossType by its string key. :raises ValueError: if the key is not valid. """ for loss_type in cls: if loss_type.key == key: return loss_type valid_keys = ", ".join(loss_type.key for loss_type in cls) raise ValueError(f"Unknown loss '{key}'. Valid types: {valid_keys}")
[docs] def create_loss( loss_type: str, *, name: str, gradient: Optional[str], weight: float, reduction: str, **extra_kwargs: Any, ) -> LossInterface: """ Factory to instantiate a concrete :py:class:`LossInterface` given its string key. :param loss_type: string key matching one of the members of :py:class:`LossType`. :param name: target name for the loss. :param gradient: gradient name, if present. :param weight: weight for the loss contribution. :param reduction: reduction mode for the torch loss. :param extra_kwargs: additional hyperparameters specific to the loss type. :return: instance of the selected loss. """ loss_type_entry = LossType.from_key(loss_type) try: return loss_type_entry.cls( name=name, gradient=gradient, weight=weight, reduction=reduction, **extra_kwargs, ) except TypeError as e: raise TypeError(f"Error constructing loss '{loss_type}': {e}") from e