performer-pytorch icon indicating copy to clipboard operation
performer-pytorch copied to clipboard

SelfAttention layer seems to have large error relative to nn.MultiheadAttention?

Open jueseph opened this issue 3 years ago • 8 comments

Hi,

I started trying to use this and the first thing I did was compare the error between the performer_pytorch.SelfAttention layer and torch.nn.MultiheadAttention for different sizes of the random feature map. I was a little surprised to see that it never went below 100% relative error. image Am I doing something wrong? This analysis was done using this colab notebook: https://colab.research.google.com/drive/1vemlPOySWtDdB2Xfm7YCE7--PYtIbelS?usp=sharing

Thanks!

jueseph avatar Dec 12 '20 03:12 jueseph

Screenshot from 2020-12-16 17-20-19 I'm having an issue with getting good performance with Performers (using this library) as well. Just training a small scale BERT model. Red is Performer, orange is Transformer. Using 256 random features for a sequence length of 256. Performer doesn't seem to be able to converge even close to Transformer.

calclavia avatar Dec 17 '20 01:12 calclavia

Also, a quick note: @jueseph Replicating your notebook, I got MSE error as such: Screenshot from 2020-12-16 18-10-47

https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html But pytorch's MHA requires tensor to be of [sequence, batch, features], so I believe your notebook's implementation should've been:

    x = torch.randn(1, 1024, 512).cuda()

    x_t = x.transpose(0, 1)
    y1 = attn1(query=x_t, key=x_t, value=x_t)[0].transpose(0, 1) # (1, 1024, 512)

which gives a much more reasonable MSE: Screenshot from 2020-12-16 18-09-12

calclavia avatar Dec 17 '20 02:12 calclavia

@calclavia thanks for catching that!

yes, looks like now I can get low MSEs, but the mean relative error is still somewhat large: image

if i look at the individual elements of the attention output matrix, it seems that relative error gets large when values are close to zero, which makes sense. however, even when the outputs have relatively large values, relative error is mostly between 10% and 100% (dotted lines at -1 and 0 on y-axis). i'm not sure if this is a problem with the implementation. could be a fundamental issue with the random features approximation. but the authors never showed relative error in the paper, so we can't know either way.

image

Here is the notebook again, in case anyone can see something wrong with how I'm doing things. https://colab.research.google.com/drive/1vemlPOySWtDdB2Xfm7YCE7--PYtIbelS?usp=sharing

jueseph avatar Dec 17 '20 04:12 jueseph

@lucidrains Have you tested your FastAttention implementation against JAX implementation (perhaps a unit test e.g. same input tensor), like the one I've written here for my incomplete implementation: https://github.com/calclavia/Performer-Pytorch/blob/main/test.py#L11

calclavia avatar Dec 17 '20 07:12 calclavia

@jueseph Tried running your notebook test from a different implementation https://github.com/r0mainK/outperformer/blob/main/src/performer.py

Screenshot from 2020-12-17 00-47-49

Screenshot from 2020-12-17 01-09-09

Similar result. Which implies that this could be correct behavior. The metric absolute relative error probably isn't the best metric, because near-zero you're dividing by y1 which is close to zero, so you'd get a large number.

Making the same graph with MSE: Screenshot from 2020-12-17 01-13-06

calclavia avatar Dec 17 '20 09:12 calclavia

yea, unfortunately I would say Performer still needs to prove itself. I haven't had much luck training Performer at context lengths greater than 2048. Others have told me the same. I built this repository in the case someone can find a solid use for it, because I believe Performer to be the pinnacle of what linear attention could be.

lucidrains avatar Dec 18 '20 19:12 lucidrains

@lucidrains FWIW, when training on sequences of images, Performer has enabled me to scale to much larger spatio-temporal sequences than quadratic attention would allow. Unfortunately, I don't have a robust quality-metric comparison yet, but I'm not even sure how I would fit these volumes into a quadratic attention mechanism on my available hardware. The FAVOR+ algorithm is the best theoretically justified linear attention mechanism I've seen, so I'm very interested in finding the tweaks needed to make it work well. Or at least, I'll try to run experiments to demonstrate where improvement is needed.

Erotemic avatar Oct 03 '21 15:10 Erotemic

I should mention that despite my initial experiments above showing some inaccuracies in FAVOR+ approximating attention directly, as part of a trained model it works reasonably well. And in particular, in our application for protein structure prediction (RoseTTAFold), we found that Performer was worse than vanilla attention for a given model size (in # parameters), but allows a larger model (more transformer blocks) for a given GPU memory and therefore better performance.

On Sun, Oct 3, 2021 at 8:20 AM Jon Crall @.***> wrote:

@lucidrains https://github.com/lucidrains FWIW, when training on sequences of images, Performer has enabled me to scale to much larger spatio-temporal sequences than quadratic attention would allow. Unfortunately, I don't have a robust quality-metric comparison yet, but I'm not even sure how I would fit these volumes into a quadratic attention mechanism on my available hardware. The FAVOR+ algorithm is the best theoretically justified linear attention mechanism I've seen, so I'm very interested in finding the tweaks needed to make it work well. Or at least, I'll try to run experiments to demonstrate where improvement is needed.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/lucidrains/performer-pytorch/issues/46#issuecomment-932972228, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABCU57IHMGBWMVGAIZVOALUFBYC5ANCNFSM4UXUOG7Q . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

jueseph avatar Oct 04 '21 17:10 jueseph