SASA-pytorch
SASA-pytorch copied to clipboard
Excessive GPU Consumption
probably matmul
or view
something...
should try einsum
by this commit (6474c17db3201421f30f53dfe342906f1277a2db), GPU Mem Usage become x20 -> x5 than ResNet-50
current status (CIFAR-10, batch=8)
Network | CStem-SA-ResNet-50 | ResNet-50 |
---|---|---|
GPU Usage | 8.2GB | 1.6GB |
Param Count | 12686130 | 23538522 |
training time per batch | 700ms | 100ms |
if we use sa-conv7x7 to sa-conv3x3,
Network | sa-conv7x7 | sa-conv3x3 |
---|---|---|
GPU Usage | 8.2GB | 3.1GB |
Param Count | 12686130 | 12684242 |
training time per batch | 700ms | 270ms |
batch mult to einsum -> 89a27c2c5674c7384cb09fd4af55fbb88fc3d6ae GPU Mem Usage become x5 -> x2.7 than ResNet-50
current status (CIFAR-10, batch=8)
Network | CStem-SA-ResNet-50 | ResNet-50 |
---|---|---|
GPU Usage | 4.3GB | 1.6GB |
Param Count | 12686130 | 23538522 |
training time per batch | 462ms | 100ms |
It has become possible to run batch 16.
Network | sa-conv7x7 | sa-conv3x3 |
---|---|---|
GPU Usage | 4.3GB | 2.2GB |
Param Count | 12686130 | 12684242 |
training time per batch | 462ms | 231ms |
I think this implements are incorrect which ignores the effect of "multi-head". In fact, the implements are single-head, thereby making it more effecient, but less effectively. For example, when you execute softmax for q*k, the shape of q*k should be (b,n_head,d/head,fh,fw,kw*kh). Then, you will
sum it at dim 2 and get attention maps with shape (b,n_head,fh,fw,kw*kh) to make the next operations. Note that the shape of attention maps in your code is (b,1,fh,fw,kw*kh), which in fact
results in single-head attention.