dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Attention mask in TransformerDecoderBlock?

Open ifsheldon opened this issue 3 years ago • 11 comments

Hi! I've been trying to porting nanoGPT to Rust with dfdx. The transformer module is awesome! but it seems an important trick is missing, which is the attention mask in TransformerDecoderBlock. I took a look at the below lines and didn't find anything about attention mask. Did I miss anything?

https://github.com/coreylowman/dfdx/blob/cbe38a54fad2f58023cbceb0ea9d9e889a34e7f2/src/nn/transformer/decoder.rs#L187

https://github.com/coreylowman/dfdx/blob/cbe38a54fad2f58023cbceb0ea9d9e889a34e7f2/src/nn/transformer/mha.rs#L130

For attention masks, you can refer to Neural Networks: Zero to Hero - Let's build GPT: from scratch, in code, spelled out and the documentation of torch.nn.MultiheadAttention.forward and torch.nn.functional.scaled_dot_product_attention.

ifsheldon avatar Mar 20 '23 07:03 ifsheldon

I think the choose function may do what you are asking, it allows you to use a boolean tensor to choose element-wise between two given tensors.

opfromthestart avatar Mar 20 '23 20:03 opfromthestart

Fairly certian the default impl doesn't have a causal attention mask. You'll need to add it yourself. Here's what I did to the forward function:

assert_eq!(k.shape().0, v.shape().0);
let s1 = q.shape().0;
let s2 = k.shape().0;
let v = self.w_v.try_forward(v.retaped::<T>())?;
let v = v.try_reshape_like(&(s2, H, V / H)).unwrap()?;
let v = v.try_permute::<_, Axes3<1, 0, 2>>()?;

let k = self.w_k.try_forward(k.retaped::<T>())?;
let k = k.try_reshape_like(&(s2, H, K / H)).unwrap()?;
let k = k.try_permute::<_, Axes3<1, 2, 0>>()?;

let q = self.w_q.try_forward(q)?;
let q = q.try_reshape_like(&(s1, H, K / H)).unwrap()?;
let q = q.try_permute::<_, Axes3<1, 0, 2>>()?;

// Get weights
let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt();
let weights = q.try_matmul(k)?.try_mul(scalar)?;
let mut mask = vec![E::zero(); s1.size() * s2.size()];
for i in 0..s1.size() {
    for j in i+1..s2.size() {
        mask[i *  s1.size() + j] = -E::infinity();
    }
}
let mask: Tensor<(S1, S2), _, _> = weights.device.try_tensor_from_vec(mask, (s1, s2)).unwrap();
let weights = weights.try_add(mask.try_broadcast_like(&(H, s1, s2))?)?;
let weights = weights.try_softmax::<Axis<2>>()?;

// Get new tokens
let tokens = weights.try_matmul(v)?;
let tokens = tokens.try_permute::<_, Axes3<1, 0, 2>>()?;
let tokens = tokens.try_reshape_like(&(s1, Const::<V>)).unwrap()?;

self.w_o.try_forward(tokens)

jafioti avatar Mar 20 '23 20:03 jafioti

let weights = weights.try_add(mask.try_broadcast_like(&(H, s1, s2))?)?;

I don't know how dfdx handles adding infinity, but in theory this is not sufficient since addition doesn't block gradient flow in backprop although it blocks attention in the forward pass.

ifsheldon avatar Mar 21 '23 02:03 ifsheldon

@ifsheldon It shouldn't block gradient flow, but the gradients will be subtracted by inf so in practice they should go to zero. Proper masking would be better.

The best would be to build it directly into an MHA Cuda kernel.

jafioti avatar Mar 21 '23 14:03 jafioti

Nope you didn't miss anything, mask isn't currently supported. Luckily we can add it in a non-breaking way by just adding more impl Module for both MultiHeadAttention/Decoder/Transformer that accept an additional tensor input.

Regarding infinity, I know huggingface usually uses the float min value (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py#L166). I'm not sure if this is any difference in practice than using infinity?

Related to this is #436 which is being worked on right now.

chelsea0x3b avatar Mar 21 '23 15:03 chelsea0x3b

It shouldn't block gradient flow, but the gradients will be subtracted by inf so in practice they should go to zero.

Regarding infinity, I know huggingface usually uses the float min value (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py#L166). I'm not sure if this is any difference in practice than using infinity?

The reason why subtracting inf works is that it is immediately followed by a softmax since e^(x - inf) = e^x / e^inf = 0. However, I guess, compared to select or masking, subtracting inf makes (a naive) autograd track a lot of unnecessary compute node since subtraction and softmax do not block compute flow.

Perhaps we can have a functional like torch.nn.functional.scaled_dot_product_attention?

ifsheldon avatar Mar 22 '23 02:03 ifsheldon

autograd track a lot of unnecessary compute node since subtraction and softmax do not block compute flow.

Since we are doing softmax regardless, the only extra computation would be the sub op forward/backwards right? Or am I missing something. I think actually masking and subtracting inf should result in the same amount of computation? Either way the mask or sub(inf) doesn't need a gradient, so the only extra operation is the forward pass

chelsea0x3b avatar Mar 22 '23 13:03 chelsea0x3b

I think actually masking and subtracting inf should result in the same amount of computation?

Note that the below is purely theoretical on paper in terms of tracking compute flow. A sophisticated autograd system that has sophisticated handling on inf and operator fusion should be able to get around the issue.

comparison

On the left is the full attention, middle causal attention with mask, right causal attention with subtracting inf. The light green box is softmax. The black lines in the grid mean gradient flow. You can see that the black circles are disconnected by masking, and the red circles are subtracted by inf. In theory, although the forward results and backward gradients are the same in these two methods, the number of gradient flow routes should be halved in the middle case in implementation. Therefore, the computation on paper can be halved as well.

But as I said, if autograd detects the combination of subtracting inf and softmax and fuses these two ops, then these two cases may be actually the same in implementation.

ifsheldon avatar Mar 23 '23 04:03 ifsheldon

Ahh I see, thanks for the graphic. At the moment dfdx does not support operator fusion, so they would both be the same. This is an interesting direction to go in though, I've been thinking about fusion a lot lately with optimization on my mind.

chelsea0x3b avatar Mar 24 '23 13:03 chelsea0x3b

@coreylowman Fusion would be such a huge win with transformer MHA, just looking at the speed differences between a fused flash attention kernel over a naïve one it’s staggering.

How were you thinking of approaching this though? One of the downsides (and upsides) of rust is it’s much less dynamic. In PyTorch I think they can parse the whole tree of a module and rewrite it at runtime

jafioti avatar Mar 24 '23 14:03 jafioti

Let's move discussion of that into the issue I just made

chelsea0x3b avatar Mar 24 '23 16:03 chelsea0x3b