Skip to content

DataLoaderPytorch

jax_dataloader.loaders.torch.DataLoaderPytorch ¤

Bases: BaseDataLoader

Pytorch Dataloader

Source code in jax_dataloader/loaders/torch.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class DataLoaderPytorch(BaseDataLoader):
    """Pytorch Dataloader"""

    @typecheck
    def __init__(
        self, 
        dataset: Union[JAXDataset, TorchDataset, HFDataset],
        batch_size: int = 1,  # Batch size
        shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
        drop_last: bool = False, # Drop last batch or not
        **kwargs
    ):
        super().__init__(dataset, batch_size, shuffle, drop_last)
        check_pytorch_installed()
        from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler
        import torch

        if 'sampler' in kwargs:
            warnings.warn("`sampler` is currently not supported. We will ignore it and use `shuffle` instead.")
            del kwargs['sampler']

        # convert to torch dataset
        dataset = to_torch_dataset(dataset)
        # init batch sampler
        generator = torch.Generator().manual_seed(get_config().global_seed)
        if shuffle: 
            sampler = RandomSampler(dataset, generator=generator)
        else:       
            sampler = SequentialSampler(dataset)
        batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)

        self.dataloader = torch_data.DataLoader(
            dataset, 
            batch_sampler=batch_sampler,
            # batch_size=batch_size, 
            # shuffle=shuffle, 
            # drop_last=drop_last,
            collate_fn=_numpy_collate,
            **kwargs
        )

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return self.dataloader.__iter__()

Attributes¤

dataloader = torch_data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=_numpy_collate, **kwargs) instance-attribute ¤

Functions¤

__init__(dataset: Union[JAXDataset, TorchDataset, HFDataset], batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, **kwargs) ¤

Source code in jax_dataloader/loaders/torch.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@typecheck
def __init__(
    self, 
    dataset: Union[JAXDataset, TorchDataset, HFDataset],
    batch_size: int = 1,  # Batch size
    shuffle: bool = False,  # If true, dataloader shuffles before sampling each batch
    drop_last: bool = False, # Drop last batch or not
    **kwargs
):
    super().__init__(dataset, batch_size, shuffle, drop_last)
    check_pytorch_installed()
    from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler
    import torch

    if 'sampler' in kwargs:
        warnings.warn("`sampler` is currently not supported. We will ignore it and use `shuffle` instead.")
        del kwargs['sampler']

    # convert to torch dataset
    dataset = to_torch_dataset(dataset)
    # init batch sampler
    generator = torch.Generator().manual_seed(get_config().global_seed)
    if shuffle: 
        sampler = RandomSampler(dataset, generator=generator)
    else:       
        sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)

    self.dataloader = torch_data.DataLoader(
        dataset, 
        batch_sampler=batch_sampler,
        # batch_size=batch_size, 
        # shuffle=shuffle, 
        # drop_last=drop_last,
        collate_fn=_numpy_collate,
        **kwargs
    )

__iter__() ¤

Source code in jax_dataloader/loaders/torch.py
88
89
def __iter__(self):
    return self.dataloader.__iter__()

__len__() ¤

Source code in jax_dataloader/loaders/torch.py
82
83
def __len__(self):
    return len(self.dataloader)

__next__() ¤

Source code in jax_dataloader/loaders/torch.py
85
86
def __next__(self):
    return next(self.dataloader)