stn.pytorch
stn.pytorch copied to clipboard
compatibility with DataParallel
Thank you for this implementation! Have you tried using it within a network that is wrapped in a DataParallel in order to make use of multiple graphics cards? I am getting an illegal memory access was encountered error when replacing
with torch.cuda.device(3):
input1 = input1.cuda()
input2 = input2.cuda()
start = time.time()
out = s(input1, input2)
print(out.size(), 'time:', time.time() - start)
start = time.time()
out.backward(input1.data.cuda())
print('time:', time.time() - start)
in test.py with
s = torch.nn.DataParallel(s)
if True:
input1 = input1.cuda()
input2 = input2.cuda()
start = time.time()
out = s(input1, input2)
print(out.size(), 'time:', time.time() - start)
start = time.time()
out.backward(input1.data.cuda())
print('time:', time.time() - start)
Interestingly, the code works with
export CUDA_VISIBLE_DEVICES="0"
but fails with
export CUDA_VISIBLE_DEVICES="0,1"
I see that you are explicitly setting the CUDA device before executing the kernel, which might be the reason for the illegal memory access. Any ideas? Thank you!