torch
torch copied to clipboard
`tltorch.FactorizedConv.from_conv` tries to allocate ~83GB memory for an input of shape (1024, 512, 3, 3, 3)
Minimal Code to reproduce the error:
import torch
import tltorch
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1))
print(tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='cp'))
Error:
RuntimeError: [enforce fail at alloc_cpu.cpp:83] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 89060441849856 bytes. Error code 12 (Cannot allocate memory)
The error is actually coming from this line in the truncated_svd function of TensorLy. The shape of the matrix passed to the svd function is torch.Size([3, 4718592]). Note, this error is not thrown when when torch.svd is directly used. The size of the matrix is only ~54 MBs, it's strange the tl.svd tries to allocate 83 GBs for it.
It's also possible that I'm making a very stupid mistake, in any case, looking forward to some solution here :pray: @JeanKossaifi
The same example works with factorization = 'tucker' but the from_conv function does not infer dilation from the input conv layer.
Example Code:
import torch
import tltorch
test_input = torch.randn((1, 1024,64, 7, 7))
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1), dilation=(3,1,1))
print(test_conv3d(test_input).shape)
fact_conv3d = tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='tucker')
print(fact_conv3d(test_input).shape)
prints
torch.Size([1, 512, 64, 7, 7])
torch.Size([1, 512, 68, 7, 7])
while,
import torch
import tltorch
test_input = torch.randn((1, 1024,64, 7, 7))
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1), dilation=(3,1,1))
print(test_conv3d(test_input).shape)
fact_conv3d = tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='tucker', dilation=(3, 1, 1))
print(fact_conv3d(test_input).shape)
prints
torch.Size([1, 512, 64, 7, 7])
torch.Size([1, 512, 68, 7, 7])
Not sure if this is intentional, but felt it worth mentioning.
Great catch @hello-fri-end, thank you for investigating and flagging! dilation is supported by the conv but the from_conv doesn't pass the argument - would you be able to open a small PR to fix the issue?
Support for dilated cones in from_conv is now added in 615fbdd