Adding a new architecture

This page describes the required classes and files necessary for adding a new architecture to metatrain as experimental or stable architecture as described on the Life Cycle of an Architecture page. For examples refer to the already existing architectures inside the source tree.

To work with metatrain any architecture has to follow the same public API to be called correctly within the metatrain.cli.train() function to process the user’s options. In brief, the core of the train function looks similar to these lines

from architecture import __model__ as Model
from architecture import __trainer__ as Trainer

hypers = {}
dataset_info = DatasetInfo()

if continue_from is not None:
    trainer = Trainer.load_checkpoint(continue_from, hypers["training"])
    model = Model.load_checkpoint(continue_from)
    model = model.restart(dataset_info)
else:
    model = Model(hypers["model"], dataset_info)
    trainer = Trainer(hypers["training"])

trainer.train(
    model=model,
    dtype=dtype,
    devices=[],
    train_datasets=[],
    val_datasets=[],
    checkpoint_dir="path",
)

model.save_checkpoint("model.ckpt")

mts_atomistic_model = model.export()
mts_atomistic_model.export("model.pt", collect_extensions="extensions/")

To follow this, a new architecture has to define two classes

  • a Model class, defining the core of the architecture. This class must implement the interface documented below in ModelInterface

  • a Trainer class, used to train an architecture and produce a model that can be evaluated and exported. This class must implement the interface documented below in TrainerInterface.

Note

metatrain does not know the types and numbers of targets/datasets an architecture can handle. As a result, it cannot generate useful error messages when a user attempts to train an architecture with unsupported target and dataset combinations. Therefore, it is the responsibility of the architecture developer to verify if the model and the trainer support the provided train_datasets and val_datasets passed to the Trainer, as well as the dataset_info passed to the model.

To comply with this design each architecture has to implement a couple of files inside a new architecture directory, either inside the experimental subdirectory or in the root of the Python source if the new architecture already complies with all requirements to be stable. The usual structure of architecture looks as

myarchitecture
    ├── model.py
    ├── trainer.py
    ├── __init__.py
    ├── default-hypers.yaml
    └── schema-hypers.json

Note

A new architecture doesn’t have to be registered somewhere in the file tree of metatrain. Once a new architecture folder with the required files is created metatrain will include the architecture automatically.

Model class (model.py)

The ModelInterface, is recommended to be located in a file called model.py inside the architecture folder is the main model class and must implement a save_checkpoint(), load_checkpoint() as well as a restart() and export() method.

from metatensor.torch.atomistic import MetatensorAtomisticModel, ModelMetadata

class ModelInterface:

    __supported_devices__ = ["cuda", "cpu"]
    __supported_dtypes__ = [torch.float64, torch.float32]
    __default_metadata__ = ModelMetadata(
        references = {"implementation": ["ref1"], "architecture": ["ref2"]}
    )

    def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):
        self.hypers = model_hypers
        self.dataset_info = dataset_info

    @classmethod
    def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface":
        pass

    def restart(cls, dataset_info: DatasetInfo) -> "ModelInterface":
        """Restart training.

        This function is called whenever training restarts, with the same or a
        different dataset.

        It enables transfer learning (changing the targets), and fine-tuning (same
        targets, different datasets)
        """
        pass

        def export(
    self, metadata: Optional[ModelMetadata] = None
) -> MetatensorAtomisticModel:
        pass

Note that the ModelInterface does not necessarily inherit from torch.nn.Module since training can be performed in any way. __supported_devices__ and __supported_dtypes__ can be defined to set the capabilities of the model. These two lists should be sorted in order of preference since metatrain will use these to determine, based on the user request and machines’ availability, the optimal dtype and device for training.

The __default_metadata__ is a class attribute that can be used to provide references that will be stored in the exported model. The references are stored in a dictionary with keys implementation and architecture. The implementation key should contain references to the software used in the implementation of the architecture, while the architecture key should contain references about the general architecture.

The export() method is required to transform a trained model into a standalone file to be used in combination with molecular dynamic engines to run simulations. We provide a helper function metatrain.utils.export.export() to export a torch model to an MetatensorAtomisticModel.

Trainer class (trainer.py)

The TrainerInterface class should have the following signature with required methods for train(), save_checkpoint() and load_checkpoint().

class TrainerInterface:
    def __init__(self, train_hypers):
        self.hypers = train_hypers

    def train(
        self,
        model: ModelInterface,
        dtype: torch.dtype,
        devices: List[torch.device],
        train_datasets: List[Union[Dataset, torch.utils.data.Subset]],
        val_datasets: List[Union[Dataset, torch.utils.data.Subset]],
        checkpoint_dir: str,
    ) -> None: ...

    def save_checkpoint(self, path: Union[str, Path]) -> None: ...

    @classmethod
    def load_checkpoint(
        cls, path: Union[str, Path], train_hypers: Dict
    ) -> "TrainerInterface":
        pass

The format of checkpoints is not defined by metatrain and can be any format that can be loaded by the trainer (to restart training) and by the model (to export the checkpoint). The only requirements are that the checkpoint must be loadable with torch.load(), it must be a dictionary, and it must contain the name of the architecture under the architecture_name key.

Init file (__init__.py)

The names of the ModelInterface and the TrainerInterface are free to choose but should be linked to constants in the __init__.py of each architecture. On top of these two constants the __init__.py must contain constants for the original __authors__ and current __maintainers__ of the architecture.

from .model import ModelInterface
from .trainer import TrainerInterface

__model__ = ModelInterface
__trainer__ = TrainerInterface

__authors__ = [
    ("Jane Roe <jane.roe@myuniversity.org>", "@janeroe"),
    ("John Doe <john.doe@otheruniversity.edu>", "@johndoe"),
]

__maintainers__ = [("Joe Bloggs <joe.bloggs@sotacompany.com>", "@joebloggs")]
param __model__:

Mapping of the custom ModelInterface to a general one to be loaded by metatrain.

param __trainer__:

Same as __MODEL_CLASS__ but the Trainer class.

param __authors__:

Tuple denoting the original authors with an email address and GitHub handle of an architecture. These do not necessarily be in charge of maintaining the architecture.

param __maintainers__:

Tuple denoting the current maintainers of the architecture. Uses the same style as the __authors__ constant.

Default Hyperparamers (default-hypers.yaml)

The default hyperparameters for each architecture should be stored in a YAML file default-hypers.yaml inside the architecture directory. Reasonable default hypers are required to improve usability. The default hypers must follow the structure

name: myarchitecture

model:
    ...

training:
    ...

metatrain will parse this file and overwrite these default hypers with the user-provided parameters and pass the merged model section as a Python dictionary to the ModelInterface and the training section to the TrainerInterface.

JSON schema (schema-hypers.yaml)

To validate the user’s input hyperparameters we are using JSON schemas stored in a schema file called schema-hypers.json. For an experimental architecture it is not required to provide such a schema along with its default hypers but it is highly recommended to reduce possible errors of user input like typos in parameter names or wrong sections. If no schema-hypers.json is provided no validation is performed and user hypers are passed to the architecture model and trainer as is.

To create such a schema start by using online tools that convert the default-hypers.yaml into a JSON schema. Besides online tools, we also had success using ChatGPT/LLM for this for conversion.

Documentation

Each new architecture should be added to metatrain’s documentation. A short page describing the architecture and its default hyperparameters will be sufficient. You can take inspiration from existing architectures. The various targets that the architecture can fit should be added to the table in the “Fitting generic targets” section.