nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

Cannot run train due to flash

Open david-waterworth opened this issue 1 year ago • 0 comments

I noticed the comment that you're using torch 2.0 and if you encounter warnings to set --compile=False

Problem I'm running into is flash is auto-detected

    # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
    self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
    if not self.flash:
        print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention")
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                    .view(1, 1, config.block_size, config.block_size))

But then dropout must be zero, i.e. model.py#68

    # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
    if self.flash:
        # efficient attention using Flash Attention CUDA kernels
        assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
    else:
        # manual implementation of attention

I set dropout==0.0 in the toy example as a check and it worked, but it seems that it's probably better to disable flash rather than dropout until it is supported by flash?

david-waterworth avatar Jan 31 '23 04:01 david-waterworth