xla icon indicating copy to clipboard operation
xla copied to clipboard

Running batch transforms (e.g. torch.nn.functional.grid_sample) is slower on TPU vs CPU

Open butchland opened this issue 4 years ago • 17 comments

🐛 Bug

Executing the batch transforms which use the torch.nn.functional.grid_sample function seems to run slower on a single TPU core vs the CPU.

To Reproduce

We encountered this weird bug where the batch transforms seem to run slower on a single TPU core compared to a GPU (which we kind of expected) but we also found out that it runs even slower than the CPU!

Here's some notebooks showing the results for a single transform (Flip)

GPU (fastest) - avg time: 0.021 secs CPU (middle) - avg time: 1.227 secs TPU (slowest) - avg time: 7.341 secs

For the torch.nn.functional F.grid_sample method, times: GPU - avg time: 0.000 *not measurable in time.time() diff CPU - avg time: 0.821 secs TPU - avg time: 4.247 secs

This is not even using gradients, just pure parallel tensor computations...

The notebooks here have a run on colab link in them so you can validate the stats produced above.

https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_CPU.ipynb

https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_GPU.ipynb

https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_TPU.ipynb

Expected behavior

We expect that the TPU should run the transforms much faster than a CPU.

Environment

Colab

  • Reproducible on XLA backend [CPU/TPU]: TPU runtime - pytorch-dev20200707
  • torch_xla version: torch-xla==1.6+5430aca

Additional context

Lastly, we noticed that other data augmentations that run on batch on the TPU doesn't slow it down as much (brightness and contrast) as they run faster on a TPU vs a CPU...

We (@butchland and @tyoc213) are building an extension library to enable the fastai library to run on TPUs.

If you have suggestions to speed it up (e.g. alternative algos for batch transforms for data augmentations), we'd appreciate it!

butchland avatar Aug 05 '20 17:08 butchland

Hi @butchland, thanks for reporting! Could you follow the instruction in here to run a debug run? This way we can know what exactually happened. My guess would be that xla currently does not lower grid_sampler_2d and grid_sampler_3d node so they are being forwarded to the CPU which caused the slowdown.

JackCaoG avatar Aug 05 '20 18:08 JackCaoG

debug_run.tar.gz If you need any other thing or need extra parameters we can send it back.

deleted log because I just see that it is in the zip

The finale python code executed is this

import fastai_xla_extensions.core
from fastai2.vision.all import *
from my_timesaver_utils.profiling import *
path = untar_data(URLs.PETS)/'images'
Path.BASE_PATH = path; path.ls()
print(f'running on default_device() & cuda is {torch.cuda.is_available()}')

img = PILImage.create(path/'Abyssinian_1.jpg')
resize = Resize(size=200)
img2 = resize(img,split_idx=0)




timg2 = TensorImage(array(img2)).permute(2,0,1).float()/255.

def batch_ex(bs, device): return TensorImage(timg2[None].to(device).expand(bs, *timg2.shape))


b768_img = batch_ex(768, default_device()); (b768_img.shape, b768_img.device)


flip_tfm = Flip(p=1.0)
# run without profile
run_with_profile = True
F.grid_sample = profile_call(F.grid_sample) if run_with_profile else F.grid_sample

@profile_call
def mtest(b_img):
    #set_trace()
    new_b_img = flip_tfm(b_img)
    return new_b_img
    
clear_prof_data()
print("--- 10 image tensor loops:")
for i in range(10):
    print("--- ---------------------------------")
    new_b768_img = mtest(b768_img)
print("--- ")
print_prof_data()

tyoc213 avatar Aug 05 '20 20:08 tyoc213

Oh, so it looks like you are running a small code snippet and does not finish a full step so the metric report is not generated. Do you mind running it again with

import torch_xla.debug.metrics as met

print(met.metrics_report())

at the end. More detail can be find here. This report will be super helpful and telling us where the slowness coming from.

JackCaoG avatar Aug 05 '20 22:08 JackCaoG

O yes, we tried to remove all the other interference from extra code and just limit it the most to what is causing the slowness.

debug_run_stats .tar.gz

By the way, I see I can add --hlo and generate maybe something like grab_graph.py something?


By the way, in our sample what was missing to generate this report so that we dont print it manually?


Found that aten are the the calls forwarded to CPU because not implemented on TPU, so I paste from the tgz for easy access.

Counter: aten::_local_scalar_dense
  Value: 10
Counter: aten::affine_grid_generator
  Value: 10
Counter: aten::grid_sampler_2d
  Value: 10

tyoc213 avatar Aug 05 '20 23:08 tyoc213

yup, you are right.. _local_scalar_dense most likely comes from pytorch item() call. The other two looks like we need to add a lowering. We are a bit busy with the upcoming release now but will add this to our todo list.

For your other questions, yes you can setup XLA_SAVE_TENSORS_FILE and XLA_SAVE_TENSORS_FMT to dump the hlo text for the debug run as well. I think we will dump the metric report every time mark_step is called in here. You can also manually call the api to print the metric to the output.

JackCaoG avatar Aug 05 '20 23:08 JackCaoG

Good! I see, thanks.

So we should wait for the lowering of this 2 calls, but what about item() call, is there a way we can optimize it? or it will show up always?

tyoc213 avatar Aug 06 '20 00:08 tyoc213

yup I will update this thread when I make any progress on lowering these two ops. We have a section in here talking about item call, the take away is don't use it unless necessary.

JackCaoG avatar Aug 06 '20 00:08 JackCaoG

Hi there @JackCaoG, Im back on this issue so I will give a try!

tyoc213 avatar Apr 13 '21 00:04 tyoc213

Hi there @JackCaoG, Im back on this issue so I will give a try!

Hi, get a solution new? I just met the same issue when training GPT-neo model using TPUs on Colab

Pwang001 avatar Jul 19 '21 23:07 Pwang001

I am using Resize inside Pytorch lightning training step and it makes my code terribly slow. Is there a solution for this?

dhruvrnaik avatar Mar 26 '22 03:03 dhruvrnaik

affine_grid should be supported now. To get a better understanding of the problem, doing a metric report

import torch_xla.debug.metrics as met

print(met.metrics_report())

after a step will help

JackCaoG avatar Mar 26 '22 22:03 JackCaoG

@JackCaoG I am using Resize in my training step, more specifically transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None).

I am using the metric report function, but the training seems to be stuck at the resize(tensor) operation, so the code doesn't reach that step.

dhruvrnaik avatar Mar 26 '22 23:03 dhruvrnaik

Hi there @JackCaoG, Im back on this issue so I will give a try!

Hi @JackCaoG @butchland @tyoc213, I'm wondering if you have find the solution to speeding up the F.grid_sample method. I'm also running into same issue. Any help will be much appreciated.

honglin-chen avatar Jul 11 '22 23:07 honglin-chen

@dhruvrnaik if you have a small repo I might be able to take a look. It depends on what op that transforms.Resize get decompose by pytorch and passed to us.

JackCaoG avatar Jul 14 '22 23:07 JackCaoG

Taking another look of the cpu grid_sampler_2d implementation which seems to be play with the stride quite a bit. This kind of op is pretty difficult to lower for xla since we can't play with the stride(xla is a functional compiler and does not expose its memory space to the user). We will have to use a bunch of reshape, conv to fake the stride. Which models uses this batch transforms? I don't have anyone can immediately work on this lowering.

JackCaoG avatar Jul 14 '22 23:07 JackCaoG

Thank you for looking into it, @JackCaoG. I'm working with a model for learning pixel-conditioned Neural Radiance Field (paper, code). Many radiance field models heavily rely on nn.functional.grid_sample to sample features based on pixel/voxel locations, so having XLA support for it will be tremendously beneficial to the community. If this function will not be supported soon, do you have recommendations for alternative approaches that might serve the same functionality, while running much faster on TPU?

honglin-chen avatar Jul 15 '22 00:07 honglin-chen

I don't think we have anything similar to grid_sample. PyTorch/XLA supports upsample_nearest2d but I felt like that's not what you want. I was trying to find if there is a tensorflow implementation of this op(tf op usually has xla lowering) but I only found this workaround. We might be able to use some of the technique in this tf workaround when lowering this op..

JackCaoG avatar Jul 15 '22 05:07 JackCaoG