axlearn
axlearn copied to clipboard
Improve pathways checkpoint load times
- Utilize a shared memory between the Jax client and pathways proxy for data heavy transfers e.g. device_puts.
- Increase threads of ThreadPoolExecutor from 32 (python default) to 192.
- Remove memory limit from pathways head main container.
Callers of deserialize should utilize a concurrent_restore_gb as large as possible until OOM. Otherwise GCS read and device_put won't happen in parallel. The default of 32GB is too low to achieve optimal performance with Pathways.