transformers
transformers copied to clipboard
Enable traced model for text-generation task
@sywangyi Enable traced model for text-generation task. I changed beam_search and greedy_search of generation for traced model. If a traced model has been set on the attribute of "trace_graph", then we will use the model.trace_grapg to forward. I also changed the text-generation example and found that model optimized by jit trace performs better on text-generation task. The data running on a A100 is as below: model: gptj-6b beam search: input_tokens=32, output_tokens=32, num_beam=4 data type: bf16 original model's latency: 0.96s jit trace model's latency: 0.72s
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
@sgugger please help review
@gante Could you have a first look at the changes in generate?
@gante Hi, Gante. Thanks for your delicate comment, it's reasonable and I agree with it. Here I have two solutions:
- For
trace_graph
in the main body ofgenerate
, we can add a doc to explaintrace_graph
with details, including what it is and how to implement it, and how it helps accelerate inference; For tensor manipulation, the method of preparing input tensors fortrace_graph
is general for text-generation task across all kinds of models. It can also adapt to any task easily with a few changes(it is in progress) instead of a specific use case. We can put this method on utils in general. - As you said, we can redefine
prepare_inputs_for_generation
for both inputs andmodel.trace_graph
outputs. However, redefiningmodel.prepare_inputs_for_generation()
is not a general way since different model classes have different functions ofprepare_inputs_for_generation()
, and it is not convenient to inherit different model classes every time we changed the type of model.
I strongly recommend the first way. There are many ways to optimize model.forward
, if we can support the attribute trace_graph
in the main body of generate
, it will be convenient for users to pass their custom models.
BTW, you set return_dict=True
in the main body of generate, so it would not work if I set return_dict=False
in the .from_pretrain
. Could I remove this so the users can decide whether or not to return the dictionary by themselves?
Thanks!
@jiqing-feng Thank you for your comment.
To clarify my position further, in an attempt to find a solution that pleases us all: from the transformers
perspective, our current priority is the ease of use and experimentation. We also welcome performance-enhancing solutions like the one in the PR, but they must fulfill one of three requirements: (i) they are commonly requested by the community; (ii) they require minimal changes to existing functionality; (iii) the benefits of the new technique are very big, like int8 quantization. If we don't adhere to these principles, the codebase will quickly be unusable and hard to maintain, as there are many possible strategies to improve the code.
From my perspective, I haven't seen any request for torch.jit
support in .generate()
, and I get tagged in pretty much everything .generate()
-related. This PR also includes a diff of 50 lines to existing functions in utils.py
and the benefit is up to 20% speedup. This means that, according to the principles stated above, I'm afraid can't support the changes as they are 🤗
This doesn't mean that my perspective is static on the subject! I've suggested above what can be done to showcase torch.jit
in the example. That is a way to increase the visibility of the technique, which may increase the community demand for it -- and, if this demand does materialize, I'd be more than happy to include the additional logic in utils.py
.
I apologize if this is not the answer you'd like to read, but we do have to be picky with the changes we introduce in actively maintained cross-model functionality. I'm also working towards increasing the modularity of .generate()
, so that use cases like yours can be more easily added!
Just my +1 , generation speed improvement, especially with torch 2.0 is something very nice for make the model production ready
Yes, echo. W/ PyTorch 2.0 introduced, suppose we will see more and more performance benefit out of jit for deployment.