flax icon indicating copy to clipboard operation
flax copied to clipboard

Single Device prefetch_to_device

Open n2cholas opened this issue 4 years ago • 1 comments

prefetch_to_device is great when doing distributed training, but can be a hassle for non-pmap use cases since you need to add then remove a dummy sharding dimension at the front. A convenience function to prefetch data onto the device in single GPU training would be helpful.

Also, prefetch_to_device returns ShardedDeviceArrays instead of DeviceArrays. Is there a difference, or are the two interchangeable? Thanks in advance!

n2cholas avatar Dec 05 '20 05:12 n2cholas

Hi @n2cholas -- very sorry for the late reply! I'll rope in @jheek who has thought about this much more carefully.

avital avatar Dec 17 '20 11:12 avital

@jheek is this something we are still planning on doing?

marcvanzee avatar Sep 06 '22 12:09 marcvanzee

I'll close this for now, users can already "prefetch" something to a single device using jax.device_put. The use case seems somewhat niche and prefetching is mostly handled automatically or manually through JAX APIs these days.

jheek avatar Sep 06 '22 13:09 jheek

Hi @n2cholas ,

Do you have any working code that can use prefetch_to_device with single GPU? My dataset comes from TFDS and I am struggling with how to add the dummy dimmension for prefetch_to_device.

I created a thread here. I'd be much appreciated if you can share some experiences over there.

Thanks!

davidshen84 avatar Apr 19 '24 06:04 davidshen84