tt-pytorch icon indicating copy to clipboard operation
tt-pytorch copied to clipboard

High GPU memory consumption

Open saareliad opened this issue 5 years ago • 4 comments

Hi, I tried to integrate the TTLayer into transformerXL, however I found that it consumes much more memory than usual. Did you experience such problems? do you know anyway around this?

(BTW I also applied few fixes for multi-GPU training, e.g tensor train objects are not passed to GPU when you activate the model.to(device), therefore breaking the model in distributed training).

saareliad avatar Aug 08 '19 16:08 saareliad

Hello,

It seems that this is a fundamental problem with TTLayer and how optimization in autograd frameworks is done. In addition to the memory footprint of model weights, during optimization we also store activations on GPU.

In the case of a single FC layer of size d^3 x d^3 and a batch of size B x d^3, the storage memory footprint is d^6 and the activations footprint is Bd^3. In the case of TTLayer with 3 cores, the storage footprint is d^2(r^2 + 2r) but the activation footprint is Bd^3(2r + 1). Batch size (number of tokens, B x L) used in Transformers is usually pretty big and, as a result, an increase in the activation memory footprint (2Brd^3) outweighs the win in memory footprint (d^6).

To be precise, this happens when 2Br > d^3. Common values for d^3 is ~1000-4000, common batch sizes are around 5000 and for ranks they are 8-32. As you see, there is a big activation memory overhead. Most likely, TTLayers are not applicable to FC layers used in Transformers.

AlexGrinch avatar Aug 13 '19 19:08 AlexGrinch

Your explanation about activations makes sense, I also went over the math and its correct.

However, the compressed model also consumes much more memory during inference, i.e in eval model and with torch.no_grad().

Top memory consumption was 3021MB for the compressed model versus 2132MB for the normal model. (TransfomerXL, before compression 151M params, after compression was 124M params, compressed all positional FF layers).

I also tried to write the "forward" method more efficiently (e.g with bmm or einsum) , it didn't help too. I suspect its due to all the reshapes happening underneath the surface. What do you think?

saareliad avatar Aug 14 '19 06:08 saareliad

Hey, We have tried to implement a naive but more concise version of the forward pass for the TTLayer (so far for d=3) and it seems that it fixes the problem (in our example memory usage fell from 7700 MB to 1500MB, which is roughly equal to the memory occupied by the standard Fully Connected layer). The code is in the branch memory-fix, and to allow naive (more efficient) forward pass you have to supply the corresponding argument TTLayer(..., naive=True). Probably it can be optimized even further, I'll work on it.

KhrulkovV avatar Aug 14 '19 09:08 KhrulkovV

Excellent, cool! Looks very promising.

I wonder how the "native" solution would scale in terms of compute time and memory consumption.

I can prepare code for d>3, I made a working script for this yesterday for something else. (my last implementation does the entire TT with a single einsum op, it had high memory consumption too)

That's your main change: full = torch.einsum('abcd,defg,ghij->bcefhi', core0, core1, core2) res = torch.einsum('abcd,bqcsdx->aqsx', input, full)

which does

  1. "decompress" , restoring the full matrix.
  2. kind of normal matmul between decompressed and input, but with more dimensions.

Need to fully understand when einsum does a reshape and if it does efficient broadcasting for scaling this.

There are several issues on pytorch repo about einsum, I understood they are working on it:

https://github.com/pytorch/pytorch/issues/10661 https://github.com/pytorch/pytorch/issues/15671

saareliad avatar Aug 14 '19 11:08 saareliad