axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

Improve pathways checkpoint load times

Open samos123 opened this issue 3 months ago • 0 comments

  • Utilize a shared memory between the Jax client and pathways proxy for data heavy transfers e.g. device_puts.
  • Increase threads of ThreadPoolExecutor from 32 (python default) to 192.
  • Remove memory limit from pathways head main container.

Callers of deserialize should utilize a concurrent_restore_gb as large as possible until OOM. Otherwise GCS read and device_put won't happen in parallel. The default of 32GB is too low to achieve optimal performance with Pathways.

samos123 avatar Sep 23 '25 16:09 samos123