SASA-pytorch icon indicating copy to clipboard operation
SASA-pytorch copied to clipboard

Excessive GPU Consumption

Open MerHS opened this issue 5 years ago • 3 comments

probably matmul or view something...

should try einsum

MerHS avatar Jun 28 '19 10:06 MerHS

by this commit (6474c17db3201421f30f53dfe342906f1277a2db), GPU Mem Usage become x20 -> x5 than ResNet-50

current status (CIFAR-10, batch=8)

Network CStem-SA-ResNet-50 ResNet-50
GPU Usage 8.2GB 1.6GB
Param Count 12686130 23538522
training time per batch 700ms 100ms

if we use sa-conv7x7 to sa-conv3x3,

Network sa-conv7x7 sa-conv3x3
GPU Usage 8.2GB 3.1GB
Param Count 12686130 12684242
training time per batch 700ms 270ms

MerHS avatar Jul 01 '19 04:07 MerHS

batch mult to einsum -> 89a27c2c5674c7384cb09fd4af55fbb88fc3d6ae GPU Mem Usage become x5 -> x2.7 than ResNet-50

current status (CIFAR-10, batch=8)

Network CStem-SA-ResNet-50 ResNet-50
GPU Usage 4.3GB 1.6GB
Param Count 12686130 23538522
training time per batch 462ms 100ms

It has become possible to run batch 16.

Network sa-conv7x7 sa-conv3x3
GPU Usage 4.3GB 2.2GB
Param Count 12686130 12684242
training time per batch 462ms 231ms

MerHS avatar Jul 01 '19 08:07 MerHS

I think this implements are incorrect which ignores the effect of "multi-head". In fact, the implements are single-head, thereby making it more effecient, but less effectively. For example, when you execute softmax for q*k, the shape of q*k should be (b,n_head,d/head,fh,fw,kw*kh). Then, you will
sum it at dim 2 and get attention maps with shape (b,n_head,fh,fw,kw*kh) to make the next operations. Note that the shape of attention maps in your code is (b,1,fh,fw,kw*kh), which in fact
results in single-head attention.

luogen1996 avatar Dec 04 '19 12:12 luogen1996