.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/programmatic/llpr/llpr.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_programmatic_llpr_llpr.py: Computing LLPR uncertainties ============================ This tutorial demonstrates how to use an already trained and exported model from Python. It involves the computation of the local prediction rigidity (`LPR `_) for every atom of a single ethanol molecule, using the last-layer prediction rigidity (`LLPR `_) approximation. .. _LPR: https://pubs.acs.org/doi/10.1021/acs.jctc.3c00704 .. _LLPR: https://arxiv.org/html/2403.02251v1 The model was trained using the following training options. .. literalinclude:: options.yaml :language: yaml You can train the same model yourself with .. literalinclude:: train.sh :language: bash A detailed step-by-step introduction on how to train a model is provided in the :ref:`label_basic_usage` tutorial. .. GENERATED FROM PYTHON SOURCE LINES 29-35 .. code-block:: Python import torch from metatrain.utils.io import load_model .. GENERATED FROM PYTHON SOURCE LINES 36-40 Models can be loaded using the :func:`metatrain.utils.io.load_model` function from the. For already exported models The function requires the path to the exported model and, for many models, also the path to the respective extensions directory. Both are produced during the training process. .. GENERATED FROM PYTHON SOURCE LINES 41-45 .. code-block:: Python model = load_model("model.pt", extensions_directory="extensions/") .. GENERATED FROM PYTHON SOURCE LINES 46-49 In metatrain, a Dataset is composed of a list of systems and a dictionary of targets. The following lines illustrate how to read systems and targets from xyz files, and how to create a Dataset object from them. .. GENERATED FROM PYTHON SOURCE LINES 50-92 .. code-block:: Python from metatrain.utils.data import Dataset, read_systems, read_targets # noqa: E402 from metatrain.utils.neighbor_lists import ( # noqa: E402 get_requested_neighbor_lists, get_system_with_neighbor_lists, ) qm9_systems = read_systems("qm9_reduced_100.xyz") target_config = { "energy": { "quantity": "energy", "read_from": "ethanol_reduced_100.xyz", "reader": "ase", "key": "energy", "unit": "kcal/mol", "type": "scalar", "per_atom": False, "num_subtargets": 1, "forces": False, "stress": False, "virial": False, }, } targets, _ = read_targets(target_config) requested_neighbor_lists = get_requested_neighbor_lists(model) qm9_systems = [ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems ] dataset = Dataset.from_dict({"system": qm9_systems, **targets}) # We also load a single ethanol molecule on which we will compute properties. # This system is loaded without targets, as we are only interested in the LPR # values. ethanol_system = read_systems("ethanol_reduced_100.xyz")[0] ethanol_system = get_system_with_neighbor_lists( ethanol_system, requested_neighbor_lists ) .. GENERATED FROM PYTHON SOURCE LINES 93-95 The dataset is fully compatible with torch. For example, be used to create a DataLoader object. .. GENERATED FROM PYTHON SOURCE LINES 96-108 .. code-block:: Python from metatrain.utils.data import collate_fn # noqa: E402 dataloader = torch.utils.data.DataLoader( dataset, batch_size=10, shuffle=False, collate_fn=collate_fn, ) .. GENERATED FROM PYTHON SOURCE LINES 109-112 We now wrap the model in a LLPRUncertaintyModel object, which will allows us to compute prediction rigidity metrics, which are useful for uncertainty quantification and model introspection. .. GENERATED FROM PYTHON SOURCE LINES 113-136 .. code-block:: Python from metatensor.torch.atomistic import ( # noqa: E402 MetatensorAtomisticModel, ModelMetadata, ) from metatrain.utils.llpr import LLPRUncertaintyModel # noqa: E402 llpr_model = LLPRUncertaintyModel(model) llpr_model.compute_covariance(dataloader) llpr_model.compute_inverse_covariance(regularizer=1e-4) # calibrate on the same dataset for simplicity. In reality, a separate # calibration/validation dataset should be used. llpr_model.calibrate(dataloader) exported_model = MetatensorAtomisticModel( llpr_model.eval(), ModelMetadata(), llpr_model.capabilities, ) .. GENERATED FROM PYTHON SOURCE LINES 137-141 We can now use the model to compute the LPR for every atom in the ethanol molecule. To do so, we create a ModelEvaluationOptions object, which is used to request specific outputs from the model. In this case, we request the uncertainty in the atomic energy predictions. .. GENERATED FROM PYTHON SOURCE LINES 142-163 .. code-block:: Python from metatensor.torch.atomistic import ModelEvaluationOptions, ModelOutput # noqa: E402 evaluation_options = ModelEvaluationOptions( length_unit="angstrom", outputs={ # request the uncertainty in the atomic energy predictions "energy": ModelOutput(per_atom=True), # needed to request the uncertainties "mtt::aux::energy_uncertainty": ModelOutput(per_atom=True), # `per_atom=False` would return the total uncertainty for the system, # or (the inverse of) the TPR (total prediction rigidity) # you also can request other outputs from the model here, for example: # "mtt::aux::energy_last_layer_features": ModelOutput(per_atom=True), }, selected_atoms=None, ) outputs = exported_model([ethanol_system], evaluation_options, check_consistency=False) lpr = outputs["mtt::aux::energy_uncertainty"].block().values.detach().cpu().numpy() .. GENERATED FROM PYTHON SOURCE LINES 164-166 We can now visualize the LPR values using the `plot_atoms` function from ``ase.visualize.plot``. .. GENERATED FROM PYTHON SOURCE LINES 167-189 .. code-block:: Python import ase.io # noqa: E402 import matplotlib.pyplot as plt # noqa: E402 from ase.visualize.plot import plot_atoms # noqa: E402 from matplotlib.colors import LogNorm # noqa: E402 structure = ase.io.read("ethanol_reduced_100.xyz") norm = LogNorm(vmin=min(lpr), vmax=max(lpr)) colormap = plt.get_cmap("viridis") colors = colormap(norm(lpr)) ax = plot_atoms(structure, colors=colors, rotation="180x,0y,0z") custom_ticks = [1e10, 2e10, 5e10, 1e11, 2e11] cbar = plt.colorbar( plt.cm.ScalarMappable(norm=norm, cmap=colormap), ax=ax, label="LPR", ticks=custom_ticks, ) cbar.ax.set_yticklabels([f"{tick:.0e}" for tick in custom_ticks]) cbar.minorticks_off() plt.show() .. image-sg:: /examples/programmatic/llpr/images/sphx_glr_llpr_001.png :alt: llpr :srcset: /examples/programmatic/llpr/images/sphx_glr_llpr_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 8.710 seconds) .. _sphx_glr_download_examples_programmatic_llpr_llpr.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: llpr.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: llpr.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: llpr.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_