Dataset

class metatrain.utils.data.dataset.DatasetInfo(length_unit: str, atomic_types: List[int], targets: Dict[str, TargetInfo])[source]

Bases: object

A class that contains information about datasets.

This class is used to communicate additional dataset details to the training functions of the individual models.

Parameters:
  • length_unit (str) – Unit of length used in the dataset. Examples are "angstrom" or "nanometer".

  • atomic_types (List[int]) – List containing all integer atomic types present in the dataset. atomic_types will be stored as a sorted list of unique atomic types.

  • targets (Dict[str, TargetInfo]) – Information about targets in the dataset.

property atomic_types: List[int]

Sorted list of unique integer atomic types.

copy() DatasetInfo[source]

Return a shallow copy of the DatasetInfo.

Return type:

DatasetInfo

update(other: DatasetInfo) None[source]

Update this instance with the union of itself and other.

Raises:

ValueError – If the length_units are different.

Parameters:

other (DatasetInfo)

Return type:

None

union(other: DatasetInfo) DatasetInfo[source]

Return the union of this instance with other.

Parameters:

other (DatasetInfo)

Return type:

DatasetInfo

metatrain.utils.data.dataset.get_stats(dataset: Dataset | Subset, dataset_info: DatasetInfo) str[source]

Returns the statistics of a dataset or subset as a string.

Parameters:
Return type:

str

metatrain.utils.data.dataset.get_atomic_types(datasets: Dataset | List[Dataset]) List[int][source]

List of all atomic types present in a dataset or list of datasets.

Parameters:

datasets (Dataset | List[Dataset]) – the dataset, or list of datasets

Returns:

sorted list of all atomic types present in the datasets

Return type:

List[int]

metatrain.utils.data.dataset.get_all_targets(datasets: Dataset | List[Dataset]) List[str][source]

Sorted list of all unique targets present in a dataset or list of datasets.

Parameters:

datasets (Dataset | List[Dataset]) – the dataset(s).

Returns:

Sorted list of all targets present in the dataset(s).

Return type:

List[str]

metatrain.utils.data.dataset.collate_fn(batch: List[Dict[str, Any]]) Tuple[List, Dict[str, TensorMap]][source]

Wraps group_and_join to return the data fields as a list of systems, and a dictionary of nameed targets.

Parameters:

batch (List[Dict[str, Any]])

Return type:

Tuple[List, Dict[str, TensorMap]]

metatrain.utils.data.dataset.check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset])[source]

Check that the training and validation sets are compatible with one another

Although these checks will not fit all use cases, most models would be expected to be able to use this function.

Parameters:
  • train_datasets (List[Dataset]) – A list of training datasets to check.

  • val_datasets (List[Dataset]) – A list of validation datasets to check

Raises:
  • TypeError – If the dtype within the datasets are inconsistent.

  • ValueError – If the val_datasets has a target that is not present in the train_datasets.

  • ValueError – If the training or validation set contains chemical species or targets that are not present in the training set

class metatrain.utils.data.dataset.DiskDataset(path: str | Path)[source]

Bases: Dataset

A class representing a dataset stored on disk.

The dataset is stored in a zip file, where each sample is stored in a separate directory. The directory’s name is the index of the sample (e.g. 0/), and the files in the directory are the system (system.mta) and the targets (each named <target_name>.mts). These are metatensor.torch.atomistic.System and metatensor.torch.TensorMap objects, respectively.

Such a dataset can be created conveniently using the DiskDatasetWriter class.

Parameters:

path (str | Path) – Path to the zip file containing the dataset.

get_target_info(target_config: DictConfig) Dict[str, TargetInfo][source]

Get information about the targets in the dataset.

Parameters:

target_config (DictConfig) – The user-provided (through the yaml file) target configuration.

Return type:

Dict[str, TargetInfo]

class metatrain.utils.data.dataset.DiskDatasetWriter(path: str | Path)[source]

Bases: object

A class for writing a dataset to disk, to be read by the DiskDataset class.

The class is initialized with a path to a zip file, and samples can be written to the zip file using the write_sample() method.

Parameters:

path (str | Path) – Path to the zip file to write the dataset to.

write_sample(system: System, targets: Dict[str, TensorMap])[source]

Write a sample to the zip file.

Parameters:
  • system (System) – The system to write.

  • targets (Dict[str, TensorMap]) – A dictionary of targets to write, where each value is a TensorMap.