xla icon indicating copy to clipboard operation
xla copied to clipboard

[JAX] Automatically share PGO data for GPU latency-hiding scheduler.

Open copybara-service[bot] opened this issue 1 year ago • 0 comments

[JAX] Automatically share PGO data for GPU latency-hiding scheduler.

Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.

  1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.

  2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.

  3. The profile session runner should be passed to pxla.py and then called.

  4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h

  5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.

copybara-service[bot] avatar Mar 28 '24 10:03 copybara-service[bot]