Adding a new loss function

This page describes the required classes and files necessary for adding a new loss function to metatrain. Defining a new loss can be useful in case some extra data has to be used to compute the loss.

Loss functions in metatrain are implemented as subclasses of metatrain.utils.loss.LossInterface. This interface defines the required method compute(), which takes the model predictions and the ground truth values as input and returns the computed loss value. The compute() method accepts an additional argument extra_data on top of predictions and targets, that can be used to pass any extra information needed for the loss computation.

from typing import Dict, Optional
import torch
from metatrain.utils.loss import LossInterface
from metatensor.torch import TensorMap

class NewLoss(LossInterface):
    def __init__(
        self,
        name: str,
        gradient: Optional[str],
        weight: float,
        reduction: str,
    ) -> None:
        ...

    def compute(
        self,
        predictions: Dict[str, TensorMap],
        targets: Dict[str, TensorMap],
        extra_data: Dict[str, TensorMap]
    ) -> torch.Tensor:
        ...

Examples of loss functions already implemented in metatrain are metatrain.utils.loss.TensorMapMSELoss and metatrain.utils.loss.TensorMapMAELoss. They both inherit from the metatrain.utils.loss.BaseTensorMapLoss class, which implements pointwise losses for metatensor.torch.TensorMap objects.

Loss weight scheduling

Currently, only one loss weight scheduler is implemented in metatrain, which is metatrain.utils.loss.EMAScheduler. This class is used to schedule the weight of a loss function based on the Exponential Moving Average (EMA) of the loss value. The EMA scheduler is useful to adapt the loss weight during training, allowing for a more dynamic adjustment of the loss contribution based on the training progress. New schedulers can be implemented by inheriting from the metatrain.utils.loss.WeightScheduler abstract class, which defines the initialize() and update() methods that need to be implemented.