burn
burn copied to clipboard
Transformer prenorm location
Feature description
Currentlt, the transformer module has a norm_first
flag, but I think it is not used as intended.
Currently, is does this:
x=norm(x); x += FF(x)
According to https://arxiv.org/pdf/2002.04745.pdf, it should be done after diverging to the residual path:
x += FF(norm(x))
This is the recommended way to do normalizaiton in transformers today AFAIU.
I tried making a PR https://github.com/tracel-ai/burn/pull/1054
but for some reaosn I don't understand, the rest fails only on the torch backend...
@Philonoist there is a bug in the tch backend, I'm going to work on it very soon, so you can ignore them for now.
Fixed in #1054