transformers icon indicating copy to clipboard operation
transformers copied to clipboard

enable Pipeline to get device from model

Open faaany opened this issue 1 year ago • 5 comments

What does this PR do?

import torch 
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, pipeline

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
print(model.device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(pipe.model.device)
results = pipe("He's a dreadful magician and")

Currently, the code above will give an output of

cuda:0
cpu

But this is not OK: when users have moved the model to CUDA, Pipeline should not move the model back to CPU without showing any message. This PR makes it possible to let the model stay on its original device. Below is the results after this PR:

cuda:0
cuda:0

@Narsil and @muellerzr

faaany avatar Apr 29 '24 08:04 faaany

@yao-matrix

faaany avatar Apr 29 '24 08:04 faaany

@faaany are we sure that model.device is a thing across all these frameworks?

At most I see ModuleUtilsMixin has device which is PyTorch specific (it gets added to AutoModel, but I'd like to verify the locations of TF and Flax backends having these capabilities to grab the model device. Otherwise we don't really want just None here IMO

muellerzr avatar Apr 29 '24 13:04 muellerzr

Thanks for adding this!

Could you add a test?

sure, in which test file should I put this test?

faaany avatar Apr 30 '24 10:04 faaany

@faaany are we sure that model.device is a thing across all these frameworks?

At most I see ModuleUtilsMixin has device which is PyTorch specific (it gets added to AutoModel, but I'd like to verify the locations of TF and Flax backends having these capabilities to grab the model device. Otherwise we don't really want just None here IMO

Good point! Yes, I know that Flax model doesn't have "device". How about moving it inside if is_torch_available() and self.framework == "pt": ? I have updated my code.

Furthermore, I removed the self.device is not None check, because it will never be None. And I also added the logic that model shouldn't be moved, if the model is already on device.

faaany avatar Apr 30 '24 10:04 faaany

Hi @amyeroberts, sorry for the late response. We had a long holiday here in China. Unit tests are added. Let me explain more about in detail:

There are 3 possibilities for model.device: a1. user passes device_map to from_pretrained a2. user doesn't pass device_map to from_pretrained a3. user manually moves the model to a certain device with to(device) after model is loaded with from_pretrained

There are 2 possibilities for pipeline.device: b1. user passes device to pipeline b2. user doesn't pass device to pipeline

Sincea2&b2 is trivial, my unit tests cover the cases a1&b1, a1&b2, a3&b1 and a3&b2. Pls have a review, thx!

faaany avatar May 11 '24 03:05 faaany

Thanks for the review! @amyeroberts @muellerzr

faaany avatar May 13 '24 13:05 faaany