jax icon indicating copy to clipboard operation
jax copied to clipboard

Add CLI for profiling JAX scripts with output and backend options

Open nirmalmuppiri opened this issue 1 year ago • 5 comments

When using any module that utilizes start_trace() from _src/profiler.py, allows profiler trace destination to be specified through command line like JAX_PROFILE = output.pb

nirmalmuppiri avatar Jun 05 '24 09:06 nirmalmuppiri

Thanks for the contribution!

I think #20293 implied a slightly different idea, where profiling is enabled via an environment variable. The version in this PR might work fine for users in open-source, but it is not compatible as is with Google-internal infra, unfortunately.

superbobry avatar Jun 06 '24 18:06 superbobry

@superbobry do you mean something like this should suffice?

def start_trace(log_dir, create_perfetto_link: bool = False,
                create_perfetto_trace: bool = False) -> xla_client.profiler.ProfilerSession:
        #....code above

        jax_profile_output = os.getenv('JAX_PROFILE')
        if jax_profile_output:
            log_dir = jax_profile_output

        #....code below


nirmalmuppiri avatar Jun 08 '24 16:06 nirmalmuppiri

Yep, something like this.

superbobry avatar Jun 12 '24 10:06 superbobry

Cool, thanks for the patience! I've updated everything with the changes, so please feel free to review/merge.

nirmalmuppiri avatar Jun 12 '24 13:06 nirmalmuppiri

Hello ! I was looking for a first issue to work in, it seems that this has been completed. I don't want to be annoying, just to remember @superbobry to merge or merge or deny to review :)

pedrochans avatar Aug 21 '24 11:08 pedrochans