flax
flax copied to clipboard
Single Device prefetch_to_device
prefetch_to_device
is great when doing distributed training, but can be a hassle for non-pmap use cases since you need to add then remove a dummy sharding dimension at the front. A convenience function to prefetch data onto the device in single GPU training would be helpful.
Also, prefetch_to_device
returns ShardedDeviceArray
s instead of DeviceArray
s. Is there a difference, or are the two interchangeable? Thanks in advance!
Hi @n2cholas -- very sorry for the late reply! I'll rope in @jheek who has thought about this much more carefully.
@jheek is this something we are still planning on doing?
I'll close this for now, users can already "prefetch" something to a single device using jax.device_put
. The use case seems somewhat niche and prefetching is mostly handled automatically or manually through JAX APIs these days.
Hi @n2cholas ,
Do you have any working code that can use prefetch_to_device
with single GPU? My dataset comes from TFDS and I am struggling with how to add the dummy dimmension for prefetch_to_device
.
I created a thread here. I'd be much appreciated if you can share some experiences over there.
Thanks!