Skip to content

DataLoaderTensorflow

jax_dataloader.loaders.tensorflow.DataLoaderTensorflow ¤

Bases: BaseDataLoader

Tensorflow Dataloader

Source code in jax_dataloader/loaders/tensorflow.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class DataLoaderTensorflow(BaseDataLoader):
    """Tensorflow Dataloader"""

    @typecheck
    def __init__(
        self, 
        dataset: Union[JAXDataset, TFDataset, HFDataset],
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        check_tf_installed()
        # Convert to tf dataset
        ds = to_tf_dataset(dataset)
        ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
        ds = ds.batch(batch_size, drop_remainder=drop_last)
        ds = ds.prefetch(tf.data.AUTOTUNE)
        self.dataloader = ds

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.as_numpy_iterator()

Attributes¤

dataloader = ds instance-attribute ¤

Functions¤

__init__(dataset: Union[JAXDataset, TFDataset, HFDataset], batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, **kwargs) ¤

Source code in jax_dataloader/loaders/tensorflow.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@typecheck
def __init__(
    self, 
    dataset: Union[JAXDataset, TFDataset, HFDataset],
    batch_size: int = 1,  # Batch size
    shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
    drop_last: bool = False, # Drop last batch or not
    **kwargs
):
    super().__init__(dataset, batch_size, shuffle, drop_last)
    check_tf_installed()
    # Convert to tf dataset
    ds = to_tf_dataset(dataset)
    ds = ds.shuffle(buffer_size=len(dataset), seed=get_config().global_seed) if shuffle else ds
    ds = ds.batch(batch_size, drop_remainder=drop_last)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    self.dataloader = ds

__iter__() ¤

Source code in jax_dataloader/loaders/tensorflow.py
56
57
def __iter__(self):
    return self.dataloader.as_numpy_iterator()

__len__() ¤

Source code in jax_dataloader/loaders/tensorflow.py
50
51
def __len__(self):
    return len(self.dataloader)

__next__() ¤

Source code in jax_dataloader/loaders/tensorflow.py
53
54
def __next__(self):
    return next(self.dataloader)