Loss¶
- class metatrain.utils.loss.LossInterface(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
ABC
Abstract base for all loss functions.
Subclasses must implement the
compute
method.- Parameters:
- abstractmethod compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor [source]¶
Compute the loss.
- class metatrain.utils.loss.WeightScheduler[source]¶
Bases:
ABC
Abstract interface for scheduling a weight for a
LossInterface
.- abstractmethod initialize(loss_fn: LossInterface, targets: Dict[str, TensorMap]) float [source]¶
Compute and return the initial weight.
- Parameters:
loss_fn (LossInterface) – the base loss to initialize.
targets (Dict[str, TensorMap]) – mapping of target names to
TensorMap
.
- Returns:
initial weight as a float.
- Return type:
- class metatrain.utils.loss.EMAScheduler(sliding_factor: float | None)[source]¶
Bases:
WeightScheduler
Exponential moving average scheduler for loss weights.
- Parameters:
sliding_factor (float | None) – factor in [0,1] for EMA (0 disables scheduling).
- EPSILON = 1e-06¶
- initialize(loss_fn: LossInterface, targets: Dict[str, TensorMap]) float [source]¶
Compute and return the initial weight.
- Parameters:
loss_fn (LossInterface) – the base loss to initialize.
targets (Dict[str, TensorMap]) – mapping of target names to
TensorMap
.
- Returns:
initial weight as a float.
- Return type:
- class metatrain.utils.loss.ScheduledLoss(base_loss: LossInterface, weight_scheduler: WeightScheduler)[source]¶
Bases:
LossInterface
Wrap a base
LossInterface
with aWeightScheduler
. After each compute, the scheduler updates the loss weight.- Parameters:
base_loss (LossInterface) – underlying LossInterface to wrap.
weight_scheduler (WeightScheduler) – scheduler that controls the multiplier.
- class metatrain.utils.loss.BaseTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
LossInterface
Backbone for pointwise losses on
TensorMap
entries.Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss.
- Parameters:
- compute_flattened(tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: TensorMap | None = None) Tensor [source]¶
Flatten prediction and target blocks (and optional mask), then apply the torch loss.
- class metatrain.utils.loss.MaskedTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
BaseTensorMapLoss
Pointwise masked loss on
TensorMap
entries.Inherits flattening and torch-loss logic from BaseTensorMapLoss.
- Parameters:
- compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor [source]¶
Gather and flatten target and prediction blocks, then compute loss.
- Parameters:
predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target
name
used in the constructor, there is a corresponding data fieldname + "_mask"
that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.
- Returns:
Scalar loss tensor.
- Return type:
- class metatrain.utils.loss.TensorMapMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLoss
Unmasked mean-squared error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLoss
Unmasked mean-absolute error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
BaseTensorMapLoss
Unmasked Huber loss on
TensorMap
entries.- Parameters:
delta (float) – threshold parameter for HuberLoss.
name (str) – key in the predictions/targets dict.
gradient (str | None) – optional gradient field name.
weight (float) – dummy here; real weighting in ScheduledLoss.
reduction (str) – reduction mode for torch loss.
loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).
- class metatrain.utils.loss.TensorMapMaskedMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLoss
Masked mean-squared error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapMaskedMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLoss
Masked mean-absolute error on
TensorMap
entries.
- class metatrain.utils.loss.TensorMapMaskedHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
MaskedTensorMapLoss
Masked Huber loss on
TensorMap
entries.- Parameters:
delta (float) – threshold parameter for HuberLoss.
name (str) – key in the predictions/targets dict.
gradient (str | None) – optional gradient field name.
weight (float) – dummy here; real weighting in ScheduledLoss.
reduction (str) – reduction mode for torch loss.
loss_fn – pre-instantiated torch.nn loss (e.g. MSELoss).
- class metatrain.utils.loss.LossAggregator(targets: Dict[str, TargetInfo], config: Dict[str, Dict[str, Any]])[source]¶
Bases:
LossInterface
Aggregate multiple
LossInterface
terms with scheduled weights and metadata.- Parameters:
- class metatrain.utils.loss.LossType(*values)[source]¶
Bases:
Enum
Enumeration of available loss types and their implementing classes.
- MSE = ('mse', <class 'metatrain.utils.loss.TensorMapMSELoss'>)¶
- MAE = ('mae', <class 'metatrain.utils.loss.TensorMapMAELoss'>)¶
- HUBER = ('huber', <class 'metatrain.utils.loss.TensorMapHuberLoss'>)¶
- MASKED_MSE = ('masked_mse', <class 'metatrain.utils.loss.TensorMapMaskedMSELoss'>)¶
- MASKED_MAE = ('masked_mae', <class 'metatrain.utils.loss.TensorMapMaskedMAELoss'>)¶
- MASKED_HUBER = ('masked_huber', <class 'metatrain.utils.loss.TensorMapMaskedHuberLoss'>)¶
- POINTWISE = ('pointwise', <class 'metatrain.utils.loss.BaseTensorMapLoss'>)¶
- MASKED_POINTWISE = ('masked_pointwise', <class 'metatrain.utils.loss.MaskedTensorMapLoss'>)¶
- property cls: Type[LossInterface]¶
Class implementing this loss type.
- metatrain.utils.loss.create_loss(loss_type: str, *, name: str, gradient: str | None, weight: float, reduction: str, **extra_kwargs: Any) LossInterface [source]¶
Factory to instantiate a concrete
LossInterface
given its string key.- Parameters:
loss_type (str) – string key matching one of the members of
LossType
.name (str) – target name for the loss.
gradient (str | None) – gradient name, if present.
weight (float) – weight for the loss contribution.
reduction (str) – reduction mode for the torch loss.
extra_kwargs (Any) – additional hyperparameters specific to the loss type.
- Returns:
instance of the selected loss.
- Return type: