rust-bert icon indicating copy to clipboard operation
rust-bert copied to clipboard

Potentially use BetterTransformer from PyTorch

Open lerouxrgd opened this issue 1 year ago • 1 comments

Hello,

As describe on PyTorch's blog since version 1.12 it is possible to have significantly faster transformers.

To benefit from it in Python one has to use pre-built modules such as TransformerEncoder. Looking at the source code it seems to boil down to using _transformer_encoder_layer_fwd which is also available in tch-rs.

Do you think it would be possible to make use of it in rust-bert ? I can have a look at it if you think it is worth investigating.

lerouxrgd avatar Nov 21 '22 17:11 lerouxrgd

Hello @lerouxrgd ,

Yes the availability of BetterTransformer is an interesting development. The challenge for an integration in the library is twofold:

  1. a lot of the language models implemented implement the attention mechanism from scratch, often with subtle differences that may differ from the BetterTransformer module.
  2. even if the logic of the transformer block would be identical between the base implementation and BetterTransformer, the submodule and parameters may have different names that will not be loaded correctly using a torch.load_state_dict in Python (or varstore.load in the Rust version). The weight may have to be re-exported with updated variable names causing a lack of backward compatibility if the old one is removed.

It may be worth to keep an eye on the related issues on the Python's library (e.g. https://github.com/huggingface/transformers/issues/20372 , https://github.com/huggingface/transformers/pull/19966, https://github.com/huggingface/transformers/pull/19632) and the documentation page at https://huggingface.co/docs/optimum/bettertransformer/tutorials/contribute

guillaume-be avatar Nov 26 '22 09:11 guillaume-be