Randomness suggestion
Thanks for putting the library together. Just a suggestion for giving users control over the randomness. What do you think about the DataLoader class taking a jax.random key as a keyword allowing users to create the seed and split as necessary in a calling function based on their use case. Something like:
key = jr.PRNGKey(123456])
data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5)
# other code
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
backend='jax', # Use 'jax' backend for loading data
batch_size=batch_size, # Batch size
shuffle=True, # Shuffle the dataloader every iteration or not
drop_last=False, # Drop the last batch or not
key=loader_key, # or None if not used
)
What are your thoughts?
Hi @aspannaus - Thanks for your suggestions. This library tries to bring the PyTorch dataloader API that works on the four supported datasets and three backends.
It seems like your proposal might not work for the pytorch and tensorflow backends. Plus, it is not in the PyTorch dataloader's class definition.
We currently support deterministic dataloading by setting the manual_seed:
jdl.manual_seed(1234) # Set the random seed to 1234 for reproducibility
# dataloader definition
dl = jdl.DataLoader(dataset, 'jax', ...)
This serves the same purpose as what you proposed. And this API design is agnostic to the dataset and backend choices.
Are there some use cases that which your proposal might be more preferable?
Ahh, right. I'm only using Jax at the moment, and wasn't thinking about how Jax handles the random number generation as compared with the other backends. I can see why what I've got above doesn't easily translate.
For the work I'm presently doing, I read in the seed from a config file and split the seed as needed for the model, dataloading, training, subsetting and generating samples from some distributions. The module here allows for setting the seed manually, but what ends up in my workflow is that the dataloading and one of these tasks has the same key sequence. My preference would be similar to the code I have above, where separate keys are split from the initial seed at the start, then each task, ie training, dataloading, sampling, manages its own sequence.
Thanks for your feedback. One workaround for now is to call manual_seed each time before initializing a dataloader. E.g.,
data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5)
# first loader
jdl.manual_seed(loader_key)
loader = jdl.DataLoader(dataset, 'jax', ...)
# second loader
jdl.manual_seed(train_key)
train_dl = jdl.DataLoader(dataset, 'jax', ...)
This should work for your use case, although I think this might not be the most elegant way of controlling the randomness.
Another idea is to introduce the Generator API, which is actually a part of PyTorch DataLoader's class parameter. We could define the jdl.DataLoader as something like this:
class DataLoader:
def __init__(
self,
dataset, # Dataset from which to load the data
...
generator: jdl.Generator | jrand.PRNGKey | torch.Generator = None, # <== Control the randomness here
**kwargs
):
This proposal seems to be a better way to control the randomness, and it also aligns with the scope of jax-dataloader.
I will try to create a PR to implement this.
This is addressed in #41
Thank you for your suggestion. Feel free to play around with this API and let me know if you encounter any issues.