After applying commit 6a3a9da, the training job runs out of GPU memory (OOM).
Bug description
My training job (Flux-dev model) on AMD GPUs encounters an out-of-memory (OOM) issue after merging the following commit from the main branch: commit id : 6a3a9da9564d82a1120c7639ef6236bb4cffa049 Refactor attention and make attention mask an argument to the model related PR : Refactor attention and make attention mask an argument to the model
Reverting this commit resolves the problem, so it’s possible that this change introduces additional GPU memory usage ?
Versions
torchtitan commit id : 6a3a9da9564d82a1120c7639ef6236bb4cffa049 torch : 2.10.0.dev20250914+rocm6.4
I'm surprised Flux is affected as it doesn't use FlexAttention. Can I get the command you use? Is this specific for AMD GPUs? Also how many steps have you ran? Or you cannot even run the first step?
I checked the code, Flux doesn't seem to use attention.py and has its own train.py. So Flux shouldn't be affected by the refactor. @wwwjn is my understanding correct? Or do I miss anything?
oh I think the bug was introduced here -- now with wrong indentation https://github.com/pytorch/torchtitan/pull/1776/files#diff-83b7868cc3b5fde38ae75ccd8346675495ed27207bc75c422cf8c2ef4d8096d3L210-L218
oh I think the bug was introduced here -- now with wrong indentation https://github.com/pytorch/torchtitan/pull/1776/files#diff-83b7868cc3b5fde38ae75ccd8346675495ed27207bc75c422cf8c2ef4d8096d3L210-L218
Can you elaborate more on this? Why this causing memory usage increase?
I think it is a different issue. Flux does not support CP yet.