Kevin Yin
Kevin Yin
As a side note, I tried `@torch.compile` on PyTorch's un-fused AdamW, after you mentioned it. TFLOPS goes from 85 to 4, haha. It's because `torch.compile` keeps recompiling when the LR...
Using `AdamW(lr=torch.tensor(...))` with scheduler active, TFLOPS went from 85 to 88. Zoom zoom! LR scheduler is LambdaLR with regular non-tensor floats.
icy χατγιρλ: @ad8e I can reproduce the performance dropoff under torch 2.2 icy χατγιρλ: compiling only forward + loss, enabling the fused optimizer completely destroys the performance icy χατγιρλ: but...
PyTorch nightly: fused AdamW 69.5 TFLOPS unfused AdamW 70.5 TFLOPS One of the two fused AdamW runs has a wobbly TFLOPS line, going up and down around 69.5. ``` PyTorch...
The test he wrote is actually correct, but it also shows that varlen is working correctly; no difference in acc between naive and flash. Tested on A40, FA 2.4.2. Flash...
Adding a zero key at the beginning works, and the speed penalty is
How do we handle fused layers with DTensor? For example, in SwiGLU, there are frequently two input matrices in the FF layer. These two matrices are fused into one big...
On a transformer, with TP + Sequence Parallel (i.e. what torchtitan is doing), we should theoretically be able to hide all the TP comms inside the computation. Is this planned...
I'm observing the same: amp + bf16 + ParallelMLP works, but amp + bf16 + grouped ParallelDroplessMLP doesn't work.
Dropless works with MixedPrecision from FSDP2 though.