transformers
transformers copied to clipboard
enable Pipeline to get device from model
What does this PR do?
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, pipeline
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
print(model.device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(pipe.model.device)
results = pipe("He's a dreadful magician and")
Currently, the code above will give an output of
cuda:0
cpu
But this is not OK: when users have moved the model to CUDA, Pipeline should not move the model back to CPU without showing any message. This PR makes it possible to let the model stay on its original device. Below is the results after this PR:
cuda:0
cuda:0
@Narsil and @muellerzr
@yao-matrix
@faaany are we sure that model.device is a thing across all these frameworks?
At most I see ModuleUtilsMixin has device which is PyTorch specific (it gets added to AutoModel, but I'd like to verify the locations of TF and Flax backends having these capabilities to grab the model device. Otherwise we don't really want just None here IMO
Thanks for adding this!
Could you add a test?
sure, in which test file should I put this test?
@faaany are we sure that
model.deviceis a thing across all these frameworks?At most I see
ModuleUtilsMixinhasdevicewhich is PyTorch specific (it gets added toAutoModel, but I'd like to verify the locations of TF and Flax backends having these capabilities to grab the model device. Otherwise we don't really want justNonehere IMO
Good point! Yes, I know that Flax model doesn't have "device". How about moving it inside if is_torch_available() and self.framework == "pt": ? I have updated my code.
Furthermore, I removed the self.device is not None check, because it will never be None. And I also added the logic that model shouldn't be moved, if the model is already on device.
Hi @amyeroberts, sorry for the late response. We had a long holiday here in China. Unit tests are added. Let me explain more about in detail:
There are 3 possibilities for model.device:
a1. user passes device_map to from_pretrained
a2. user doesn't pass device_map to from_pretrained
a3. user manually moves the model to a certain device with to(device) after model is loaded with from_pretrained
There are 2 possibilities for pipeline.device:
b1. user passes device to pipeline
b2. user doesn't pass device to pipeline
Sincea2&b2 is trivial, my unit tests cover the cases a1&b1, a1&b2, a3&b1 and a3&b2. Pls have a review, thx!
Thanks for the review! @amyeroberts @muellerzr