transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Refactor Pytorch `model.generate` method to work on TPU

Open mikcnt opened this issue 1 year ago • 3 comments

Feature request

Refactor PT version of the method model.generate for text generating models to make it compatible with XLA and speed up inference on TPU.

Motivation

Right now, model.generate on PT is extremely slow on TPU compared to CPU and GPU. This is probably due to the fact that some operations done in the PT version of model.generate are not XLA compatible, and thus the generation process falls back on CPU. This makes inference on TPU infeasible. A major refactoring work has already been done on its TF counterpart, so it would be nice to have the PT version working as well.

A more in-depth discussion with @gante took place in #12322 and on this huggingface discussion.

Your contribution

If there is some interest from the HF team, I can definitely assist during the work.

mikcnt avatar Aug 17 '22 09:08 mikcnt

cc @patrickvonplaten

gante avatar Aug 17 '22 09:08 gante

Hey @mikcnt,

This sounds like a very cool project and I think we should sooner or later focus on it. Currently I won't have the time to take a closer look here, but my advice would be:

  • I think you're totally right in that PyTorch/XLA often falls back on CPU which is why it is very slow. We're luckier here with Jax and TF because if things fall back on CPU the code fails
  • It'll take some time to get this fully working so we should start with the easiest example -> see what code changes are necessary to make PyTorch/XLA work with greedy(...)
  • To set expectations: PyTorch's generate method is one of Transformers most used functions - it's extremely important and we're trying very hard to keep the code readable, easy to understand. If making PyTorch XLA-compatible requires too many changes or makes the code too unreadable we might come to the conclusion that it's just not worth it and maybe just add it as a "experimental" additional function but not in "main" generate. Also @michaelbenayoun @mfuntowicz is that maybe something we want to have only in optimum maybe but not in Transformers?

patrickvonplaten avatar Aug 17 '22 15:08 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 16 '22 15:09 github-actions[bot]

Hi,

Any updates on this? When can we expect to generate a function to work on TPUs? Also, will it be part of transformers or optimum? as mentioned by @patrickvonplaten above?

divyanshuaggarwal avatar Sep 24 '22 18:09 divyanshuaggarwal

I won't have time to look into this sadly anytime soon. @gante maybe?

patrickvonplaten avatar Sep 27 '22 11:09 patrickvonplaten

Added to my generate task queue 👍

@divyanshuaggarwal it would be part of transformers!

gante avatar Sep 28 '22 12:09 gante

Thanks @gante!

divyanshuaggarwal avatar Sep 28 '22 16:09 divyanshuaggarwal