transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Enable traced model for text-generation task

Open jiqing-feng opened this issue 1 year ago • 5 comments

@sywangyi Enable traced model for text-generation task. I changed beam_search and greedy_search of generation for traced model. If a traced model has been set on the attribute of "trace_graph", then we will use the model.trace_grapg to forward. I also changed the text-generation example and found that model optimized by jit trace performs better on text-generation task. The data running on a A100 is as below: model: gptj-6b beam search: input_tokens=32, output_tokens=32, num_beam=4 data type: bf16 original model's latency: 0.96s jit trace model's latency: 0.72s

jiqing-feng avatar Mar 10 '23 06:03 jiqing-feng

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sgugger please help review

sywangyi avatar Mar 13 '23 08:03 sywangyi

@gante Could you have a first look at the changes in generate?

sgugger avatar Mar 13 '23 13:03 sgugger

@gante Hi, Gante. Thanks for your delicate comment, it's reasonable and I agree with it. Here I have two solutions:

  1. For trace_graph in the main body of generate, we can add a doc to explain trace_graph with details, including what it is and how to implement it, and how it helps accelerate inference; For tensor manipulation, the method of preparing input tensors for trace_graph is general for text-generation task across all kinds of models. It can also adapt to any task easily with a few changes(it is in progress) instead of a specific use case. We can put this method on utils in general.
  2. As you said, we can redefine prepare_inputs_for_generation for both inputs and model.trace_graph outputs. However, redefining model.prepare_inputs_for_generation() is not a general way since different model classes have different functions of prepare_inputs_for_generation(), and it is not convenient to inherit different model classes every time we changed the type of model.

I strongly recommend the first way. There are many ways to optimize model.forward, if we can support the attribute trace_graph in the main body of generate, it will be convenient for users to pass their custom models.

BTW, you set return_dict=True in the main body of generate, so it would not work if I set return_dict=False in the .from_pretrain. Could I remove this so the users can decide whether or not to return the dictionary by themselves?

Thanks!

jiqing-feng avatar Mar 14 '23 03:03 jiqing-feng

@jiqing-feng Thank you for your comment.

To clarify my position further, in an attempt to find a solution that pleases us all: from the transformers perspective, our current priority is the ease of use and experimentation. We also welcome performance-enhancing solutions like the one in the PR, but they must fulfill one of three requirements: (i) they are commonly requested by the community; (ii) they require minimal changes to existing functionality; (iii) the benefits of the new technique are very big, like int8 quantization. If we don't adhere to these principles, the codebase will quickly be unusable and hard to maintain, as there are many possible strategies to improve the code.

From my perspective, I haven't seen any request for torch.jit support in .generate(), and I get tagged in pretty much everything .generate()-related. This PR also includes a diff of 50 lines to existing functions in utils.py and the benefit is up to 20% speedup. This means that, according to the principles stated above, I'm afraid can't support the changes as they are 🤗

This doesn't mean that my perspective is static on the subject! I've suggested above what can be done to showcase torch.jit in the example. That is a way to increase the visibility of the technique, which may increase the community demand for it -- and, if this demand does materialize, I'd be more than happy to include the additional logic in utils.py.

I apologize if this is not the answer you'd like to read, but we do have to be picky with the changes we introduce in actively maintained cross-model functionality. I'm also working towards increasing the modularity of .generate(), so that use cases like yours can be more easily added!

gante avatar Mar 14 '23 12:03 gante

Just my +1 , generation speed improvement, especially with torch 2.0 is something very nice for make the model production ready

bratao avatar Mar 20 '23 11:03 bratao

Yes, echo. W/ PyTorch 2.0 introduced, suppose we will see more and more performance benefit out of jit for deployment.

yao-matrix avatar Mar 20 '23 23:03 yao-matrix