jax icon indicating copy to clipboard operation
jax copied to clipboard

CPU profiling (not tracing)

Open joaospinto opened this issue 1 year ago • 4 comments
trafficstars

I want to get some flamegraphs from some JIT'd JAX CPU code to understand where time is being spent (in terms of my user-defined functions).

My understanding (based on the docs) is that currently the recommended approach is to add some custom tracing events and run JAX's tracing feature.

This seems rather suboptimal. Is there a better way?

Related discussion: https://github.com/jax-ml/jax/discussions/19888

joaospinto avatar Oct 16 '24 21:10 joaospinto

One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost. So we would not be able to automatically generate a profile that contains information about user-defined functions. You can look at the compiled code yourself by running jax.jit(f).lower(*args).compiler_ir('hlo').

One workaround for this could be to decorate all of your user functions using jax.named_scope. After this, they should be visible in the trace viewer (https://jax.readthedocs.io/en/latest/profiling.html#tensorboard-profiling). It's not automatic, but it shouldn't be too much of an overhead.

justinjfu avatar Oct 17 '24 18:10 justinjfu

One difficulty with this feature request is that after a function is compiled into HLO, information about the original python function boundaries is lost.

There are several ways of exporting HLO/StableHLO from JAX, and many (certainly the StableHLO MLIR bytecode portable artifacts) do export location information (which maps HLO/StableHLO ops to the Python code that created them).

joaospinto avatar Oct 17 '24 19:10 joaospinto

For example, this can be used (although it might be not the most compact representation):

with open("output.hlo", "w") as f:
  ir.operation.print(
    enable_debug_info=True,
    pretty_debug_info=True,
    use_local_scope=True,
    file=f,
  )

joaospinto avatar Oct 17 '24 19:10 joaospinto

Related discussion: https://github.com/jax-ml/jax/issues/23251

joaospinto avatar Oct 18 '24 01:10 joaospinto

@jakevdp @gnecula @justinjfu Any thoughts?

joaospinto avatar Oct 25 '24 18:10 joaospinto