pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[BUG] CPU Memory Leakage on TPU

Open zw615 opened this issue 3 years ago • 3 comments

Describe the bug I find strange memory behaviour with bits_and_tpu timm code on tpu. If the layer decay learning rate is set (--layer-decay), the memory usage just keep going up. For small models this is fine. For large models this may eventually lead to out-of-memory error and program crashes. However, if the layer decay learning rate is turned off, everything seems good.

This behaviour is also verified by commenting out this line and only leaving the line without lr_scale. In this case, even --layer-decay is set, the memory usage is stable. But I have checked, all the variables in param_group[self.param_group_field] = value * param_group['lr_scale'] are type float. And self.param_group_field] is actually just lr, a key already in the dict, so there is no risk of increasing the dict every step.

Another interesting thing is that even after the program is finished, the cpu memory usage is still over 100GB. I have to use kill -9 to release the memory.

I have also find that turning on the layer decay learning rate will essentially triple the time taken by the first few steps on tpu pods. I think this has something to do with graph compilation. But I fail to understand why updating optimizer learning rate has anything to do with the computational graph.

To Reproduce Steps to reproduce the behavior:

  1. pull the latest bits_and_tpu branch timm code.
  2. This script should be able to reproduce the issue. Of course, other arguments other than --layer-decay don't really matter.
python3 /home/user/Code/pytorch-image-models/launch_xla.py --num-devices 8 /home/user/Code/pytorch-image-models/train.py \
/path/to/gcs/bucket --dataset tfds/imagenet2012:5.1.0 --val-split validation \
--model vit_large_patch16_224 --gp "avg" \
--initial-checkpoint "" \
--resume "" \
--num-classes 1000 --img-size 224 --crop-pct 0.875 --train-interpolation bicubic \
--batch-size 32 \
--opt adamw --opt-eps 1e-8 --weight-decay 0.05 \
--sched cosine --lr 2e-3 --lr-cycle-decay 1.0 --warmup-lr 1e-6 --min-lr 1e-6 \
**--layer-decay 0.75** \
--epochs 50 --warmup-epochs 5 --cooldown-epochs 0 \
--aa rand-m9-mstd0.5-inc1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 --smoothing 0.1 --drop-path 0.2 \
--workers 4 --seed 0 \
--log-interval 200 --pin-mem

Expected behavior This strange behaviour of --layer-decay not only greatly slows down the training speed (on pods the first few steps take almost as much time as the rest 99.9% steps), but also leads to program crashes after a long time training. Now I have to manually kill the program twice a day. It would be nice if one can use tpu smoothly, just like gpu training.

Screenshots cpu_memory

The first half is the result without --layer-decay, and the second half is the result with --layer-decay. It is clear that without --layer-decay, the memory usage is stable. On the other hand, when --layer-decay is turned on, the memory usage keeps going up until the memory limit is reached and program crashes.

Desktop (please complete the following information): This issue is found on multiple tpu vm setup

  • tpu v3-8, with tpu software version tpu-vm-pt-1.12
  • tpu v3 pods, with tpu software version tpu-vm-pt-1.12
  • tpu v4 pods, with tpu software version v2-alpha-tpuv4

Does anyone knows how to fix this one? Thanks a lot!

zw615 avatar Sep 17 '22 01:09 zw615

And by the way, I think this issue is more severe when lr scheduler steps in each iteration instead of each epoch. I know that timm's default behaviour is to step in epoch, which might hide this issue to some extent.

zw615 avatar Sep 17 '22 04:09 zw615

@zeyuwang615 I believe this might be a problematic use case for TPU right now, with Python scalars being used, updating per layer and per step like this might be triggering a recompile, or at least frequent transfers from CPU -> TPU. With the once per epoch it isn't a concern but might be here. In JAX / TF I believe the values for this are held in device tensors, that might be a possible solution, but none of the defaults for PyTorch optimizers and LR schedulers use tensors..hmmm.

You could possibly see this if you turn on PyTorch XLA debugging output

rwightman avatar Sep 23 '22 17:09 rwightman

Do you mean this?

import torch_xla.debug.metrics as met

print(met.metrics_report())

It seems for now there is not a plausible solution for this bug now...

zw615 avatar Sep 25 '22 18:09 zw615