acme icon indicating copy to clipboard operation
acme copied to clipboard

Avoid inlining large arrays in JaxInMemoryRandomSampleIterator

Open ethanluoyc opened this issue 2 years ago • 0 comments
trafficstars

The JaxInMemoryRandomSampleIterator currently inlines the in-memory dataset. See https://github.com/deepmind/acme/blob/master/acme/datasets/tfds.py#L199-L200

This causes some OOM issues due to some issues in XLA and also when running on GPU the process might hang. I have filed a more detailed issue in the JAX project https://github.com/google/jax/issues/14080 and the authors recommend not inlining the array instead. I can create a PR if the developers would like to fix that.

ethanluoyc avatar Jan 25 '23 12:01 ethanluoyc