Print(tensor) takes a very long time!
🐛 Bug
when i create a tensor on the tpu, i can perform operations including printing, but after running inference it takes a very long time, sometimes more than 2 minutes
To Reproduce
run this script
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
class InstanceNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(InstanceNorm, self).__init__()
self.epsilon = epsilon
def forward(self, x):
x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
class ApplyStyle(nn.Module):
"""
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
"""
def __init__(self, latent_size, channels):
super(ApplyStyle, self).__init__()
self.linear = nn.Linear(latent_size, channels * 2)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
return x
class ResnetBlock_Adain(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()
p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation
p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)
def forward(self, x, dlatents_in_slice):
y = self.conv1(x)
y = self.style1(y, dlatents_in_slice)
y = self.act1(y)
y = self.conv2(y)
y = self.style2(y, dlatents_in_slice)
out = x + y
return out
class myModel(nn.Module):
def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert (n_blocks >= 0)
super(myModel, self).__init__()
activation = nn.ReLU(True)
self.deep = deep
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
norm_layer(64), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
norm_layer(128), activation)
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
norm_layer(256), activation)
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
if self.deep:
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
### resnet blocks
BN = []
for i in range(n_blocks):
BN += [
ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
self.BottleNeck = nn.Sequential(*BN)
if self.deep:
self.up4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512), activation
)
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256), activation
)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128), activation
)
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64), activation
)
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))
def forward(self, input, dlatents):
x = input # 3*224*224
skip1 = self.first_layer(x)
skip2 = self.down1(skip1)
skip3 = self.down2(skip2)
if self.deep:
skip4 = self.down3(skip3)
x = self.down4(skip4)
else:
x = self.down3(skip3)
bot = []
bot.append(x)
features = []
for i in range(len(self.BottleNeck)):
x = self.BottleNeck[i](x, dlatents)
bot.append(x)
if self.deep:
x = self.up4(x)
features.append(x)
x = self.up3(x)
features.append(x)
x = self.up2(x)
features.append(x)
x = self.up1(x)
features.append(x)
x = self.last_layer(x)
# x = (x + 1) / 2
# return x, bot, features, dlatents
return x
TPU = xm.xla_device()
model = myModel(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False)
model.to(TPU)
i = 0
while i < 10:
i += 1
latent = torch.rand(1,1,512,device=TPU)
img = torch.rand(1,3,224,224, device=TPU)
print(latent) # prints fine
print(img) # prints fine
# now try printing after inference
output = model(img, latent)
print('ran inference')
print(output)# takes a minute or 2
# also doing tensor operations consumes the same amount of time! like addition or tensor.item()
Steps to reproduce the behavior:
- open colab with TPU v2
- run this script
- notice the behaviour
Expected behavior
the print statement should take fractions of a second, but it takes a long time
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: i tried many versions, including the last one
update: seems that adding xm.mark_step() before and after inference makes it work somehow, and it's fast
Hi, @baqt99, sry for the late reply
Yes, add xm.mark_step() should help
Reason should be:
PyTorch/XLA compile code with PyTorch's Lazy Tensor, so code will not compiled/run before it was called/triggered actually, such as by print(output) and xm.mark_step(); and xm.mark_step() is a command that let all the above TPU code to compile and run;
so without xm.mark_step()(like the above code), print tensor need to wait after output = model(img, latent) finish, so we could observe a very long time in each loop:
- tried above locally on v4-8: https://gist.github.com/ManfeiBai/c3a89ac4a3497f999195ee8a38aeab61
then if add xm.mark_step(), print tensor would take a short time due to code has been compiled before print tensor when run xm.mark_step(), and print tensor used less time due to it only need to print tensor:
- tried locally on v4-8, called
xm.mark_stepafter inference: https://gist.github.com/ManfeiBai/6fa4eda66f1fa45c56c0707db1a635f6
cc @JackCaoG
close the issue since the root cause and solution is found in the comments above.