Use BetterTransformer for fast inference
🚀 Feature
PyTorch 1.12 came with a speed bump in Transformer encoder inference speeds through BetterTransformer. Because of Unbabel's efforts to improve usability and efficiency (~COMETINHO), such improvements might be a useful addition to the library.
Implementation
Implementation is relatively straightforward (blog, docs) to implement. This requires accelerate and optimum, which can be optional dependencies (pip install unbabel-comet[bettertransformer]). If both of these are installed (with sufficient versions) alongside PyTorch 1.13 (min. requirement to work with Optimum API), then we can set a global constant BETTER_TRANSFORMER_AVAILABLE = True.
As an example, this line:
https://github.com/Unbabel/COMET/blob/master/comet/encoders/bert.py#L38
should be followed by:
if self.use_bettertransformer and BETTER_TRANSFORMER_AVAILABLE:
self.model = BetterTransformer.transform(self.model, keep_original_model=False)
And that's about it.
Of course, which class to implement use_bettertransformer is something that has to be decided but apart from that I think this is a feasible addition that can lead to significant speed improvements.
I can work on this if needed but I need guidance on which class to implement self.use_bettertransformer in and whether the logic (if-statement) above should be implemented on a per-model basis or if we can generalize it somehow.
yep this seems like a good idea.
I'll try to find some time to test it! Thanks!