Dataset¶
- class metatrain.utils.data.dataset.DatasetInfo(length_unit: str, atomic_types: List[int], targets: Dict[str, TargetInfo])[source]¶
Bases:
object
A class that contains information about datasets.
This class is used to communicate additional dataset details to the training functions of the individual models.
- Parameters:
length_unit (str) – Unit of length used in the dataset. Examples are
"angstrom"
or"nanometer"
.atomic_types (List[int]) – List containing all integer atomic types present in the dataset.
atomic_types
will be stored as a sorted list of unique atomic types.targets (Dict[str, TargetInfo]) – Information about targets in the dataset.
- copy() DatasetInfo [source]¶
Return a shallow copy of the DatasetInfo.
- Return type:
- update(other: DatasetInfo) None [source]¶
Update this instance with the union of itself and
other
.- Raises:
ValueError – If the
length_units
are different.- Parameters:
other (DatasetInfo)
- Return type:
None
- union(other: DatasetInfo) DatasetInfo [source]¶
Return the union of this instance with
other
.- Parameters:
other (DatasetInfo)
- Return type:
- metatrain.utils.data.dataset.get_stats(dataset: Dataset | Subset, dataset_info: DatasetInfo) str [source]¶
Returns the statistics of a dataset or subset as a string.
- Parameters:
dataset_info (DatasetInfo)
- Return type:
- metatrain.utils.data.dataset.get_atomic_types(datasets: Dataset | List[Dataset]) List[int] [source]¶
List of all atomic types present in a dataset or list of datasets.
- metatrain.utils.data.dataset.get_all_targets(datasets: Dataset | List[Dataset]) List[str] [source]¶
Sorted list of all unique targets present in a dataset or list of datasets.
- metatrain.utils.data.dataset.collate_fn(batch: List[Dict[str, Any]]) Tuple[List, Dict[str, TensorMap]] [source]¶
Wraps group_and_join to return the data fields as a list of systems, and a dictionary of nameed targets.
- metatrain.utils.data.dataset.check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset])[source]¶
Check that the training and validation sets are compatible with one another
Although these checks will not fit all use cases, most models would be expected to be able to use this function.
- Parameters:
- Raises:
TypeError – If the
dtype
within the datasets are inconsistent.ValueError – If the val_datasets has a target that is not present in the
train_datasets
.ValueError – If the training or validation set contains chemical species or targets that are not present in the training set
- class metatrain.utils.data.dataset.DiskDataset(path: str | Path)[source]¶
Bases:
Dataset
A class representing a dataset stored on disk.
The dataset is stored in a zip file, where each sample is stored in a separate directory. The directory’s name is the index of the sample (e.g.
0/
), and the files in the directory are the system (system.mta
) and the targets (each named<target_name>.mts
). These aremetatensor.torch.atomistic.System
andmetatensor.torch.TensorMap
objects, respectively.Such a dataset can be created conveniently using the
DiskDatasetWriter
class.
- class metatrain.utils.data.dataset.DiskDatasetWriter(path: str | Path)[source]¶
Bases:
object
A class for writing a dataset to disk, to be read by the
DiskDataset
class.The class is initialized with a path to a zip file, and samples can be written to the zip file using the
write_sample()
method.