nanoGPT
nanoGPT copied to clipboard
non flash attention: speedup by avoiding ddp broadcasts of causal mask
In the manual implementation of causal self-attention, the causal mask is registered as a buffer, which causes DDP to broadcast it at every step. Excluding it from being broadcasted gives some extra speedup (up to 8% in a 8x3090 setup with fake data, probably more for multinode)
In fact, the causal masks are the only buffers in this model, so you can also set broadcast_buffers=False
in train.py:L201 but that could potentially backfire if other buffers that need syncing are introduced in the future.