xla
xla copied to clipboard
XLA profiler issues: 1) TPU device trace does not show trace annotations, and 2) PRJT only captures 2 cores
🐛 Bug
The XLA profiler has been a crucial tool for performance analysis on TPUs. In https://github.com/pytorch/xla/pull/3866, the profiler was also integrated into the PJRT runtime. However, a few problems as mentioned below still hinder applying the XLA profiler to real-world machine learning models for performance debugging. (This issue is submitted as a follow-up to the discussion in https://github.com/pytorch/xla/pull/3866#issuecomment-1234716910, and is part of the challenges we're facing for TPU profiling and performance debugging in benchmark models and our internal use cases.)
Specifically
- The XLA profiler is expected to propagate the
xp.Trace
annotations to "TensorFlow Name Scope" on the TPU device traces when running withXLA_HLO_DEBUG=1
(as described in Annotation propagation) -- This feature is important for performance analysis of a neural network model (e.g. to show the time consumption of different submodules), and was working well under PT/XLA 1.10 (withtpu-vm-pt-1.10
TPU VM environment).- However, with the nightly PyTorch/XLA wheels, under both PJRT and XRT runtime, the profiler now fails to propagate the trace annotations from
xp.Trace
to TPU device traces withXLA_HLO_DEBUG=1
. This suggests that something related to tracing and HLO annotation could be broken due to changes between PT/XLA 1.10 and the current nightly version.
- However, with the nightly PyTorch/XLA wheels, under both PJRT and XRT runtime, the profiler now fails to propagate the trace annotations from
-
Under the PJRT runtime, the XLA profiler can only capture the device traces of 2 cores (instead of 8 cores) on TPU v3, whereas earlier in XRT the traces of all 8 cores can be captured on a v3-8 TPU VM.
- Meanwhile, it's really good to see that under PJRT runtime the XLA profiler can also capture TPU device traces on a pod now (although it still captures only 2 TPU v3 cores on e.g. a v3-128 pod).
(Apart from these two problems described above, there is also a long-standing problem that under XRT runtime the XLA profiler cannot capture TPU device traces when running on a pod, as mentioned in https://github.com/pytorch/xla/issues/3446. But I guess that problem is less important if we are committed to transiting to PJRT :)
To Reproduce
The example below contains a simple XLA profiler test case for both XRT and PJRT on a v3-8 TPU VM. The observation is that
- Under XRT runtime with PT/XLA 1.10 (TPU VM runtime
tpu-vm-pt-1.10
), the XLA profiler can capture traces of all 8 TPU cores, and can show the trace annotations under "TensorFlow Name Scope" on the TPU device traces. - Under XRT runtime with PT/XLA nightly, the XLA profiler can capture traces of all 8 TPU cores, but cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.
- Under PJRT runtime with PT/XLA nightly, the XLA profiler can only capture traces of 2 TPU cores, and cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.
To reproduce, save the following content to a file ./test_profile_pjrt_vs_xrt.py
. Then one can reproduce each case as follows.
import argparse
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
class DummyModelForProfiling(torch.nn.Module):
def __init__(self, num_layers=64, dim=4096):
super().__init__()
self.num_layers = num_layers
self.dim = dim
self.fc = torch.nn.Linear(self.dim, self.dim, bias=False)
def forward(self, x):
with xp.Trace("forward_pass"):
for n_layer in range(self.num_layers):
with xp.Trace(f"layer_{n_layer}"):
x = self.fc(x)
return x
def run_dummy_model(*args, **kwargs):
server = xp.start_server(9012) # start a profiling server
# a dummy model with a dummy input
device = xm.xla_device()
model = DummyModelForProfiling().to(device)
x = torch.zeros(256, model.dim, device=device)
xm.mark_step()
# run 1 million steps to capture profile
for step in range(1000000):
print(f"step {step}")
with xp.StepTrace("run_dummy_model", step_num=step):
loss = model(x).sum()
loss = xm.all_reduce(xm.REDUCE_SUM, loss) # force a sync between TPUs
# note that `xp.StepTrace` implicitly calls `xm.mark_step` upon its __exit__
step += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--runtime", type=str, required=True, choices=("xrt", "pjrt"))
args = parser.parse_args()
print(f"running w/ {args.runtime}")
if args.runtime == "pjrt":
from torch_xla.experimental import pjrt
pjrt.run_multiprocess(run_dummy_model)
else:
import torch_xla.distributed.xla_multiprocessing as xmp
xmp.spawn(run_dummy_model)
XRT runtime with PT/XLA 1.10
-
Allocate a new v3-8 TPU VM with
tpu-vm-pt-1.10
runtime environment. -
Run this file in a tmux window with
XLA_HLO_DEBUG=1
:
XLA_HLO_DEBUG=1 \
python3 ./test_profile_pjrt_vs_xrt.py --runtime xrt
- Capture profile from
localhost:9012
in another tmux window on the TPU VM (saved toTRACE_DIR
below):
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt110
PORT=9012
sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"
- Finally, view the captured trace files in tensorboard:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt110
tensorboard --logdir ${TRACE_DIR} --port 6016
Below are the screenshots from tensorboard in this case.
And below is a zoom-in view of "TensorFlow Name Scope", which corresponds to our
xp.Trace
annotations in the PyTorch model:
As can be seen from the screenshot above, under XRT runtime with PT/XLA 1.10 (TPU VM runtime tpu-vm-pt-1.10
), the XLA profiler can capture traces of all 8 TPU cores, and can show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.
XRT runtime with PT/XLA nightly
- Allocate a new v3-8 TPU VM with
tpu-vm-pt-1.12
runtime environment and install the nightly 20220829 wheels as follows:
# torch, torchvision and torch_xla
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220829-cp38-cp38-linux_x86_64.whl
# libtpu
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220623-py3-none-any.whl
# no longer needed for tpu-vm-pt-1.12 TPU VM runtime
# sudo pip3 install numpy==1.23.0
- Run this file in a tmux window with
XLA_HLO_DEBUG=1
:
XLA_HLO_DEBUG=1 \
python3 ./test_profile_pjrt_vs_xrt.py --runtime xrt
- Capture profile from
localhost:9012
in another tmux window on the TPU VM (saved toTRACE_DIR
below):
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt
PORT=9012
sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"
- Finally, view the captured trace files in tensorboard:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt
tensorboard --logdir ${TRACE_DIR} --port 6006
Below are the screenshots from tensorboard in this case.
As can be seen from the screenshot above, under XRT runtime with PT/XLA nightly, the XLA profiler can capture traces of all 8 TPU cores, but cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.
PJRT runtime with PT/XLA nightly
- Allocate a new v3-8 TPU VM with
tpu-vm-pt-1.12
runtime environment and install the nightly 20220829 wheels as follows (this part is the same as in "XRT runtime with PT/XLA nightly" above):
# torch, torchvision and torch_xla
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220829-cp38-cp38-linux_x86_64.whl
# libtpu
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220623-py3-none-any.whl
# no longer needed for tpu-vm-pt-1.12 TPU VM runtime
# sudo pip3 install numpy==1.23.0
- Run this file in a tmux window with
PJRT_DEVICE=TPU
andXLA_HLO_DEBUG=1
:
PJRT_DEVICE=TPU XLA_HLO_DEBUG=1 \
python3 ./test_profile_pjrt_vs_xrt.py --runtime pjrt
- Capture profile from
localhost:9012
in another tmux window on the TPU VM (saved toTRACE_DIR
below):
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/pjrt
PORT=9012
sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"
- Finally, view the captured trace files in tensorboard:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/pjrt
tensorboard --logdir ${TRACE_DIR} --port 6026
Below are the screenshots from tensorboard in this case.
As can be seen from the screenshot above, under PJRT runtime with PT/XLA nightly, the XLA profiler can only capture traces of 2 TPU cores, and cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.
Expected behavior
- Under both PJRT and XRT, the XLA profiler should be able to show the traces (annotated via
xp.Trace
) onto "TensorFlow Name Scope" on the TPU device traces when running withXLA_HLO_DEBUG=1
, as described in Annotation propagation. This feature is quite important to match the PyTorch modules with TPU traces in performance analysis. - Under PJRT runtime, the XLA profiler should ideally capture the TPU traces of all 8 TPU v3 cores on a v3-8 VM.
Environment
- Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM with
tpu-vm-pt-1.10
ortpu-vm-pt-1.12
runtime (see details above) - torch_xla version: PT/XLA 1.10 or nightly 20220829 (see details above)
cc: @will-cromar @JackCaoG
@ronghanghu I think I figured out what went wrong with xp.trace
. It used to be you can set either XLA_IR_DEBUG
and XLA_HLO_DEBUG
then you will get the scope information. However in one of a recent changes for lTC, we moved scope_pusher
to upstream. Upstream controls meta data in https://github.com/pytorch/pytorch/blob/30fb2c4abaaaa966999eab11674f25b18460e609/torch/csrc/lazy/core/ir_metadata.cpp#L94
and in pytorch/xla we do the env_var mapping in https://github.com/pytorch/xla/blob/28ea3758e9586ef8cc22270a16e1ddb4a21aa6f7/torch_xla/csrc/init_python_bindings.cpp#L747
During this process, we only map the XLA_IR_DEBUG
and forgot the XLA_HLO_DEBUG
. A quick workaround is to also set XLA_IR_DEBUG=1
during profiling. I will work a fix right away
hmm, even if I fixed this issue, something else is still going on. Let me keep looking into this.
@JackCaoG Thanks for looking into this!
A quick workaround is to also set XLA_IR_DEBUG=1 during profiling.
Earlier I tried having both XLA_IR_DEBUG=1
and XLA_HLO_DEBUG=1
and it still cannot propagate xp.trace
annotations to TPU device traces. (I also tried adding PT_XLA_DEBUG=1
but it hangs the program)
This is weird, now with the XLA_IR_DEBUG=1
and XLA_HLO_DEBUG=1
, I do see scope information being pushed to profiler
However, it seems like profiler can not understand it. If we look at the 1.10 profile I see
and if we click on the scope above I can see
I will look into why profiler can not resolve the tf_op
field into a scope.
https://github.com/tensorflow/tensorflow/commit/306904197c95cc01cdcd30462fd62984329f5cef might be the cause of this issue, I will keep looking tmr. The information is there, I am guessing we just need to tweak the format a bit for it to display correctly.
Update on this: it seems that the PJRT profiling no longer works on a pod in the torch_xla 20220916 wheel, although it was working in the 20220829 torch_xla wheel.
I was trying to run the profiling example above on a v3-128 pod by installing the nightly wheels following "PJRT runtime with PT/XLA nightly" and capturing profile as follows:
- Run the example script above (downloaded to
~/workspace/pt_xla_tracing/test_profile_pjrt_vs_xrt.py
) on all VMs:
TPU_NAME=ronghang-v3-128
ZONE=europe-west4-a
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} \
--worker all --command "
PJRT_DEVICE=TPU XLA_HLO_DEBUG=1 \
python3 ~/workspace/pt_xla_tracing/test_profile_pjrt_vs_xrt.py --runtime pjrt
"
- Then capture a profile trace on the first TPU VM host in the pod:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/pjrt_pod_v3-128
PORT=9012
sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"
The procedure above could successfully capture a TPU profile using the nightly 20220829 wheels, but when switching to nightly 20220916 wheels (and still keeping "libtpu_nightly-0.1.dev20220623"), it cannot capture a TPU profile in the second step above. It prints
No trace event is collected after 3 attempt(s). Perhaps, you want to try again (with more attempts?).
I guess there is a chance that a few recent changes could have broken the profiler in PJRT on a pod.
@ronghanghu do you know if this still works on donut?
I walk through changes merged in last 3 weeks, the only one that seems relevant is https://github.com/pytorch/xla/commit/0aed7b648fa4ca19f9f2ec5916da216eaeec5837
@ronghanghu do you know if this still works on donut?
Yes, the PJRT profiling still works on a donut -- I can capture the device trace with the script above on a v3-8 VM following the command in this issue using 20220916 wheel. But the pod case (as mentioned above) seems broken in nightly 20220916 wheels (although it was working on 20220829 wheels).
Hmm this is really weird and we should fix it before release cut. @will-cromar can you take a look when you have time?