xla icon indicating copy to clipboard operation
xla copied to clipboard

Add profiler traces to `PjRtComputationClient`

Open will-cromar opened this issue 2 years ago • 6 comments

The existing profiler works out of the box because the traces work the same way. Tested manually with TensorBoard:

image

will-cromar avatar Aug 10 '22 19:08 will-cromar

The new test passed on CPU but flaked on GPU:

2022-08-10 23:44:46.343597: W  562435 tensorflow/core/profiler/lib/profiler_session.cc:107] Profiling is late by 1594605 nanoseconds and will start immediately.
2022-08-10 23:44:48.736212: W  562436 tensorflow/core/profiler/rpc/client/capture_profile.cc:133] No trace event is collected from localhost:9012

will-cromar avatar Aug 11 '22 17:08 will-cromar

The GPU test has Profiling is late by 1594605 nanoseconds vs CPU with Profiling is late by 760266 nanoseconds. Maybe the XLA execution had finished by the time tracing had started?

will-cromar avatar Aug 11 '22 18:08 will-cromar

There is some lag between when the tracer thread starts and when it actually starts tracing, long enough that XLA execution can finish before it starts. This test was only passing when xm.xla_device() and torch.ones() were sufficiently slow to let the tracer get a head start, and I was able to reproduce the error consistently by moving them before the tracer thread started.

Adding a half-second delay makes the test pass consistently on my machine. Hopefully that holds up on the CI 🤞

will-cromar avatar Aug 11 '22 21:08 will-cromar

Test still flaked on GPU (even though it's using CPU) and I can't reproduce the error locally. Removing it from the CI tests since I'll have to add a TPU version outside of the CI anyway.

will-cromar avatar Aug 12 '22 15:08 will-cromar

How is this test different from https://github.com/pytorch/xla/blob/master/test/test_profiler.py we already have? It seem like that tests runs mnist. If you switch the runtime from xrt to pjrt would test_profiler.py still pass?

For the TPU test, can you make sure device side trace also shows up?

JackCaoG avatar Aug 12 '22 21:08 JackCaoG

This test specifically checks for the PJRT traces I added in this PR. This test is also more isolated (no dependency on another test) and uses a much smaller graph, so it's actually possible to read what is in the trace file if there's an error. The other test is more of an end-to-end test, while I tried to write a minimal integration test here.

I tried running the other test with PJRT and it did not work. I would need to add a call to pjrt.run_multiprocess to set up the TPU environment variables, set the world size, etc.

I did manually check that the TPU device traces are in the output.

will-cromar avatar Aug 12 '22 21:08 will-cromar

@will-cromar Thanks for adding this feature! We tried using the profiler under PJRT in our applications and observed a few issues:

  1. Here in PJRT when starting a profiling server with xp.start_server on v3-8 running the PJRT ImageNet example (by adding server = xp.start_server(9012) at the beginning of train_imagenet), it only captures the device traces of 2 TPU cores. This is unlike on XRT where we could get the device traces of all 8 TPU cores.

A possible workaround is that we can start multiple profiler server with server = xp.start_server(9012 + xm.get_local_ordinal(), only_on_master=False) to capture traces on more TPUs.

We are capturing traces from the command line as follows (in a separate tmux window on the TPU VM host)

TRACE_DIR=~/workspace/pt_xla_tracing/tracing_pjrt
PORT=9012

sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python -c "$python_cmd"

and then viewing the captured traces from a tensorboard session in a remote coordinator VM.

  1. Unlike in the XRT cases where we propagate the xp.Trace annotations to TPU device traces with XLA_HLO_DEBUG=1 (as described in Annotation propagation), it seems that the TPU device traces captured under PJRT cannot propagate the annotations to TPU device traces even with XLA_HLO_DEBUG=1.

  2. We are happy to find that the profiler could work on a TPU pod (unlike the XRT cases where the profiler doesn't work on a pod as we reported in https://github.com/pytorch/xla/issues/3446). Meanwhile, on a v3-128 pod, it still only captures the device traces of 2 TPU cores.

  3. A related bug is that xm.get_local_ordinal only returns a value in 0,1,2,3 on TPU v3 (and two TPU v3 cores that are on the different chips in the TPU v3 board could have the same xm.get_local_ordinal), whereas it is expected to return a value in 0,1,2,3,4,5,6,7 on TPU v3 as in XRT. I believe we should module by 8 instead of 4 as local_world_size on TPU v3 in https://github.com/pytorch/xla/blob/5059c3b13b5c190d2871ccd004caff89117ae2af/torch_xla/experimental/pjrt.py#L160

Below is a screenshot illustrating the traces we captured with XLA_HLO_DEBUG=1 under PJRT on a v3-128 TPU, running our ImageNet FSDP test case.

pjrt_on_pod

ronghanghu avatar Sep 01 '22 00:09 ronghanghu

I have a feeling that only seeing one core is due to

  1. In XRT there is a single process(XRT_SERVER) driving the whole TPU device execution
  2. In PJRT on v3 each process only drive the TPU execution for 1 chip(2 core)

There should be a way around this through.

JackCaoG avatar Sep 01 '22 00:09 JackCaoG

  1. @JackCaoG is correct. Each process only has two cores, so this is expected. I also expect starting multiple profiler servers and tracing all of them to show more cores, although I think we need a better API to make this use case easier in practice.
  2. I wasn't aware of this use case. Can you file an issue to me and I can follow up?
  3. Awesome! I'm glad this works without having to implement anything extra (except a better profiler API that I alluded to in 1) 😁
  4. Good catch. I'm mainly doing my testing on TPU v4 so this slipped by. I need to refactor how we store these ordinals in #3949 anyway, so I'll make sure to fix this issue as part of that change.

will-cromar avatar Sep 01 '22 17:09 will-cromar

@JackCaoG is correct. Each process only has two cores, so this is expected. I also expect starting multiple profiler servers and tracing all of them to show more cores, although I think we need a better API to make this use case easier in practice.

I see, this makes sense to me. (I was thinking about the case in JAX PJRT where its jax.profiler.trace could capture the device traces of all 8 TPU v3 cores, but I guess the difference is that JAX runs in an SPMD manner :)

I wasn't aware of this use case. Can you file an issue for me and I can follow up?

Sounds good, I'll file a new issue on this. This is mostly to propagate the xp.Trace annotations (similar to jax.profiler.TraceAnnotation) onto the "TensorFlow Name Scope" annotations on the TPU device trace. It was working under XRT when we set XLA_HLO_DEBUG=1.

Below is a TPU trace that we captured earlier with the XRT runtime in PT/XLA, where the xp.Trace annotations such as "Forward model" can be propagated to TPU device traces under "TensorFlow Name Scope".

Screen Shot 2022-09-01 at 12 54 51 PM Screen Shot 2022-09-01 at 12 55 03 PM

I need to refactor how we store these ordinals in #3949 anyway, so I'll make sure to fix this issue as part of that change.

Awesome, thank you!

ronghanghu avatar Sep 01 '22 19:09 ronghanghu