Dataset

class metatrain.utils.data.dataset.DatasetInfo(length_unit: str | None, atomic_types: List[int], targets: Dict[str, TargetInfo], extra_data: Dict[str, TargetInfo] | None = None)[source]

Bases: object

A class that contains information about datasets.

This class is used to communicate additional dataset details to the training functions of the individual models.

Parameters:
  • length_unit (str | None) – Unit of length used in the dataset. Examples are "angstrom" or "nanometer". If None, the unit will be set to the empty string.

  • atomic_types (List[int]) – List containing all integer atomic types present in the dataset. atomic_types will be stored as a sorted list of unique atomic types.

  • targets (Dict[str, TargetInfo]) – Information about targets in the dataset.

  • extra_data (Dict[str, TargetInfo] | None) – Optional dictionary containing additional data that is not used as a target, but is still relevant to the dataset.

property atomic_types: List[int]

Sorted list of unique integer atomic types.

copy() DatasetInfo[source]

Return a shallow copy of the DatasetInfo.

Return type:

DatasetInfo

update(other: DatasetInfo) None[source]

Update this instance with the union of itself and other.

Raises:

ValueError – If the length_units are different.

Parameters:

other (DatasetInfo)

Return type:

None

union(other: DatasetInfo) DatasetInfo[source]

Return the union of this instance with other.

Parameters:

other (DatasetInfo)

Return type:

DatasetInfo

metatrain.utils.data.dataset.get_stats(dataset: Dataset | Subset, dataset_info: DatasetInfo) str[source]

Returns the statistics of a dataset or subset as a string.

Parameters:
Return type:

str

metatrain.utils.data.dataset.get_atomic_types(datasets: Dataset | List[Dataset]) List[int][source]

List of all atomic types present in a dataset or list of datasets.

Parameters:

datasets (Dataset | List[Dataset]) – the dataset, or list of datasets

Returns:

sorted list of all atomic types present in the datasets

Return type:

List[int]

metatrain.utils.data.dataset.get_all_targets(datasets: Dataset | List[Dataset]) List[str][source]

Sorted list of all unique targets present in a dataset or list of datasets.

Parameters:

datasets (Dataset | List[Dataset]) – the dataset(s).

Returns:

Sorted list of all targets present in the dataset(s).

Return type:

List[str]

class metatrain.utils.data.dataset.CollateFn(target_keys: List[str], join_kwargs: Dict[str, Any] | None = None)[source]

Bases: object

Parameters:
metatrain.utils.data.dataset.check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset])[source]

Check that the training and validation sets are compatible with one another

Although these checks will not fit all use cases, most models would be expected to be able to use this function.

Parameters:
  • train_datasets (List[Dataset]) – A list of training datasets to check.

  • val_datasets (List[Dataset]) – A list of validation datasets to check

Raises:
  • TypeError – If the dtype within the datasets are inconsistent.

  • ValueError – If the val_datasets has a target that is not present in the train_datasets.

  • ValueError – If the training or validation set contains chemical species or targets that are not present in the training set

class metatrain.utils.data.dataset.DiskDataset(path: str | Path, fields: List[str] | None = None)[source]

Bases: Dataset

A class representing a dataset stored on disk.

The dataset is stored in a zip file, where each sample is stored in a separate directory. The directory’s name is the index of the sample (e.g. 0/), and the files in the directory are the system (system.mta) and the targets (each named <target_name>.mts). These are metatomic.torch.System and metatensor.torch.TensorMap objects, respectively.

Such a dataset can be created conveniently using the DiskDatasetWriter class.

Parameters:
  • path (str | Path) – Path to the zip file containing the dataset.

  • fields (List[str] | None) – List of fields to read from the dataset. If None, all fields will be read.

get_target_info(target_config: DictConfig) Dict[str, TargetInfo][source]

Get information about the targets in the dataset.

Parameters:

target_config (DictConfig) – The user-provided (through the yaml file) target configuration.

Return type:

Dict[str, TargetInfo]