IO¶
Functions to be used for handling the serialization of models
- metatrain.utils.io.check_file_extension(filename: str | Path, extension: str) str | Path [source]¶
Check the file extension of a file name and adds if it is not present.
If
filename
does not end withextension
theextension
is added and a warning will be issued.
- metatrain.utils.io.is_exported_file(path: str) bool [source]¶
Check if a saved model file has been exported to a metatomic
AtomisticModel
.The functions uses
metatomic.torch.check_atomistic_model()
to verify.- Parameters:
path (str) – model path
- Returns:
- Return type:
See also
metatomic.torch.is_atomistic_model()
to verify if an already loaded model is exported.
- metatrain.utils.io.load_model(path: str | Path, extensions_directory: str | Path | None = None, hf_token: str | None = None) Any [source]¶
Load checkpoints and exported models from an URL or a local file for inference.
If an exported model should be loaded and requires compiled extensions, their location should be passed using the
extensions_directory
parameter.After reading a checkpoint, the returned model can be exported with the model’s own
export()
method.Note
This function is intended to load models for inference in Python. For continue training or finetuning use metatrain’s command line interfaace
- Parameters:
path (str | Path) – local or remote path to a model. For supported URL schemes see
urllib.request
extensions_directory (str | Path | None) – path to a directory containing all extensions required by an exported model
hf_token (str | None) – HuggingFace API token to download (private) models from HuggingFace
- Raises:
ValueError – if
path
is a YAML option file and no modelValueError – if no
archietcture_name
is found in the checkpointValueError – if the
architecture_name
is not found in the available architectures
- Return type:
- metatrain.utils.io.model_from_checkpoint(path: str | Path, context=typing.Literal['restart', 'finetune', 'export']) Module [source]¶
Load the checkpoint at the given
path
, and create the corresponding model instance. The model architecture is determined from information stored inside the checkpoint.
- metatrain.utils.io.trainer_from_checkpoint(path: str | Path, context: Literal['restart', 'finetune', 'export'], hypers: Dict[str, Any]) Any [source]¶
Load the checkpoint at the given
path
, and create the corresponding trainer instance. The architecture is determined from information stored inside the checkpoint.