xla
xla copied to clipboard
Running batch transforms (e.g. torch.nn.functional.grid_sample) is slower on TPU vs CPU
🐛 Bug
Executing the batch transforms which use the torch.nn.functional.grid_sample function seems to run slower on a single TPU core vs the CPU.
To Reproduce
We encountered this weird bug where the batch transforms seem to run slower on a single TPU core compared to a GPU (which we kind of expected) but we also found out that it runs even slower than the CPU!
Here's some notebooks showing the results for a single transform (Flip)
GPU (fastest) - avg time: 0.021 secs CPU (middle) - avg time: 1.227 secs TPU (slowest) - avg time: 7.341 secs
For the torch.nn.functional F.grid_sample method, times: GPU - avg time: 0.000 *not measurable in time.time() diff CPU - avg time: 0.821 secs TPU - avg time: 4.247 secs
This is not even using gradients, just pure parallel tensor computations...
The notebooks here have a run on colab link in them so you can validate the stats produced above.
https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_CPU.ipynb
https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_GPU.ipynb
https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_TPU.ipynb
Expected behavior
We expect that the TPU should run the transforms much faster than a CPU.
Environment
Colab
- Reproducible on XLA backend [CPU/TPU]: TPU runtime - pytorch-dev20200707
- torch_xla version: torch-xla==1.6+5430aca
Additional context
Lastly, we noticed that other data augmentations that run on batch on the TPU doesn't slow it down as much (brightness and contrast) as they run faster on a TPU vs a CPU...
We (@butchland and @tyoc213) are building an extension library to enable the fastai library to run on TPUs.
If you have suggestions to speed it up (e.g. alternative algos for batch transforms for data augmentations), we'd appreciate it!
Hi @butchland, thanks for reporting! Could you follow the instruction in here to run a debug run? This way we can know what exactually happened. My guess would be that xla currently does not lower grid_sampler_2d
and grid_sampler_3d
node so they are being forwarded to the CPU which caused the slowdown.
debug_run.tar.gz If you need any other thing or need extra parameters we can send it back.
deleted log because I just see that it is in the zip
The finale python code executed is this
import fastai_xla_extensions.core
from fastai2.vision.all import *
from my_timesaver_utils.profiling import *
path = untar_data(URLs.PETS)/'images'
Path.BASE_PATH = path; path.ls()
print(f'running on default_device() & cuda is {torch.cuda.is_available()}')
img = PILImage.create(path/'Abyssinian_1.jpg')
resize = Resize(size=200)
img2 = resize(img,split_idx=0)
timg2 = TensorImage(array(img2)).permute(2,0,1).float()/255.
def batch_ex(bs, device): return TensorImage(timg2[None].to(device).expand(bs, *timg2.shape))
b768_img = batch_ex(768, default_device()); (b768_img.shape, b768_img.device)
flip_tfm = Flip(p=1.0)
# run without profile
run_with_profile = True
F.grid_sample = profile_call(F.grid_sample) if run_with_profile else F.grid_sample
@profile_call
def mtest(b_img):
#set_trace()
new_b_img = flip_tfm(b_img)
return new_b_img
clear_prof_data()
print("--- 10 image tensor loops:")
for i in range(10):
print("--- ---------------------------------")
new_b768_img = mtest(b768_img)
print("--- ")
print_prof_data()
Oh, so it looks like you are running a small code snippet and does not finish a full step so the metric report is not generated. Do you mind running it again with
import torch_xla.debug.metrics as met
print(met.metrics_report())
at the end. More detail can be find here. This report will be super helpful and telling us where the slowness coming from.
O yes, we tried to remove all the other interference from extra code and just limit it the most to what is causing the slowness.
By the way, I see I can add --hlo and generate maybe something like grab_graph.py something?
By the way, in our sample what was missing to generate this report so that we dont print it manually?
Found that aten are the the calls forwarded to CPU because not implemented on TPU, so I paste from the tgz for easy access.
Counter: aten::_local_scalar_dense
Value: 10
Counter: aten::affine_grid_generator
Value: 10
Counter: aten::grid_sampler_2d
Value: 10
yup, you are right.. _local_scalar_dense
most likely comes from pytorch item()
call. The other two looks like we need to add a lowering. We are a bit busy with the upcoming release now but will add this to our todo list.
For your other questions, yes you can setup XLA_SAVE_TENSORS_FILE
and XLA_SAVE_TENSORS_FMT
to dump the hlo text for the debug run as well. I think we will dump the metric report every time mark_step
is called in here. You can also manually call the api to print the metric to the output.
Good! I see, thanks.
So we should wait for the lowering of this 2 calls, but what about item()
call, is there a way we can optimize it? or it will show up always?
yup I will update this thread when I make any progress on lowering these two ops. We have a section in here talking about item
call, the take away is don't use it unless necessary.
Hi there @JackCaoG, Im back on this issue so I will give a try!
Hi there @JackCaoG, Im back on this issue so I will give a try!
Hi, get a solution new? I just met the same issue when training GPT-neo model using TPUs on Colab
I am using Resize inside Pytorch lightning training step and it makes my code terribly slow. Is there a solution for this?
affine_grid should be supported now. To get a better understanding of the problem, doing a metric report
import torch_xla.debug.metrics as met
print(met.metrics_report())
after a step will help
@JackCaoG I am using Resize in my training step, more specifically transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None)
.
I am using the metric report function, but the training seems to be stuck at the resize(tensor) operation, so the code doesn't reach that step.
Hi there @JackCaoG, Im back on this issue so I will give a try!
Hi @JackCaoG @butchland @tyoc213, I'm wondering if you have find the solution to speeding up the F.grid_sample method. I'm also running into same issue. Any help will be much appreciated.
@dhruvrnaik if you have a small repo I might be able to take a look. It depends on what op that transforms.Resize
get decompose by pytorch and passed to us.
Taking another look of the cpu grid_sampler_2d
implementation which seems to be play with the stride quite a bit. This kind of op is pretty difficult to lower for xla since we can't play with the stride(xla is a functional compiler and does not expose its memory space to the user). We will have to use a bunch of reshape, conv to fake the stride. Which models uses this batch transforms? I don't have anyone can immediately work on this lowering.
Thank you for looking into it, @JackCaoG. I'm working with a model for learning pixel-conditioned Neural Radiance Field (paper, code). Many radiance field models heavily rely on nn.functional.grid_sample
to sample features based on pixel/voxel locations, so having XLA support for it will be tremendously beneficial to the community. If this function will not be supported soon, do you have recommendations for alternative approaches that might serve the same functionality, while running much faster on TPU?
I don't think we have anything similar to grid_sample. PyTorch/XLA supports upsample_nearest2d
but I felt like that's not what you want. I was trying to find if there is a tensorflow implementation of this op(tf op usually has xla lowering) but I only found this workaround. We might be able to use some of the technique in this tf workaround when lowering this op..