xla
xla copied to clipboard
Add profiler traces to `PjRtComputationClient`
The existing profiler works out of the box because the traces work the same way. Tested manually with TensorBoard:
The new test passed on CPU but flaked on GPU:
2022-08-10 23:44:46.343597: W 562435 tensorflow/core/profiler/lib/profiler_session.cc:107] Profiling is late by 1594605 nanoseconds and will start immediately.
2022-08-10 23:44:48.736212: W 562436 tensorflow/core/profiler/rpc/client/capture_profile.cc:133] No trace event is collected from localhost:9012
The GPU test has Profiling is late by 1594605 nanoseconds
vs CPU with Profiling is late by 760266 nanoseconds
. Maybe the XLA execution had finished by the time tracing had started?
There is some lag between when the tracer thread starts and when it actually starts tracing, long enough that XLA execution can finish before it starts. This test was only passing when xm.xla_device()
and torch.ones()
were sufficiently slow to let the tracer get a head start, and I was able to reproduce the error consistently by moving them before the tracer thread started.
Adding a half-second delay makes the test pass consistently on my machine. Hopefully that holds up on the CI 🤞
Test still flaked on GPU (even though it's using CPU) and I can't reproduce the error locally. Removing it from the CI tests since I'll have to add a TPU version outside of the CI anyway.
How is this test different from https://github.com/pytorch/xla/blob/master/test/test_profiler.py we already have? It seem like that tests runs mnist. If you switch the runtime from xrt to pjrt would test_profiler.py
still pass?
For the TPU test, can you make sure device side trace also shows up?
This test specifically checks for the PJRT traces I added in this PR. This test is also more isolated (no dependency on another test) and uses a much smaller graph, so it's actually possible to read what is in the trace file if there's an error. The other test is more of an end-to-end test, while I tried to write a minimal integration test here.
I tried running the other test with PJRT and it did not work. I would need to add a call to pjrt.run_multiprocess
to set up the TPU environment variables, set the world size, etc.
I did manually check that the TPU device traces are in the output.
@will-cromar Thanks for adding this feature! We tried using the profiler under PJRT in our applications and observed a few issues:
- Here in PJRT when starting a profiling server with
xp.start_server
on v3-8 running the PJRT ImageNet example (by addingserver = xp.start_server(9012)
at the beginning oftrain_imagenet
), it only captures the device traces of 2 TPU cores. This is unlike on XRT where we could get the device traces of all 8 TPU cores.
A possible workaround is that we can start multiple profiler server with server = xp.start_server(9012 + xm.get_local_ordinal(), only_on_master=False)
to capture traces on more TPUs.
We are capturing traces from the command line as follows (in a separate tmux window on the TPU VM host)
TRACE_DIR=~/workspace/pt_xla_tracing/tracing_pjrt
PORT=9012
sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python -c "$python_cmd"
and then viewing the captured traces from a tensorboard session in a remote coordinator VM.
-
Unlike in the XRT cases where we propagate the
xp.Trace
annotations to TPU device traces withXLA_HLO_DEBUG=1
(as described in Annotation propagation), it seems that the TPU device traces captured under PJRT cannot propagate the annotations to TPU device traces even withXLA_HLO_DEBUG=1
. -
We are happy to find that the profiler could work on a TPU pod (unlike the XRT cases where the profiler doesn't work on a pod as we reported in https://github.com/pytorch/xla/issues/3446). Meanwhile, on a v3-128 pod, it still only captures the device traces of 2 TPU cores.
-
A related bug is that
xm.get_local_ordinal
only returns a value in 0,1,2,3 on TPU v3 (and two TPU v3 cores that are on the different chips in the TPU v3 board could have the samexm.get_local_ordinal
), whereas it is expected to return a value in 0,1,2,3,4,5,6,7 on TPU v3 as in XRT. I believe we should module by 8 instead of 4 aslocal_world_size
on TPU v3 in https://github.com/pytorch/xla/blob/5059c3b13b5c190d2871ccd004caff89117ae2af/torch_xla/experimental/pjrt.py#L160
Below is a screenshot illustrating the traces we captured with XLA_HLO_DEBUG=1
under PJRT on a v3-128 TPU, running our ImageNet FSDP test case.
I have a feeling that only seeing one core is due to
- In XRT there is a single process(XRT_SERVER) driving the whole TPU device execution
- In PJRT on v3 each process only drive the TPU execution for 1 chip(2 core)
There should be a way around this through.
- @JackCaoG is correct. Each process only has two cores, so this is expected. I also expect starting multiple profiler servers and tracing all of them to show more cores, although I think we need a better API to make this use case easier in practice.
- I wasn't aware of this use case. Can you file an issue to me and I can follow up?
- Awesome! I'm glad this works without having to implement anything extra (except a better profiler API that I alluded to in 1) 😁
- Good catch. I'm mainly doing my testing on TPU v4 so this slipped by. I need to refactor how we store these ordinals in #3949 anyway, so I'll make sure to fix this issue as part of that change.
@JackCaoG is correct. Each process only has two cores, so this is expected. I also expect starting multiple profiler servers and tracing all of them to show more cores, although I think we need a better API to make this use case easier in practice.
I see, this makes sense to me. (I was thinking about the case in JAX PJRT where its jax.profiler.trace
could capture the device traces of all 8 TPU v3 cores, but I guess the difference is that JAX runs in an SPMD manner :)
I wasn't aware of this use case. Can you file an issue for me and I can follow up?
Sounds good, I'll file a new issue on this. This is mostly to propagate the xp.Trace
annotations (similar to jax.profiler.TraceAnnotation
) onto the "TensorFlow Name Scope" annotations on the TPU device trace. It was working under XRT when we set XLA_HLO_DEBUG=1
.
Below is a TPU trace that we captured earlier with the XRT runtime in PT/XLA, where the xp.Trace
annotations such as "Forward model" can be propagated to TPU device traces under "TensorFlow Name Scope".
I need to refactor how we store these ordinals in #3949 anyway, so I'll make sure to fix this issue as part of that change.
Awesome, thank you!