returnn icon indicating copy to clipboard operation
returnn copied to clipboard

RF PT could improve memory consumption (in general, but see attention with rel.pos.enc. as example)

Open albertz opened this issue 6 months ago • 0 comments

Download the attached zipped pickle file: dlm_traindata_debug_snapshot_002.pickle.zip (via @dorian-K) Unzip it. Drag & drop the pickle file to here: https://docs.pytorch.org/memory_viz That looks like this: Image

None of the functions/ops in RF are done inplace. As you can see, this can result in quite heavy memory allocations. E.g. in the attention, the scores *= self.key_dim_per_head.dimension**-0.5 will make a copy. Or the pad in _rel_pos_enc_shift allocates a big 4.5GB tensor here.

Unfortunately, I think any solution to improve this problem will be quite involved/complex. We somehow need to figure out where we can safely do inplace. Sometimes we can know this in advance (e.g. check whether grad is required, or so; still not totally trivial, will add some overhead), but often, we can not, and things like pad also need further work. Maybe sth like torch.compile (or the equivalent in JAX; or sth equivalent on RF side) could do that, but that also would be a huge effort. Frameworks like graph-based TF or JAX probably partly solve this, but not sure if they solve all of it (and if they don't, then it will also be difficult).

albertz avatar Jun 19 '25 09:06 albertz