[docs]defread_systems(filename:str)->List[System]:"""Read system information using metatensor. :param filename: name of the file to read :raises NotImplementedError: Serialization of systems is not yet available in metatensor. """raiseNotImplementedError("Reading metatensor systems is not yet implemented.")
def_wrapped_metatensor_read(filename)->TensorMap:try:returnmetatensor.torch.load(filename)exceptExceptionase:raiseValueError(f"Failed to read '{filename}' with torch: {e}")frome
[docs]defread_energy(target:DictConfig)->Tuple[TensorMap,TargetInfo]:tensor_map=_wrapped_metatensor_read(target["read_from"])iflen(tensor_map)!=1:raiseValueError("Energy TensorMaps should have exactly one block.")add_position_gradients=target["forces"]add_strain_gradients=target["stress"]ortarget["virial"]target_info=get_energy_target_info(target,add_position_gradients,add_strain_gradients)# now check all the expected metadata (from target_info.layout) matches# the actual metadata in the tensor maps_check_tensor_map_metadata(tensor_map,target_info.layout)selections=[Labels(names=["system"],values=torch.tensor([[int(i)]]),)foriintorch.unique(torch.concatenate([block.samples.column("system")forblockintensor_map.blocks()]))]tensor_maps=metatensor.torch.split(tensor_map,"samples",selections)returntensor_maps,target_info
[docs]defread_generic(target:DictConfig)->Tuple[List[TensorMap],TargetInfo]:tensor_map=_wrapped_metatensor_read(target["read_from"])forblockintensor_map.blocks():iflen(block.gradients_list())>0:raiseValueError("Only energy targets can have gradient blocks.")target_info=get_generic_target_info(target)_check_tensor_map_metadata(tensor_map,target_info.layout)# make sure that the properties of the target_info.layout also match the# actual properties of the tensor mapstarget_info.layout=_empty_tensor_map_like(tensor_map)selections=[Labels(names=["system"],values=torch.tensor([[int(i)]]),)foriintorch.unique(tensor_map.block(0).samples.column("system"))]tensor_maps=metatensor.torch.split(tensor_map,"samples",selections)returntensor_maps,target_info
def_check_tensor_map_metadata(tensor_map:TensorMap,layout:TensorMap):iftensor_map.keys!=layout.keys:raiseValueError(f"Unexpected keys in metatensor targets: "f"expected: {layout.keys} "f"actual: {tensor_map.keys}")forkeyinlayout.keys:block=tensor_map.block(key)block_from_layout=layout.block(key)ifblock.samples.names!=block_from_layout.samples.names:raiseValueError(f"Unexpected samples in metatensor targets: "f"expected: {block_from_layout.samples.names} "f"actual: {block.samples.names}")ifblock.components!=block_from_layout.components:raiseValueError(f"Unexpected components in metatensor targets: "f"expected: {block_from_layout.components} "f"actual: {block.components}")# the properties can be different from those of the default `TensorMap`# given by `get_generic_target_info`, so we don't check themifset(block.gradients_list())!=set(block_from_layout.gradients_list()):raiseValueError(f"Unexpected gradients in metatensor targets: "f"expected: {block_from_layout.gradients_list()} "f"actual: {block.gradients_list()}")fornameinblock_from_layout.gradients_list():gradient_block=block.gradient(name)gradient_block_from_layout=block_from_layout.gradient(name)ifgradient_block.labels.names!=gradient_block_from_layout.labels.names:raiseValueError(f"Unexpected samples in metatensor targets "f"for `{name}` gradient block: "f"expected: {gradient_block_from_layout.labels.names} "f"actual: {gradient_block.labels.names}")ifgradient_block.components!=gradient_block_from_layout.components:raiseValueError(f"Unexpected components in metatensor targets "f"for `{name}` gradient block: "f"expected: {gradient_block_from_layout.components} "f"actual: {gradient_block.components}")def_empty_tensor_map_like(tensor_map:TensorMap)->TensorMap:new_keys=tensor_map.keysnew_blocks:List[TensorBlock]=[]forblockintensor_map.blocks():new_block=_empty_tensor_block_like(block)new_blocks.append(new_block)returnTensorMap(keys=new_keys,blocks=new_blocks)def_empty_tensor_block_like(tensor_block:TensorBlock)->TensorBlock:new_block=TensorBlock(values=torch.empty((0,)+tensor_block.values.shape[1:],dtype=torch.float64,# metatensor can't serialize otherwisedevice=tensor_block.values.device,),samples=Labels(names=tensor_block.samples.names,values=torch.empty((0,tensor_block.samples.values.shape[1]),dtype=tensor_block.samples.values.dtype,device=tensor_block.samples.values.device,),),components=tensor_block.components,properties=tensor_block.properties,)forgradient_name,gradientintensor_block.gradients():new_block.add_gradient(gradient_name,_empty_tensor_block_like(gradient))returnnew_block