Source code for metatrain.utils.data.combine_dataloaders
fromtypingimportListimportnumpyasnpimporttorch
[docs]classCombinedDataLoader:""" Combines multiple dataloaders into a single dataloader. This is useful for learning from multiple datasets at the same time, each of which may have different batch sizes, properties, etc. :param dataloaders: list of dataloaders to combine :param shuffle: whether to shuffle the combined dataloader (this does not act on the individual batches, but it shuffles the order in which they are returned) :return: the combined dataloader """def__init__(self,dataloaders:List[torch.utils.data.DataLoader],shuffle:bool):self.dataloaders=dataloadersself.shuffle=shuffle# Create the indices:self.indices=list(range(len(self)))# Shuffle the indices if requestedifself.shuffle:np.random.shuffle(self.indices)self.reset()defreset(self):self.current_index=0self.full_list=[batchfordlinself.dataloadersforbatchindl]def__iter__(self):returnselfdef__next__(self):ifself.current_index>=len(self.indices):self.reset()# Reset the index for the next iterationraiseStopIterationidx=self.indices[self.current_index]self.current_index+=1returnself.full_list[idx]
[docs]def__len__(self):"""Total number of batches in all dataloaders. This returns the total number of batches in all dataloaders (as opposed to the total number of samples or the number of individual dataloaders). :return: the total number of batches in all dataloaders """returnsum(len(dl)fordlinself.dataloaders)