transformers
transformers copied to clipboard
Refactor Pytorch `model.generate` method to work on TPU
Feature request
Refactor PT version of the method model.generate for text generating models to make it compatible with XLA and speed up inference on TPU.
Motivation
Right now, model.generate on PT is extremely slow on TPU compared to CPU and GPU. This is probably due to the fact that some operations done in the PT version of model.generate are not XLA compatible, and thus the generation process falls back on CPU. This makes inference on TPU infeasible. A major refactoring work has already been done on its TF counterpart, so it would be nice to have the PT version working as well.
A more in-depth discussion with @gante took place in #12322 and on this huggingface discussion.
Your contribution
If there is some interest from the HF team, I can definitely assist during the work.
cc @patrickvonplaten
Hey @mikcnt,
This sounds like a very cool project and I think we should sooner or later focus on it. Currently I won't have the time to take a closer look here, but my advice would be:
- I think you're totally right in that PyTorch/XLA often falls back on CPU which is why it is very slow. We're luckier here with Jax and TF because if things fall back on CPU the code fails
- It'll take some time to get this fully working so we should start with the easiest example -> see what code changes are necessary to make PyTorch/XLA work with
greedy(...) - To set expectations: PyTorch's generate method is one of Transformers most used functions - it's extremely important and we're trying very hard to keep the code readable, easy to understand. If making PyTorch XLA-compatible requires too many changes or makes the code too unreadable we might come to the conclusion that it's just not worth it and maybe just add it as a "experimental" additional function but not in "main" generate. Also @michaelbenayoun @mfuntowicz is that maybe something we want to have only in optimum maybe but not in Transformers?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi,
Any updates on this? When can we expect to generate a function to work on TPUs? Also, will it be part of transformers or optimum? as mentioned by @patrickvonplaten above?
I won't have time to look into this sadly anytime soon. @gante maybe?
Added to my generate task queue 👍
@divyanshuaggarwal it would be part of transformers!
Thanks @gante!
Hi, @gante just noticed it had been marked WIP, any ETAs on when can we expect this feature?
This is not a prioritized feature as you can already use TPUs for generation in Flax and TensorFlow. Since you can easily convert a model from one framework to the other, there is an easy workaround :-)
Is there any update on this PR?
@deveworld we are atm exploring PT-level optimizations, which include the static shapes needed for XLA (TPU). A significant upgrade in this direction is likely in the next releases (keep an eye there :) )
@gante folks from Meta were able to do llama inference on TPU using pytorch XLA. Might be helpful for this issue.
https://pytorch.org/blog/path-achieve-low-inference-latency/?utm_content=254892693&utm_medium=social&utm_source=linkedin&hss_channel=lcp-78618366
Has there been any update on this? When is the next release likely to be released?
We have some code ready, which makes the generation loop friendly with compiled forward passes (e.g. with torch.compile). Pretty much the same algorithm we use with TF/FLAX + XLA.
However, there are performance regressions on some devices, and the PyTorch team is having a look. We will include these changes when the performance bump is consistent across devices.
Meanwhile, feel free to adapt code from this repo/PR.
I see. Will this work on TPU then / are TPUs one of the device that are experiencing performance regressions?
I also looked into the Optimum Neuron greedy decode implementation. While it no longer requires moving computations to CPU, running inference on TPU with it seems significantly slower than on GPU.
@verityw I can't confirm. We are aiming at having models that are fully compatible and efficient to use with torch.compile(), there may be additional issues when selecting the XLA backend :)
Any update on this? I'm trying to work with trl and peft on a TPU slice (to run tests on yet another HF-aspiring lib), but these newer parts of the ecosystem seem to currently only support torch, which is not supported in an XLA-friendly way in the underlying transformers.
I looked into it a bit and it seems that both mostly wrap the transformers generate(), so maybe an XLA-friendly version of that would help throughout? I also expect to encounter other issues of XLA-awkwardness in the backward step of trl, but I don't have a good intuition of that. Would love any pointers to learn about what it takes to make them XLA-friendly and how far the stack is from that.
Not far from seeing the light, actually!
Our current major endeavor in generate is possibility of using different types of caches. By default, caches grow with the input length, but XLA needs a fixed-size cache -- we will be adding it as part of this task. In turn, this should make the forward pass of most models XLA-compatible (or close to it).
Any updates on this @gante ?
Yes: https://github.com/huggingface/transformers/pull/27931 (it is a pre requisite :) )