Skip to content

ArrayDataset

jax_dataloader.datasets.ArrayDataset ¤

Bases: Dataset

Dataset wrapping numpy arrays.

Source code in jax_dataloader/datasets.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class ArrayDataset(Dataset):
    """Dataset wrapping numpy arrays."""

    def __init__(
        self, 
        *arrays: jax.Array, # Numpy array with same first dimension
        asnumpy: bool = True, # Store arrays as numpy arrays if True; otherwise store as array type of *arrays
    ):
        assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
            "All arrays must have the same dimension."
        self.arrays = tuple(arrays)
        if asnumpy:
            self.asnumpy()            

    def asnumpy(self):
        """Convert all arrays to numpy arrays."""
        self.arrays = tuple(asnumpy(arr) for arr in self.arrays)

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, index):
        return jax.tree_util.tree_map(lambda x: x[index], self.arrays)

Attributes¤

arrays = tuple(arrays) instance-attribute ¤

Functions¤

__getitem__(index) ¤

Source code in jax_dataloader/datasets.py
43
44
def __getitem__(self, index):
    return jax.tree_util.tree_map(lambda x: x[index], self.arrays)

__init__(*arrays: jax.Array, asnumpy: bool = True) ¤

Source code in jax_dataloader/datasets.py
25
26
27
28
29
30
31
32
33
34
def __init__(
    self, 
    *arrays: jax.Array, # Numpy array with same first dimension
    asnumpy: bool = True, # Store arrays as numpy arrays if True; otherwise store as array type of *arrays
):
    assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
        "All arrays must have the same dimension."
    self.arrays = tuple(arrays)
    if asnumpy:
        self.asnumpy()            

__len__() ¤

Source code in jax_dataloader/datasets.py
40
41
def __len__(self):
    return self.arrays[0].shape[0]

asnumpy() ¤

Convert all arrays to numpy arrays.

Source code in jax_dataloader/datasets.py
36
37
38
def asnumpy(self):
    """Convert all arrays to numpy arrays."""
    self.arrays = tuple(asnumpy(arr) for arr in self.arrays)