maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Simplify the mulithost data loading code

Open ZacCranko opened this issue 1 year ago • 0 comments

See if this is an improvement for your purposes. This PR modifies the multihost data put code to infer the global shapes and build NamedShardings lazily at load time. This is cheap because all the relevant quantities are cached by Jax. Unsure if jax.local_process_count() is cached, so I'm instead using len(global_mesh.local_devices).

I made this modification to my personal fork because it means I can change up what columns are coming from the data loader without modifying a bunch of stuff in input_pipeline.py.

ZacCranko avatar Nov 13 '23 04:11 ZacCranko