COMET icon indicating copy to clipboard operation
COMET copied to clipboard

Use BetterTransformer for fast inference

Open BramVanroy opened this issue 2 years ago • 1 comments

🚀 Feature

PyTorch 1.12 came with a speed bump in Transformer encoder inference speeds through BetterTransformer. Because of Unbabel's efforts to improve usability and efficiency (~COMETINHO), such improvements might be a useful addition to the library.

Implementation

Implementation is relatively straightforward (blog, docs) to implement. This requires accelerate and optimum, which can be optional dependencies (pip install unbabel-comet[bettertransformer]). If both of these are installed (with sufficient versions) alongside PyTorch 1.13 (min. requirement to work with Optimum API), then we can set a global constant BETTER_TRANSFORMER_AVAILABLE = True.

As an example, this line:

https://github.com/Unbabel/COMET/blob/master/comet/encoders/bert.py#L38

should be followed by:

if self.use_bettertransformer and BETTER_TRANSFORMER_AVAILABLE:
    self.model = BetterTransformer.transform(self.model, keep_original_model=False)

And that's about it.

Of course, which class to implement use_bettertransformer is something that has to be decided but apart from that I think this is a feasible addition that can lead to significant speed improvements.

I can work on this if needed but I need guidance on which class to implement self.use_bettertransformer in and whether the logic (if-statement) above should be implemented on a per-model basis or if we can generalize it somehow.

BramVanroy avatar Mar 03 '23 10:03 BramVanroy

yep this seems like a good idea.

I'll try to find some time to test it! Thanks!

ricardorei avatar Mar 06 '23 08:03 ricardorei