class DataLoaderPytorch(BaseDataLoader):
"""Pytorch Dataloader"""
@typecheck
def __init__(
self,
dataset: Union[JAXDataset, TorchDataset, HFDataset],
batch_size: int = 1, # Batch size
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
check_pytorch_installed()
from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler
import torch
if 'sampler' in kwargs:
warnings.warn("`sampler` is currently not supported. We will ignore it and use `shuffle` instead.")
del kwargs['sampler']
# convert to torch dataset
dataset = to_torch_dataset(dataset)
# init batch sampler
generator = torch.Generator().manual_seed(get_config().global_seed)
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)
self.dataloader = torch_data.DataLoader(
dataset,
batch_sampler=batch_sampler,
# batch_size=batch_size,
# shuffle=shuffle,
# drop_last=drop_last,
collate_fn=_numpy_collate,
**kwargs
)
def __len__(self):
return len(self.dataloader)
def __next__(self):
return next(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()