nanoGPT
nanoGPT copied to clipboard
Cannot run train due to flash
I noticed the comment that you're using torch 2.0 and if you encounter warnings to set --compile=False
Problem I'm running into is flash is auto-detected
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
But then dropout must be zero, i.e. model.py#68
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
else:
# manual implementation of attention
I set dropout==0.0
in the toy example as a check and it worked, but it seems that it's probably better to disable flash
rather than dropout until it is supported by flash?