class DataLoaderTensorflow(BaseDataLoader):
"""Tensorflow Dataloader"""
@typecheck
def __init__(
self,
dataset: Union[JAXDataset, TFDataset, HFDataset],
batch_size: int = 1, # Batch size
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
check_tf_installed()
# Convert to tf dataset
ds = to_tf_dataset(dataset)
ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
ds = ds.batch(batch_size, drop_remainder=drop_last)
ds = ds.prefetch(tf.data.AUTOTUNE)
self.dataloader = ds
def __len__(self):
return len(self.dataloader)
def __next__(self):
return next(self.dataloader)
def __iter__(self):
return self.dataloader.as_numpy_iterator()