RF PT could improve memory consumption (in general, but see attention with rel.pos.enc. as example)
Download the attached zipped pickle file: dlm_traindata_debug_snapshot_002.pickle.zip (via @dorian-K)
Unzip it.
Drag & drop the pickle file to here: https://docs.pytorch.org/memory_viz
That looks like this:
None of the functions/ops in RF are done inplace. As you can see, this can result in quite heavy memory allocations. E.g. in the attention, the scores *= self.key_dim_per_head.dimension**-0.5 will make a copy. Or the pad in _rel_pos_enc_shift allocates a big 4.5GB tensor here.
Unfortunately, I think any solution to improve this problem will be quite involved/complex. We somehow need to figure out where we can safely do inplace. Sometimes we can know this in advance (e.g. check whether grad is required, or so; still not totally trivial, will add some overhead), but often, we can not, and things like pad also need further work. Maybe sth like torch.compile (or the equivalent in JAX; or sth equivalent on RF side) could do that, but that also would be a huge effort. Frameworks like graph-based TF or JAX probably partly solve this, but not sure if they solve all of it (and if they don't, then it will also be difficult).