imagen-pytorch
imagen-pytorch copied to clipboard
Partial ignore time causes increased memory usage
Hello,
I have been training models with ignore_time=True
randomly activated at every step, and this seems to significantly increase the memory usage.
For example on a dummy task:
If I train with a fixed ignore_time=True
I see ~4.5Go of memory usage.
If I train with a fixed ignore_time=False
I see ~6.3Go of memory usage (which makes sense since more layers are used)
If I train with ignore_time=False
and ignore_time=True
with a 50% chance at each step, I see 8.4Go of memory usage.
This only happens when using accelerate / multi-gpu, never with single-gpu training, and causes OOM errors.
I tried setting the value of ignore_time
to the parity of the current training step so when using accelerate, all instances would be training with the same value for ignore_time
but it did not help.
The problem does not happen when memory_efficient=True
is set.
A side note: The new attention stabilizing trick seems to increase memory usage significantly (~30%), maybe it could be deactivated with a flag ? It would also help with backward compatibility for trained models.