xla
xla copied to clipboard
How to benchmark PyTorch XLA code properly
❓ Questions and Help
Hi! I'm trying to benchmark some pytorch XLA code, and can't find a way how to do it correctly.
For simplicity what's I'm benchmarking is torch.matmul(a, b)
. Firstly I created the most straightforward version of benchmarking, inspired by cuda & triton benchmarking code:
# create tensors
a = torch.randn((N, K), device=device, dtype=dtype)
b = torch.randn((K, M), device=device, dtype=dtype)
def fn():
torch.matmul(a, b)
benchmark(fn) # here I'm doing warmup runs/multiple fn runs
This way it didn't work, effectively rendering benchmark to be immediate.
I realized that no work is actually happening since tensors are lazy, so I've added xm.unlazy
calls after fn
run with matmul
result tensor. However I still was getting numbers which look like no work is being done.
My theory was that since that structure of computation is not changing backend is reusing results. So I tried to regenerate inputs on each iteration. I tried different approaches, with full regenerate, or with some ways so prepare is faster, such as:
def prepare():
a[0, 0] += 1
b[0, 0] += 1
return [a.clone().detach(), b.clone().detach()]
But with neither of my attempts I was able to achieve proper measurement of matmul
function. I feel like I'm either measuring compilation speed, or no-op speed. Any tips on how to write this benchmark / establish better mental model when / how to avoid recompilation of the code, but still execution of it?
Thanks in advance!
I think what you want to do
import torch_xla.core.xla_model as xm
device = xm.xla_device()
# create tensors
a = torch.randn((N, K), device=device, dtype=dtype)
b = torch.randn((K, M), device=device, dtype=dtype)
def fn():
torch.matmul(a, b)
timer_begin()
fn(a, b)
xm.mark_step() ---> trigger async computation
xm.wait_device_ops() ---> wait for all async ops to finish
timer_end()
Through I would suggest to use a really big matmul. There is a fixed overhead to launch any XLA computation, if the computation itself is tiny, it will be dominated by the runtime overhead.
@JackCaoG thanks for the response!
Couple of follow up questions:
- This way time of
randn
creation gets into the operation time, is there way to avoid it? Putxm.mark_step()
andxm.wait_device_ops()
aftera
andb
creation but before timer starts? - If I do this and run
fn()
twice will it actually run twice? Or only once in reality? Is there any way to measure repeated execution of it? - What's a magnitude of time of launching XLA computation?
Thank you again!
oh my bad, yea you should do a mark_step
after randn
.
you can execute fn
multiple times, it will just keep accumulating the computations.xm.mark_step
is the that takes all of the pending executions and runs it. As for the overhead, I think it is in the order of ms.
For more details you can take a look at https://github.com/pytorch/xla/blob/master/docs/pytorch_xla_overview.md
@ttim is there anything we can do to help further or are you unblocked?
@miladm I was able to benchmark my code and get consistent and probably more or less correct results.
I'm trying to get access to TPU v5e for a week now and tbh process of getting a quota is very frustrating. Appreciate if you can help with this. Thank you!
Great to hear!
cc @ultrons to help further if feasible.
@ultrons I would really appreciate if you can help me. I'm trying to create v5e nodes to try it out and got some quota approved, but it doesn't seem working. I was iterating with support for over a week now. I can share more details if needed. Thank you!