burn
burn copied to clipboard
Move transformer prenorm on the residual path
Pull Request Template
Checklist
- [X] Confirmed that
run-checks all
script has been executed. - [X] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
Changes
The transformer module with the prenorm setting, normalized before the residual path.
According to https://arxiv.org/pdf/2002.04745.pdf, it should be done after diverging to the residual path.
Instead of x=norm(x); x += FF(x)
, we should do this: x += FF(norm(x))
This is the recommended way to do normalizaiton in transformers today AFAIU.
Testing
I ran the tests.
Hi @Philonoist The changes look good to me, although I'm not an expert in transformers. It's good that you've documented the steps! However the test_autoregressive_norm_first seem to fail in the CI. Is it expected that values change with your new implementation?
Hey. It only fails on torch backed for some reason. It passes ndarray backend. Have any clue why that might happeb?
Hey. It only fails on torch backed for some reason. It passes ndarray backend. Have any clue why that might happeb?
It's possible the implementation of some op for Torch backend has a bug.
Indeed, see @nathanielsimard 's comment: https://github.com/tracel-ai/burn/issues/1056
@spapinistarkware I fixed a bug in tch, make sure you are updated with the new changes.
@Philonoist can you merge main in your branch so it can retrigger the CI and we'll see if tch tests pass?
Merged.
Codecov Report
Attention: 2 lines
in your changes are missing coverage. Please review.
Comparison is base (
535458e
) 85.61% compared to head (e4dba4d
) 85.68%.
Files | Patch % | Lines |
---|---|---|
burn-core/src/nn/transformer/decoder.rs | 97.43% | 2 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1054 +/- ##
==========================================
+ Coverage 85.61% 85.68% +0.07%
==========================================
Files 513 513
Lines 56832 56893 +61
==========================================
+ Hits 48658 48751 +93
+ Misses 8174 8142 -32
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.