mistral.rs icon indicating copy to clipboard operation
mistral.rs copied to clipboard

Refactor transformer implementations, unifying common components

Open p-e-w opened this issue 1 year ago • 1 comments

There is a massive amount of duplication and near-duplication between the individual model implementations. Not only does this complicate maintenance, the implementations have also diverged substantially, with inconsistent naming (Mlp vs MLP), inconsistent conventions (e.g. when to use apply vs forward), and inconsistent formatting.

Alas, fixing this is easier said than done, because the logic is often almost-but-not-quite the same for different models. This PR introduces a unified transformer implementation that I believe can eliminate a lot of boilerplate and simplify things, while being generic enough to work for most currently supported models.

Here's an overview of the changes:

  1. A generic Attention trait is introduced that abstracts the attention step, while exposing the tensors required for ISQ.
  2. Using that trait, DecoderLayer can become generic, with common fields for norms and a standardized forward implementation.
  3. This in turn is used to create a generic Model struct with a standardized forward implementation.
  4. The standard Model is used to provide (with the help of the ModelWrapper trait) default implementations of NormalModel, parts of IsqModel, and parts of AnyMoeBaseModelMixin (the latter two through Auto* traits).

I have refactored Gemma2 and Phi3 according to this paradigm. I chose those two models because they are quite different architecturally, and a system generic enough to accommodate both of them can probably be adapted to most other models as well.

This refactor reduces the size of the Phi3 implementation by 27% and the size of the Gemma2 implementation by 24%. Ported to the remaining models, I expect a net reduction of at least 1000 LoC. And this architecture can serve as a foundation for eliminating further commonality in the future.

I'm opening this PR at this stage to gather feedback. If such a refactor is of interest to you, I'll be happy to port it to the remaining models. I have not tested the new implementations yet, but everything compiles and I expect that any bugs will be minor oversights, not shortcomings of the overall approach.

p-e-w avatar Jan 06 '25 08:01 p-e-w

Code Metrics Report
  ===============================================================================
 Language            Files        Lines         Code     Comments       Blanks
===============================================================================
 C Header                2           35           28            0            7
 Dockerfile              1           41           22           10            9
 JSON                   12          105          104            0            1
 Python                 64         2729         2359           71          299
 Shell                   1           57           22           18           17
 Plain Text              3         3723            0         2413         1310
 TOML                   18          609          542            2           65
 YAML                    2           21           19            2            0
-------------------------------------------------------------------------------
 Jupyter Notebooks       4            0            0            0            0
 |- Markdown             2           77           32           31           14
 |- Python               2          205          178            1           26
 (Total)                            282          210           32           40
-------------------------------------------------------------------------------
 Markdown               44         3439            0         2611          828
 |- BASH                 6          103          100            0            3
 |- JSON                 1           12           12            0            0
 |- Python               7          121          109            0           12
 |- Rust                13          440          373            0           67
 |- TOML                 2           75           63            0           12
 (Total)                           4190          657         2611          922
-------------------------------------------------------------------------------
 Rust                  298        91486        82139         1880         7467
 |- Markdown           144         1600           25         1454          121
 (Total)                          93086        82164         3334         7588
===============================================================================
 Total                 449       102245        85235         7007        10003
===============================================================================
  

github-actions[bot] avatar Jan 06 '25 08:01 github-actions[bot]