Add CLI for profiling JAX scripts with output and backend options
When using any module that utilizes start_trace() from _src/profiler.py, allows profiler trace destination to be specified through command line like JAX_PROFILE = output.pb
Thanks for the contribution!
I think #20293 implied a slightly different idea, where profiling is enabled via an environment variable. The version in this PR might work fine for users in open-source, but it is not compatible as is with Google-internal infra, unfortunately.
@superbobry do you mean something like this should suffice?
def start_trace(log_dir, create_perfetto_link: bool = False,
create_perfetto_trace: bool = False) -> xla_client.profiler.ProfilerSession:
#....code above
jax_profile_output = os.getenv('JAX_PROFILE')
if jax_profile_output:
log_dir = jax_profile_output
#....code below
Yep, something like this.
Cool, thanks for the patience! I've updated everything with the changes, so please feel free to review/merge.
Hello ! I was looking for a first issue to work in, it seems that this has been completed. I don't want to be annoying, just to remember @superbobry to merge or merge or deny to review :)