xla icon indicating copy to clipboard operation
xla copied to clipboard

XLA profiler issues: 1) TPU device trace does not show trace annotations, and 2) PRJT only captures 2 cores

Open ronghanghu opened this issue 2 years ago • 10 comments

🐛 Bug

The XLA profiler has been a crucial tool for performance analysis on TPUs. In https://github.com/pytorch/xla/pull/3866, the profiler was also integrated into the PJRT runtime. However, a few problems as mentioned below still hinder applying the XLA profiler to real-world machine learning models for performance debugging. (This issue is submitted as a follow-up to the discussion in https://github.com/pytorch/xla/pull/3866#issuecomment-1234716910, and is part of the challenges we're facing for TPU profiling and performance debugging in benchmark models and our internal use cases.)

Specifically

  • The XLA profiler is expected to propagate the xp.Trace annotations to "TensorFlow Name Scope" on the TPU device traces when running with XLA_HLO_DEBUG=1 (as described in Annotation propagation) -- This feature is important for performance analysis of a neural network model (e.g. to show the time consumption of different submodules), and was working well under PT/XLA 1.10 (with tpu-vm-pt-1.10 TPU VM environment).
    • However, with the nightly PyTorch/XLA wheels, under both PJRT and XRT runtime, the profiler now fails to propagate the trace annotations from xp.Trace to TPU device traces with XLA_HLO_DEBUG=1. This suggests that something related to tracing and HLO annotation could be broken due to changes between PT/XLA 1.10 and the current nightly version.
  • Under the PJRT runtime, the XLA profiler can only capture the device traces of 2 cores (instead of 8 cores) on TPU v3, whereas earlier in XRT the traces of all 8 cores can be captured on a v3-8 TPU VM.
    • Meanwhile, it's really good to see that under PJRT runtime the XLA profiler can also capture TPU device traces on a pod now (although it still captures only 2 TPU v3 cores on e.g. a v3-128 pod).

(Apart from these two problems described above, there is also a long-standing problem that under XRT runtime the XLA profiler cannot capture TPU device traces when running on a pod, as mentioned in https://github.com/pytorch/xla/issues/3446. But I guess that problem is less important if we are committed to transiting to PJRT :)

To Reproduce

The example below contains a simple XLA profiler test case for both XRT and PJRT on a v3-8 TPU VM. The observation is that

  1. Under XRT runtime with PT/XLA 1.10 (TPU VM runtime tpu-vm-pt-1.10), the XLA profiler can capture traces of all 8 TPU cores, and can show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.
  2. Under XRT runtime with PT/XLA nightly, the XLA profiler can capture traces of all 8 TPU cores, but cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.
  3. Under PJRT runtime with PT/XLA nightly, the XLA profiler can only capture traces of 2 TPU cores, and cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.

To reproduce, save the following content to a file ./test_profile_pjrt_vs_xrt.py. Then one can reproduce each case as follows.

import argparse
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp


class DummyModelForProfiling(torch.nn.Module):
    def __init__(self, num_layers=64, dim=4096):
        super().__init__()
        self.num_layers = num_layers
        self.dim = dim
        self.fc = torch.nn.Linear(self.dim, self.dim, bias=False)

    def forward(self, x):
        with xp.Trace("forward_pass"):
            for n_layer in range(self.num_layers):
                with xp.Trace(f"layer_{n_layer}"):
                    x = self.fc(x)
        return x


def run_dummy_model(*args, **kwargs):
    server = xp.start_server(9012)  # start a profiling server
    # a dummy model with a dummy input
    device = xm.xla_device()
    model = DummyModelForProfiling().to(device)
    x = torch.zeros(256, model.dim, device=device)
    xm.mark_step()

    # run 1 million steps to capture profile
    for step in range(1000000):
        print(f"step {step}")
        with xp.StepTrace("run_dummy_model", step_num=step):
            loss = model(x).sum()
            loss = xm.all_reduce(xm.REDUCE_SUM, loss)  # force a sync between TPUs
            # note that `xp.StepTrace` implicitly calls `xm.mark_step` upon its __exit__
        step += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--runtime", type=str, required=True, choices=("xrt", "pjrt"))
    args = parser.parse_args()
    print(f"running w/ {args.runtime}")
    if args.runtime == "pjrt":
        from torch_xla.experimental import pjrt

        pjrt.run_multiprocess(run_dummy_model)
    else:
        import torch_xla.distributed.xla_multiprocessing as xmp

        xmp.spawn(run_dummy_model)

XRT runtime with PT/XLA 1.10

  1. Allocate a new v3-8 TPU VM with tpu-vm-pt-1.10 runtime environment.

  2. Run this file in a tmux window with XLA_HLO_DEBUG=1:

XLA_HLO_DEBUG=1 \
python3 ./test_profile_pjrt_vs_xrt.py --runtime xrt
  1. Capture profile from localhost:9012 in another tmux window on the TPU VM (saved to TRACE_DIR below):
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt110
PORT=9012

sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"
  1. Finally, view the captured trace files in tensorboard:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt110
tensorboard --logdir ${TRACE_DIR} --port 6016

Below are the screenshots from tensorboard in this case. xrt_110_1 And below is a zoom-in view of "TensorFlow Name Scope", which corresponds to our xp.Trace annotations in the PyTorch model: xrt_110_2

As can be seen from the screenshot above, under XRT runtime with PT/XLA 1.10 (TPU VM runtime tpu-vm-pt-1.10), the XLA profiler can capture traces of all 8 TPU cores, and can show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.

XRT runtime with PT/XLA nightly

  1. Allocate a new v3-8 TPU VM with tpu-vm-pt-1.12 runtime environment and install the nightly 20220829 wheels as follows:
# torch, torchvision and torch_xla
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220829-cp38-cp38-linux_x86_64.whl

# libtpu
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220623-py3-none-any.whl

# no longer needed for tpu-vm-pt-1.12 TPU VM runtime
# sudo pip3 install numpy==1.23.0
  1. Run this file in a tmux window with XLA_HLO_DEBUG=1:
XLA_HLO_DEBUG=1 \
python3 ./test_profile_pjrt_vs_xrt.py --runtime xrt
  1. Capture profile from localhost:9012 in another tmux window on the TPU VM (saved to TRACE_DIR below):
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt
PORT=9012

sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"
  1. Finally, view the captured trace files in tensorboard:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/xrt
tensorboard --logdir ${TRACE_DIR} --port 6006

Below are the screenshots from tensorboard in this case. xrt

As can be seen from the screenshot above, under XRT runtime with PT/XLA nightly, the XLA profiler can capture traces of all 8 TPU cores, but cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.

PJRT runtime with PT/XLA nightly

  1. Allocate a new v3-8 TPU VM with tpu-vm-pt-1.12 runtime environment and install the nightly 20220829 wheels as follows (this part is the same as in "XRT runtime with PT/XLA nightly" above):
# torch, torchvision and torch_xla
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220829-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220829-cp38-cp38-linux_x86_64.whl

# libtpu
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220623-py3-none-any.whl

# no longer needed for tpu-vm-pt-1.12 TPU VM runtime
# sudo pip3 install numpy==1.23.0
  1. Run this file in a tmux window with PJRT_DEVICE=TPU and XLA_HLO_DEBUG=1:
PJRT_DEVICE=TPU XLA_HLO_DEBUG=1 \
python3 ./test_profile_pjrt_vs_xrt.py --runtime pjrt
  1. Capture profile from localhost:9012 in another tmux window on the TPU VM (saved to TRACE_DIR below):
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/pjrt
PORT=9012

sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"
  1. Finally, view the captured trace files in tensorboard:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/pjrt
tensorboard --logdir ${TRACE_DIR} --port 6026

Below are the screenshots from tensorboard in this case. pjrt

As can be seen from the screenshot above, under PJRT runtime with PT/XLA nightly, the XLA profiler can only capture traces of 2 TPU cores, and cannot show the trace annotations under "TensorFlow Name Scope" on the TPU device traces.

Expected behavior

  1. Under both PJRT and XRT, the XLA profiler should be able to show the traces (annotated via xp.Trace) onto "TensorFlow Name Scope" on the TPU device traces when running with XLA_HLO_DEBUG=1, as described in Annotation propagation. This feature is quite important to match the PyTorch modules with TPU traces in performance analysis.
  2. Under PJRT runtime, the XLA profiler should ideally capture the TPU traces of all 8 TPU v3 cores on a v3-8 VM.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM with tpu-vm-pt-1.10 or tpu-vm-pt-1.12 runtime (see details above)
  • torch_xla version: PT/XLA 1.10 or nightly 20220829 (see details above)

cc: @will-cromar @JackCaoG

ronghanghu avatar Sep 02 '22 18:09 ronghanghu

@ronghanghu I think I figured out what went wrong with xp.trace. It used to be you can set either XLA_IR_DEBUG and XLA_HLO_DEBUG then you will get the scope information. However in one of a recent changes for lTC, we moved scope_pusher to upstream. Upstream controls meta data in https://github.com/pytorch/pytorch/blob/30fb2c4abaaaa966999eab11674f25b18460e609/torch/csrc/lazy/core/ir_metadata.cpp#L94

and in pytorch/xla we do the env_var mapping in https://github.com/pytorch/xla/blob/28ea3758e9586ef8cc22270a16e1ddb4a21aa6f7/torch_xla/csrc/init_python_bindings.cpp#L747

During this process, we only map the XLA_IR_DEBUG and forgot the XLA_HLO_DEBUG. A quick workaround is to also set XLA_IR_DEBUG=1 during profiling. I will work a fix right away

JackCaoG avatar Sep 09 '22 01:09 JackCaoG

hmm, even if I fixed this issue, something else is still going on. Let me keep looking into this.

JackCaoG avatar Sep 09 '22 01:09 JackCaoG

@JackCaoG Thanks for looking into this!

A quick workaround is to also set XLA_IR_DEBUG=1 during profiling.

Earlier I tried having both XLA_IR_DEBUG=1 and XLA_HLO_DEBUG=1 and it still cannot propagate xp.trace annotations to TPU device traces. (I also tried adding PT_XLA_DEBUG=1 but it hangs the program)

ronghanghu avatar Sep 09 '22 01:09 ronghanghu

This is weird, now with the XLA_IR_DEBUG=1 and XLA_HLO_DEBUG=1, I do see scope information being pushed to profiler image

However, it seems like profiler can not understand it. If we look at the 1.10 profile I see image

and if we click on the scope above I can see

image

I will look into why profiler can not resolve the tf_op field into a scope.

JackCaoG avatar Sep 09 '22 01:09 JackCaoG

https://github.com/tensorflow/tensorflow/commit/306904197c95cc01cdcd30462fd62984329f5cef might be the cause of this issue, I will keep looking tmr. The information is there, I am guessing we just need to tweak the format a bit for it to display correctly.

JackCaoG avatar Sep 09 '22 01:09 JackCaoG

Update on this: it seems that the PJRT profiling no longer works on a pod in the torch_xla 20220916 wheel, although it was working in the 20220829 torch_xla wheel.


I was trying to run the profiling example above on a v3-128 pod by installing the nightly wheels following "PJRT runtime with PT/XLA nightly" and capturing profile as follows:

  1. Run the example script above (downloaded to ~/workspace/pt_xla_tracing/test_profile_pjrt_vs_xrt.py) on all VMs:
TPU_NAME=ronghang-v3-128
ZONE=europe-west4-a
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} \
  --worker all --command "
PJRT_DEVICE=TPU XLA_HLO_DEBUG=1 \
python3 ~/workspace/pt_xla_tracing/test_profile_pjrt_vs_xrt.py --runtime pjrt
"
  1. Then capture a profile trace on the first TPU VM host in the pod:
TRACE_DIR=~/workspace/pt_xla_tracing/dummy_trace/pjrt_pod_v3-128
PORT=9012

sudo mkdir -p ${TRACE_DIR} && sudo chmod -R a+w ${TRACE_DIR}
python_cmd="import torch_xla.debug.profiler as xp; xp.trace('localhost:${PORT}', '${TRACE_DIR}')"
echo "$python_cmd" && python3 -c "$python_cmd"

The procedure above could successfully capture a TPU profile using the nightly 20220829 wheels, but when switching to nightly 20220916 wheels (and still keeping "libtpu_nightly-0.1.dev20220623"), it cannot capture a TPU profile in the second step above. It prints

No trace event is collected after 3 attempt(s). Perhaps, you want to try again (with more attempts?).

I guess there is a chance that a few recent changes could have broken the profiler in PJRT on a pod.

ronghanghu avatar Sep 19 '22 16:09 ronghanghu

@ronghanghu do you know if this still works on donut?

JackCaoG avatar Sep 19 '22 17:09 JackCaoG

I walk through changes merged in last 3 weeks, the only one that seems relevant is https://github.com/pytorch/xla/commit/0aed7b648fa4ca19f9f2ec5916da216eaeec5837

JackCaoG avatar Sep 19 '22 17:09 JackCaoG

@ronghanghu do you know if this still works on donut?

Yes, the PJRT profiling still works on a donut -- I can capture the device trace with the script above on a v3-8 VM following the command in this issue using 20220916 wheel. But the pod case (as mentioned above) seems broken in nightly 20220916 wheels (although it was working on 20220829 wheels).

ronghanghu avatar Sep 19 '22 17:09 ronghanghu

Hmm this is really weird and we should fix it before release cut. @will-cromar can you take a look when you have time?

JackCaoG avatar Sep 19 '22 17:09 JackCaoG