metnet icon indicating copy to clipboard operation
metnet copied to clipboard

Attention Layer Bottleneck

Open ValterFallenius opened this issue 2 years ago • 11 comments

Found a bottleneck: the attention layer I have found a potential bottleneck for why bug #22 occurred. It seems like the axial attention layer is some kind of bottleneck. I ran the network for 1000 epochs to try to overfit a small subset of 4 samples. See run at WandB. The network is not able to drop the loss at all almost and does not overfit the data, it yields a very bad result, some kind of mean. See image below:

yhat against y

After removing the axial attention layer the model does as expected and overfits the training data, see below after 100 epochs:

yhat against y without attention

The message from the author listed in #19 does mention that our implementation of axial attention seems to be very different from theirs, he says: "Our (Google's) heads were small MLPs as far as I remember (I'm not at google anymore so do not have access to the source code)." I am not experienced enough to look into the source code of our Axial Attention Library to see how this differs from theirs.

  1. What are heads in the axial attention? What does the number of heads have to do with anything?
  2. Are we doing both vertical and horizontal attention passes in our implementation?

ValterFallenius avatar Mar 15 '22 10:03 ValterFallenius

So, this is the actual implementation being used for axial attention: https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L153-L178 which seems like ti is doing both vertical and horizontal passes. But, I just realized that we don't actually do any position embeddings, other than the lat/lon inputs, before passing to the axial attention. So we might need to add that and see what happens? The number of heads is the number of heads for multi-headed attention. So we can probably just set it to one and be fine I think.

jacobbieker avatar Mar 15 '22 12:03 jacobbieker

Okay, how do we do this?

If we have 8 channels in the RNN output with 28×28 height and width, is this embedding information of which pixel we are in? I am struggling a bit wrapping my head around attention and axial attention...

Also when you say set the number of heads to 1 you mean for debugging, right? We still want multi head attention to replicate their model in the end.

ValterFallenius avatar Mar 15 '22 12:03 ValterFallenius

Yeah, set to 1 for the debugging to get it to overfit first. And yeah, the position embedding is saying which pixel we are in, and the location information of that pixel related to other pixels in the input. The library has a function for it, so we can probably just do this for where we use the axial attention: https://github.com/openclimatefix/metnet/pull/25

jacobbieker avatar Mar 15 '22 13:03 jacobbieker

I just realized that we don't actually do any position embeddings, other than the lat/lon inputs, before passing to the axial attention. So we might need to add that and see what happens?

I don't know if it's relevant but it recently occured to me that MetNet version 1 is quite similar to the Temporal Fusion Transformer (TFT) (also from Google!), except MetNet has 2 spatial dimensions, whilst TFT is for timeseries without any (explicit) spatial dimensions. In particular, both TFT and MetNet use an RNN followed by multi-head attention. In the TFT paper, they claim that the RNN generates a kind of learnt position encoding. So they don't bother with a "hand-crafted" position encoding.

The TFT paper says:

[The LSTM] also serves as a replacement for standard positional encoding, providing an appropriate inductive bias for the time ordering of the inputs.

JackKelly avatar Mar 15 '22 13:03 JackKelly

I can confirm initial tests show promising results now, the networks seems to learn something now :) I'll be back with more results in a few days.

ValterFallenius avatar Mar 15 '22 15:03 ValterFallenius

@all-contributors please add @jacobbieker for code

peterdudfield avatar Sep 07 '22 15:09 peterdudfield

@peterdudfield

I've put up a pull request to add @jacobbieker! :tada:

allcontributors[bot] avatar Sep 07 '22 15:09 allcontributors[bot]

@all-contributors please add @JackKelly for code

peterdudfield avatar Sep 07 '22 15:09 peterdudfield

@peterdudfield

I've put up a pull request to add @JackKelly! :tada:

allcontributors[bot] avatar Sep 07 '22 15:09 allcontributors[bot]

@all-contributors please add @ValterFallenius for userTesting

peterdudfield avatar Sep 07 '22 15:09 peterdudfield

@peterdudfield

I've put up a pull request to add @ValterFallenius! :tada:

allcontributors[bot] avatar Sep 07 '22 15:09 allcontributors[bot]