Fine-tuning¶
Warning
This section of the documentation is only relevant for PET model so far.
This section describes the process of fine-tuning a pre-trained model to adapt it to new tasks or datasets. Fine-tuning is a common technique used in machine learning, where a model is trained on a large dataset and then fine-tuned on a smaller dataset to improve its performance on specific tasks. So far the fine-tuning capabilities are only available for PET model.
Note
Please note that the fine-tuning recommendations in this section are not universal and require testing on your specific dataset to achieve the best results. You might need to experiment with different fine-tuning strategies depending on your needs.
Basic Fine-tuning¶
The basic way to fine-tune a model is to use the mtt train
command with the
available pre-trained model defined in an options.yaml
file. In this case, all the
weights of the model will be adapted to the new dataset. This approach is similar to the
training continuation described in the Checkpoints section, but in
contrast to that the optimizer and scheduler state will be reset. You can still adjust
the training hyperparameters in the options.yaml
file, but the model architecture
will be taken from the checkpoint.
We recommend to use a lower learning rate than the one used for the original training, as
this will help stabilizing the training process. I.e. if the default learning rate is
1e-4
, you can set it to 1e-5
or even lower, using the following in the
options.yaml
file:
architecture:
training:
learning_rate: 1e-5
Please note, that in the case of the basic fine-tuning, the composition model weights will be taken from the checkpoint and not adapted to the new dataset.
The basic fine-tuning startgy is a good choice in the case when the level of theory which is used for the original training is the same, or at least similar to the one used for the new dataset. However, since this is not always the case, we also provide more advanced fine-tuning strategies described below.
Fine-tuning model Heads¶
Adapting all the model weights to a new dataset is not always the best approach. If the new dataset consist of the same or similar data computed with a slightly different level of theory compared to the pre-trained models’ dataset, you might want to keep the learned representations of the crystal structures and only adapt the readout layers (i.e. the model heads) to the new dataset.
In this case, the mtt train
command needs to be accompanied by the specific training
options in the options.yaml
file. The following options need to be set:
architecture:
training:
finetune:
method: "heads"
read_from: path/to/checkpoint.ckpt
config:
head_modules: ['node_heads', 'edge_heads']
last_layer_modules: ['node_last_layers', 'edge_last_layers']
The method
parameter specifies the fine-tuning method to be used and the
read_from
parameter specifies the path to the pre-trained model checkpoint. The
head_modules
and last_layer_modules
parameters specify the modules to be
fine-tuned. Here, the node_*
and edge_*
modules represent different parts of the
model readout layers related to the atom-based and bond-based features. The
*_last_layer
modules are the last layers of the corresponding heads, implemented as
multi-layer perceptron (MLPs). You can select different combinations of the node and
edge heads and last layers to be fine-tuned.
We recommend to first start the fine-tuning including all the modules listed above and
experiment with their different combinations if needed. You might also consider using a lower
learning rate, e.g. 1e-5
or even lower, to stabilize the training process.
LoRA Fine-tuning¶
If the conceptually new type of structures is introduced in the new dataset, tuning only the model heads might not be sufficient. In this case, you might need to adapt the internal representations of the crystal structures. This can be done using the LoRA technique. However, in this case the model heads will be not adapted to the new dataset, so conceptually the level of theory should be consistent with the one used for the pre-trained model.
What is LoRA?¶
LoRA (Low-Rank Adaptation) stands for a Parameter-Efficient Fine-Tuning (PEFT) technique used to adapt pre-trained models to new tasks by introducing low-rank matrices into the model’s architecture.
Given a pre-trained model with the weights matrix
where
To use LoRA for fine-tuning, you need to provide the pre-trained model checkpoint with
the mtt train
command and specify the LoRA parameters in the options.yaml
file:
architecture:
training:
finetune:
method: "lora"
read_from: path/to/pre-trained-model.ckpt
config:
alpha: 0.1
rank: 4
These parameters control the rank of the low-rank matrices introduced by LoRA
(rank
), and the regularization factor for the low-rank matrices (alpha
).
By selecting the LoRA rank and the regularization factor, you can control the
amount of adaptation to the new dataset. Using lower values of the rank and
the regularization factor will lead to a more conservative adaptation, which can help
balancing the performance of the model on the original and new datasets.
We recommend to start with the LoRA parameters listed above and experiment with
different values if needed. You might also consider using a lower learning rate,
e.g. 1e-5
or even lower, to stabilize the training process.
Fine-tuning on a new level of theory¶
If the new dataset is computed with a totally different level of theory compared to the
pre-trained model, which includes, for instance, the different composition energies,
or you want to fine-tune the model on a completely new target, you might need to consider
the transfer learning approach and introduce a new target in the
options.yaml
file. More details about this approach can be found in the
Transfer Learning section of the documentation.