Attention-Augmented-Conv2d
Attention-Augmented-Conv2d copied to clipboard
Memory blow up issue
Hi,
Thanks for the open impl of AAConv/AAWRN :smile:
I have access to a 16 gb GPU to do a few experiments on AA Wide Res Net, but the memory grows out of bounds at the start of training. For a AAWRN28-10 it requires approx 2gb of memory, for 206229580 parameters. At that point, running the model with a batch of 128 images from CIFAR 100 causes RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 15.90 GiB total capacity; 11.70 GiB already allocated; 1.26 GiB free; 2.24 GiB cached).
The error outputs points at the line 113 of the AA Conv class : rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1)).
On the other hand, it happens in the first convolution of the first layer. Tried to switch every AA Conv to relative=False which performs a bit better, to the 2nd conv of the first layer.
Had to downscale the model to either batch size = 16 or a terribly low widen factor. If you any idea/plan on how to improve the memory efficiency it would be neat ! :laughing:
I noticed the same. Going into this I thought the positional encodings were a pretty simple affair but looking closer at the einsum I think it's really a spatially separable 2Hx2W convolution kernel on Q padded by HxW. So I although a faster implementation probably can be made by diving into cuda, I think a pretty high computational and memory cost is pretty unavoidable.
In fact, I would not be surprised if one would get pretty good results on this model by removing the K key altogether and only using these positional encodings.
Why does this model consume too much memory?