maxtext
maxtext copied to clipboard
Simplify the mulithost data loading code
See if this is an improvement for your purposes. This PR modifies the multihost data put code to infer the global shapes and build NamedSharding
s lazily at load time. This is cheap because all the relevant quantities are cached by Jax. Unsure if jax.local_process_count()
is cached, so I'm instead using len(global_mesh.local_devices)
.
I made this modification to my personal fork because it means I can change up what columns are coming from the data loader without modifying a bunch of stuff in input_pipeline.py
.