Skip to content

Get backend compatibilities

jax_dataloader.core.get_backend_compatibilities() -> dict[str, list[type]] ยค

Return list of supported dataloader backends for each dataset type

Source code in jax_dataloader/core.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def get_backend_compatibilities() -> dict[str, list[type]]: # { backend: [supported datasets] }
    """Return list of supported dataloader backends for each dataset type"""

    ds = {
        JAXDataset: ArrayDataset(np.array([1,2,3])),
        TorchDataset: torch_data.Dataset(),
        TFDataset: tf.data.Dataset.from_tensor_slices(np.array([1,2,3])),
        HFDataset: hf_datasets.Dataset.from_dict({'a': [1,2,3]})
    }
    assert len(ds) == len(SUPPORTED_DATASETS)
    backends = {b: [] for b in _get_backends()}
    for b in _get_backends():
        for name, dataset in ds.items():
            try:
                _check_backend_compatibility(dataset, b)
                backends[b].append(name)
            except:
                pass

    return backends