Refactor transformer implementations, unifying common components
There is a massive amount of duplication and near-duplication between the individual model implementations. Not only does this complicate maintenance, the implementations have also diverged substantially, with inconsistent naming (Mlp vs MLP), inconsistent conventions (e.g. when to use apply vs forward), and inconsistent formatting.
Alas, fixing this is easier said than done, because the logic is often almost-but-not-quite the same for different models. This PR introduces a unified transformer implementation that I believe can eliminate a lot of boilerplate and simplify things, while being generic enough to work for most currently supported models.
Here's an overview of the changes:
- A generic
Attentiontrait is introduced that abstracts the attention step, while exposing the tensors required for ISQ. - Using that trait,
DecoderLayercan become generic, with common fields for norms and a standardizedforwardimplementation. - This in turn is used to create a generic
Modelstruct with a standardizedforwardimplementation. - The standard
Modelis used to provide (with the help of theModelWrappertrait) default implementations ofNormalModel, parts ofIsqModel, and parts ofAnyMoeBaseModelMixin(the latter two throughAuto*traits).
I have refactored Gemma2 and Phi3 according to this paradigm. I chose those two models because they are quite different architecturally, and a system generic enough to accommodate both of them can probably be adapted to most other models as well.
This refactor reduces the size of the Phi3 implementation by 27% and the size of the Gemma2 implementation by 24%. Ported to the remaining models, I expect a net reduction of at least 1000 LoC. And this architecture can serve as a foundation for eliminating further commonality in the future.
I'm opening this PR at this stage to gather feedback. If such a refactor is of interest to you, I'll be happy to port it to the remaining models. I have not tested the new implementations yet, but everything compiles and I expect that any bugs will be minor oversights, not shortcomings of the overall approach.
Code Metrics Report
=============================================================================== Language Files Lines Code Comments Blanks =============================================================================== C Header 2 35 28 0 7 Dockerfile 1 41 22 10 9 JSON 12 105 104 0 1 Python 64 2729 2359 71 299 Shell 1 57 22 18 17 Plain Text 3 3723 0 2413 1310 TOML 18 609 542 2 65 YAML 2 21 19 2 0 ------------------------------------------------------------------------------- Jupyter Notebooks 4 0 0 0 0 |- Markdown 2 77 32 31 14 |- Python 2 205 178 1 26 (Total) 282 210 32 40 ------------------------------------------------------------------------------- Markdown 44 3439 0 2611 828 |- BASH 6 103 100 0 3 |- JSON 1 12 12 0 0 |- Python 7 121 109 0 12 |- Rust 13 440 373 0 67 |- TOML 2 75 63 0 12 (Total) 4190 657 2611 922 ------------------------------------------------------------------------------- Rust 298 91486 82139 1880 7467 |- Markdown 144 1600 25 1454 121 (Total) 93086 82164 3334 7588 =============================================================================== Total 449 102245 85235 7007 10003 ===============================================================================