Loss

class metatrain.utils.loss.LossInterface(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: ABC

Abstract base for all loss functions.

Subclasses must implement the compute method.

Parameters:
  • name (str) – key in the predictions/targets dict to select the TensorMap.

  • gradient (str | None) – optional name of a gradient field to extract.

  • weight (float) – multiplicative weight (used by ScheduledLoss).

  • reduction (str) – reduction mode for torch losses (“mean”, “sum”, etc.).

target: str
gradient: str | None
weight: float
reduction: str
loss_kwargs: Dict[str, Any]
abstractmethod compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Compute the loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • targets (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • extra_data (Any | None) – optional additional data (e.g., masks).

Returns:

scalar torch.Tensor representing the loss.

Return type:

Tensor

classmethod from_config(cfg: Dict[str, Any]) LossInterface[source]

Instantiate a loss from a config dict.

Parameters:

cfg (Dict[str, Any]) – keyword args matching the loss constructor.

Returns:

instance of a LossInterface subclass.

Return type:

LossInterface

class metatrain.utils.loss.WeightScheduler[source]

Bases: ABC

Abstract interface for scheduling a weight for a LossInterface.

initialized: bool = False
abstractmethod initialize(loss_fn: LossInterface, targets: Dict[str, TensorMap]) float[source]

Compute and return the initial weight.

Parameters:
Returns:

initial weight as a float.

Return type:

float

abstractmethod update(loss_fn: LossInterface, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap]) float[source]

Update and return the new weight after a batch.

Parameters:
Returns:

updated weight as a float.

Return type:

float

class metatrain.utils.loss.EMAScheduler(sliding_factor: float | None)[source]

Bases: WeightScheduler

Exponential moving average scheduler for loss weights.

Parameters:

sliding_factor (float | None) – factor in [0,1] for EMA (0 disables scheduling).

EPSILON = 1e-06
initialize(loss_fn: LossInterface, targets: Dict[str, TensorMap]) float[source]

Compute and return the initial weight.

Parameters:
Returns:

initial weight as a float.

Return type:

float

update(loss_fn: LossInterface, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap]) float[source]

Update and return the new weight after a batch.

Parameters:
Returns:

updated weight as a float.

Return type:

float

class metatrain.utils.loss.ScheduledLoss(base_loss: LossInterface, weight_scheduler: WeightScheduler)[source]

Bases: LossInterface

Wrap a base LossInterface with a WeightScheduler. After each compute, the scheduler updates the loss weight.

Parameters:
  • base_loss (LossInterface) – underlying LossInterface to wrap.

  • weight_scheduler (WeightScheduler) – scheduler that controls the multiplier.

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Compute the loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • targets (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • extra_data (Any | None) – optional additional data (e.g., masks).

Returns:

scalar torch.Tensor representing the loss.

Return type:

Tensor

class metatrain.utils.loss.BaseTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]

Bases: LossInterface

Backbone for pointwise losses on TensorMap entries.

Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn (_Loss) – pre-instantiated torch.nn loss (e.g. MSELoss).

compute_flattened(tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: TensorMap | None = None) Tensor[source]

Flatten prediction and target blocks (and optional mask), then apply the torch loss.

Parameters:
  • tensor_map_predictions_for_target (TensorMap) – predicted TensorMap.

  • tensor_map_targets_for_target (TensorMap) – target TensorMap.

  • tensor_map_mask_for_target (TensorMap | None) – optional mask TensorMap.

Returns:

scalar torch.Tensor of the computed loss.

Return type:

Tensor

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Compute the unmasked pointwise loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping of names to TensorMap.

  • targets (Dict[str, TensorMap]) – mapping of names to TensorMap.

  • extra_data (Any | None) – ignored for unmasked losses.

Returns:

scalar torch.Tensor loss.

Return type:

Tensor

class metatrain.utils.loss.MaskedTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]

Bases: BaseTensorMapLoss

Pointwise masked loss on TensorMap entries.

Inherits flattening and torch-loss logic from BaseTensorMapLoss.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn (_Loss) – pre-instantiated torch.nn loss (e.g. MSELoss).

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]

Gather and flatten target and prediction blocks, then compute loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target name used in the constructor, there is a corresponding data field name + "_mask" that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.

Returns:

Scalar loss tensor.

Return type:

Tensor

class metatrain.utils.loss.TensorMapMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: BaseTensorMapLoss

Unmasked mean-squared error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).

class metatrain.utils.loss.TensorMapMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: BaseTensorMapLoss

Unmasked mean-absolute error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).

class metatrain.utils.loss.TensorMapHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]

Bases: BaseTensorMapLoss

Unmasked Huber loss on TensorMap entries.

Parameters:
  • delta (float) – threshold parameter for HuberLoss.

  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).

class metatrain.utils.loss.TensorMapMaskedMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: MaskedTensorMapLoss

Masked mean-squared error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).

class metatrain.utils.loss.TensorMapMaskedMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: MaskedTensorMapLoss

Masked mean-absolute error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).

class metatrain.utils.loss.TensorMapMaskedHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]

Bases: MaskedTensorMapLoss

Masked Huber loss on TensorMap entries.

Parameters:
  • delta (float) – threshold parameter for HuberLoss.

  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).

class metatrain.utils.loss.LossAggregator(targets: Dict[str, TargetInfo], config: Dict[str, Dict[str, Any]])[source]

Bases: LossInterface

Aggregate multiple LossInterface terms with scheduled weights and metadata.

Parameters:
  • targets (Dict[str, TargetInfo]) – mapping from target names to TargetInfo.

  • config (Dict[str, Dict[str, Any]]) – per-target configuration dict.

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Sum over all scheduled losses present in the predictions.

Parameters:
Return type:

Tensor

class metatrain.utils.loss.LossType(*values)[source]

Bases: Enum

Enumeration of available loss types and their implementing classes.

MSE = ('mse', <class 'metatrain.utils.loss.TensorMapMSELoss'>)
MAE = ('mae', <class 'metatrain.utils.loss.TensorMapMAELoss'>)
HUBER = ('huber', <class 'metatrain.utils.loss.TensorMapHuberLoss'>)
MASKED_MSE = ('masked_mse', <class 'metatrain.utils.loss.TensorMapMaskedMSELoss'>)
MASKED_MAE = ('masked_mae', <class 'metatrain.utils.loss.TensorMapMaskedMAELoss'>)
MASKED_HUBER = ('masked_huber', <class 'metatrain.utils.loss.TensorMapMaskedHuberLoss'>)
POINTWISE = ('pointwise', <class 'metatrain.utils.loss.BaseTensorMapLoss'>)
MASKED_POINTWISE = ('masked_pointwise', <class 'metatrain.utils.loss.MaskedTensorMapLoss'>)
property key: str

String key for this loss type.

property cls: Type[LossInterface]

Class implementing this loss type.

classmethod from_key(key: str) LossType[source]

Look up a LossType by its string key.

Raises:

ValueError – if the key is not valid.

Parameters:

key (str)

Return type:

LossType

metatrain.utils.loss.create_loss(loss_type: str, *, name: str, gradient: str | None, weight: float, reduction: str, **extra_kwargs: Any) LossInterface[source]

Factory to instantiate a concrete LossInterface given its string key.

Parameters:
  • loss_type (str) – string key matching one of the members of LossType.

  • name (str) – target name for the loss.

  • gradient (str | None) – gradient name, if present.

  • weight (float) – weight for the loss contribution.

  • reduction (str) – reduction mode for the torch loss.

  • extra_kwargs (Any) – additional hyperparameters specific to the loss type.

Returns:

instance of the selected loss.

Return type:

LossInterface