aj
aj
I am facing a similar issue with fsdp2 enabled: ``` m = nn.Sequential( nn.Linear(4096, 4096*3, bias=False), nn.Linear(4096*3, 4096, bias=False), ).to(device=device, dtype=torch.bfloat16) x = torch.randn(32000, 4096, device="cuda", dtype=torch.bfloat16) ``` With FP8:...
Got it, thanks a lot for the clarification. @danielvegamyhre so if my understanding is correct this is different from transformer engine's impl, where activations might be stored in FP8?
Hi, Is there any update on this?