Adding a new architecture

This page describes the required classes and files necessary for adding a new architecture to metatrain as experimental or stable architecture as described on the Life Cycle of an Architecture page.

To work with metatrain any architecture has to follow the same public API to be called correctly within the metatrain.cli.train() function to process the user’s options. In brief, the core of the train function looks similar to these lines

from architecture import __model__ as Model
from architecture import __trainer__ as Trainer

hypers = {...}
dataset_info = DatasetInfo()

if checkpoint_path is not None:
    checkpoint = torch.load(checkpoint_path)

    trainer = Trainer.load_checkpoint(
        checkpoint, hypers=hypers["training"], context="restart")
    model = Model.load_checkpoint(checkpoint, context="restart")
    model = model.restart(dataset_info)
else:
    trainer = Trainer(hypers["training"])

    if hasattr(hypers["training"], "finetune"):
        checkpoint = hypers["training"]["finetune"]["read_from"]
        model = Model.load_checkpoint(path=checkpoint, context="finetune")
    else:
        model = Model(hypers["model"], dataset_info)

trainer.train(
    model=model,
    dtype=dtype,
    devices=[],
    train_datasets=[],
    val_datasets=[],
    checkpoint_dir="path",
)

model.save_checkpoint("model.ckpt")

mts_atomistic_model = model.export()
mts_atomistic_model.export("model.pt", collect_extensions="extensions/")

To follow this, a new architecture has to define two classes

  • a Model class, defining the core of the architecture. This class must implement the interface documented below in metatrain.utils.abc.ModelInterface

  • a Trainer class, used to train an architecture and produce a model that can be evaluated and exported. This class must implement the interface documented below in metatrain.utils.abc.TrainerInterface.

Note

metatrain does not know the types and numbers of targets/datasets an architecture can handle. As a result, it cannot generate useful error messages when a user attempts to train an architecture with unsupported target and dataset combinations. Therefore, it is the responsibility of the architecture developer to verify if the model and the trainer support the provided train_datasets and val_datasets passed to the Trainer, as well as the dataset_info passed to the model.

To comply with this design each architecture has to implement a couple of files inside a new architecture directory, either inside the experimental subdirectory or in the root of the Python source if the new architecture already complies with all requirements to be stable. The usual structure of architecture looks as

myarchitecture
    ├── model.py
    ├── trainer.py
    ├── __init__.py
    ├── default-hypers.yaml
    └── schema-hypers.json

Note

A new architecture doesn’t have to be registered somewhere in the file tree of metatrain. Once a new architecture folder with the required files is created metatrain will include the architecture automatically.

Note

Because achitectures can live in either src/metatrain/<architecture>, src/metatrain/experimental/<architecture>, or src/metatrain/deprecated/<architecture>; the code inside should use absolute imports use the tools provided by metatrain.

# do not do this
from ..utils.dtype import dtype_to_str

# Do this instead
from metatrain.utils.dtype import dtype_to_str

Model class (model.py)

class metatrain.utils.abc.ModelInterface[source]

Abstract base class for a machine learning model in metatrain.

All architectures in metatrain must be implemented as sub-class of this class, and implement the corresponding methods.

abstractmethod forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap][source]

Execute the model for the given systems, computing the requested outputs.

See also

metatensor.torch.atomistic.ModelInterface for more explanation about the different arguments.

Parameters:
  • systems (List[System])

  • outputs (Dict[str, ModelOutput])

  • selected_atoms (Labels | None)

Return type:

Dict[str, TensorMap]

abstractmethod supported_outputs() Dict[str, ModelOutput][source]

Get the outputs currently supported by this model.

This will likely be the same outputs that are set as this model capabilities in ModelInterface.export().

Return type:

Dict[str, ModelOutput]

abstractmethod restart(dataset_info: DatasetInfo) ModelInterface[source]

Update a model to restart training, potentially with different dataset and/or targets.

This function is called whenever training restarts, with the same or a different dataset. It enables transfer learning (changing the targets), and fine-tuning (same targets, different datasets)

This function should return the updated model, or a new instance of the model able to handle the new dataset.

Parameters:

dataset_info (DatasetInfo)

Return type:

ModelInterface

abstractmethod classmethod load_checkpoint(checkpoint: Dict[str, Any], context: Literal['restart', 'finetune', 'export']) ModelInterface[source]

Create a model from a checkpoint (i.e. state dictionary).

Parameters:
  • checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.

  • context (Literal['restart', 'finetune', 'export']) – Context in which to load the model. Possible values are "restart" when restarting a stopped traininf run, "finetune" when loading a model for further fine-tuning or transfer learning, and "export" when loading a model for final export. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context.

Return type:

ModelInterface

abstractmethod export(metadata: ModelMetadata | None = None) MetatensorAtomisticModel[source]

Turn this model into an instance of metatensor.torch.atomistic.MetatensorAtomisticModel, containing the model itself, a definition of the model capabilities and some metadata about the model.

Parameters:

metadata (ModelMetadata | None) – additional metadata to add in the model as specified by the user.

Return type:

MetatensorAtomisticModel

Defining a new model can then be done as follow;

from metatomic.torch import ModelMetadata
from metatrain.utils.abc import ModelInterface

class MyModel(ModelInterface):

    __supported_devices__ = ["cuda", "cpu"]
    __supported_dtypes__ = [torch.float64, torch.float32]
    __default_metadata__ = ModelMetadata(
        references = {"implementation": ["ref1"], "architecture": ["ref2"]}
    )

    def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):
        ...

    ... # implementation of all the functions from ModelInterface

In addition to subclassing ModelInterface, the model class should have the following class attributes:

  • __supported_devices__ list of the suported torch devices for running the model;

  • __supported_dtypes__ list of the supported dtype for this model;

Both lists should be sorted in order of preference since metatrain will use these to determine, based on the user request and machines’ availability, the optimal dtype and device for training.

  • __default_metadata__ can be used to provide references that will be stored in the exported model. The references are stored in a dictionary with keys implementation and architecture. The implementation key should contain references to the software used in the implementation of the architecture, while the architecture key should contain references about the general architecture.

Trainer class (trainer.py)

class metatrain.utils.abc.TrainerInterface(train_hypers)[source]

Abstract base class for a model trainer in metatrain.

All architectures in metatrain must implement such a trainer, which is responsible for training the model. The trainer must be a be sub-class of this class, and implement the corresponding methods.

Create a trainer using the hyper-parameters in train_hypers.

abstractmethod train(model: ModelInterface, dtype: dtype, devices: List[device], train_datasets: List[Dataset | Subset], val_datasets: List[Dataset | Subset], checkpoint_dir: str)[source]

Train the model using the train_datasets. How to train the model is left to this class, using the hyper-parameter given in __init__.

Parameters:
  • model (ModelInterface) – the model to train

  • dtype (dtype) – torch.dtype used by the data in the datasets

  • devices (List[device]) – torch.device to use for training the model. When training with more than one device (e.g. multi-GPU training), this can contains multiple devices.

  • train_datasets (List[Dataset | Subset]) – datasets to use to train the model

  • val_datasets (List[Dataset | Subset]) – datasets to use for model validation

  • checkpoint_dir (str) – directory where checkpoints shoudl be saved

abstractmethod save_checkpoint(model, path: str | Path)[source]

Save a checkoint of both the model and trainer state to the given path

Parameters:

path (str | Path)

abstractmethod classmethod load_checkpoint(checkpoint: Dict[str, Any], train_hypers: Dict[str, Any], context: Literal['restart', 'finetune']) TrainerInterface[source]

Create a trainer instance from data stored in the checkpoint.

Parameters:
  • checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.

  • train_hypers (Dict[str, Any]) – Hyper-parameters for the trainer, as specified by the user.

  • context (Literal['restart', 'finetune']) – Context in which to load the model. Possible values are "restart" when restarting a stopped traininf run, and "finetune" when loading a model for further fine-tuning or transfer learning. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context.

Return type:

TrainerInterface

Defining a new model can then be done as like this;

from metatrain.utils.abc import TrainerInterface

class MyTrainer(TrainerInterface):

    def __init__(self, train_hypers):
        ...

    ... # implementation of all the functions from TrainerInterface

Init file (__init__.py)

You are free to name the Model and Trainer classes as you want. These classes should then be made available in the __init__.py under the names __model__ and __trainer__ so metatrain knows where to find them. __init__.py must also contain definition for the original __authors__ and current __maintainers__ of the architecture.

from .model import ModelInterface
from .trainer import TrainerInterface

# class to use as the architecture's model
__model__ = ModelInterface
# class to use as the architecture's trainer
__trainer__ = TrainerInterface

# List of the original authors of the architecture, each with an email
# address and GitHub handle.
#
# These authors are not necessarily currently in charge of maintaining the code
__authors__ = [
    ("Jane Roe <jane.roe@myuniversity.org>", "@janeroe"),
    ("John Doe <john.doe@otheruniversity.edu>", "@johndoe"),
]

# Current maintainers of the architecture code, using the same
# style as ``__authors__``
__maintainers__ = [("Joe Bloggs <joe.bloggs@sotacompany.com>", "@joebloggs")]

Default Hyperparamers (default-hypers.yaml)

The default hyperparameters for each architecture should be stored in a YAML file default-hypers.yaml inside the architecture directory. Reasonable default hypers are required to improve usability. The default hypers must follow the structure

name: myarchitecture

model:
    ...

training:
    ...

metatrain will parse this file and overwrite these default hypers with the user-provided parameters and pass the merged model section as a Python dictionary to the ModelInterface and the training section to the TrainerInterface.

Finetuning

If your architecture is supporting finetuning you have to add a finetune subsection in the training section. The subsection must contain a read_from key that points to the checkpoint file the finetuning is started from. Any additional hyperparameters can be architecture specific.

training:
    finetune:
        read_from: path/to/checkpoint.ckpt
        # other architecture finetune hyperparameters

JSON schema (schema-hypers.yaml)

To validate the user’s input hyperparameters we are using JSON schemas stored in a schema file called schema-hypers.json. For an experimental architecture it is not required to provide such a schema along with its default hypers but it is highly recommended to reduce possible errors of user input like typos in parameter names or wrong sections. If no schema-hypers.json is provided no validation is performed and user hypers are passed to the architecture model and trainer as is.

To create such a schema you can try using online tools that convert the default-hypers.yaml into a JSON schema. Besides online tools, we also had success using ChatGPT/LLM for this for conversion.

Documentation

Each new architecture should be added to metatrain’s documentation. A short page describing the architecture and its default hyperparameters will be sufficient. You can take inspiration from existing architectures. The various targets that the architecture can fit should be added to the table in the “Fitting generic targets” section.