transformers
transformers copied to clipboard
fix assisted decoding
Hi @gante . This PR is to fix the assisted decoding when the model and assistant model are on different devices.
It can be easily reproduced by:
model = model.to("cuda")
model.generate(**inputs, assistant_model=assistant_model.to("cpu"))
The failed CIs seem not related to my changes
Hi @gante . Sorry for not making it clear. Could you run this script:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-chat-hf"
assistant_model_id = "Felladrin/Llama-68M-Chat-v1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id, torch_dtype=torch.bfloat16).to("cpu")
prompt = "Assisted decoding is"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
model.generate(**inputs, assistant_model=assistant_model, max_new_tokens=8, min_new_tokens=8, do_sample=False)
It will get the error Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!.
Full traceback
Traceback (most recent call last):
File "/workspace/jiqing/hete_specdecode/test_assisted.py", line 16, in <module>
model.generate(**inputs, assistant_model=assistant_model, max_new_tokens=8, min_new_tokens=8, do_sample=False)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 1853, in generate
result = self._assisted_decoding(
File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 3698, in _assisted_decoding
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
File "/workspace/jiqing/transformers/src/transformers/generation/candidate_generator.py", line 229, in get_candidates
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 1896, in generate
result = self._sample(
File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 2648, in _sample
next_token_scores = logits_processor(input_ids, next_token_logits)
File "/workspace/jiqing/transformers/src/transformers/generation/logits_process.py", line 98, in __call__
scores = processor(input_ids, scores)
File "/workspace/jiqing/transformers/src/transformers/generation/logits_process.py", line 157, in __call__
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument test_elements in method wrapper_CUDA_isin_Tensor_Tensor)
HI @gante . I just found the real issue happens here, pls take a review. Thx!
I would like to add a test for this. Do you know where I should add this test? Thx!
This makes sense, thank you for digging deeper and iterating @jiqing-feng ! 💛
Regarding tests: it's a bit tricky to test two devices on our CI AFAIK 🤔 @amyeroberts do you have suggestions on how to test it? [TL;DR @jiqing-feng found that assisted generation fails if the two models are on different devices, because the special tokens are copied from the main model to the assistant model]
I think we can just run the test on a device with GPU; there is almost no limitation for CPU because we can run a very tiny model on CPU just for functionality.
Regarding tests: it's a bit tricky to test two devices on our CI AFAIK 🤔 @amyeroberts do you have suggestions on how to test it? [TL;DR @jiqing-feng found that assisted generation fails if the two models are on different devices, because the special tokens are copied from the main model to the assistant model]
@gante There's certain tests in our suite which require multiple devices e.g. test_model_parallelization, which we can denote with the require_torch_multi_accelerator and require_torch_multi_gpu decorators.
In this case, I'd suggest having two tests, one for the single accelerator case, and another which only runs in the multi device case.
derp, ofc a GPU is enough (which has a CPU paired up), what a brain fart on my end :D
@jiqing-feng could you add two tests like the script in this comment of yours to this file? More precisely:
- Inside
GenerationIntegrationTests; - Using the
@slowdecorator; - One of the tests with the
@require_torch_multi_gpudecorator with each model in a different gpu, another with@require_torch_gpuwith the assistant on cpu - Let's use one of our tiny test models like
hf-internal-testing/tiny-random-MistralForCausalLM(as both main model and assistant)
Hi @gante . I have added the tests, could you please take a review? Thx!
BTW, the failed CIs seem not related to my changes
Hi @amyeroberts. Could you please take a review? The failed CIs are not related to my changes :)
@jiqing-feng Regarding the failing tests, could you rebase on main to include upstream changes? This should resolve the failures on CI
Could you also run and share the output of executing the following in a multi-gpu environment:
RUN_SLOW=1 pytest -k "test_assisted_decoding_in_different_gpu or test_assisted_decoding_in_different_gpu"
@jiqing-feng rebasing the PR should get CI green 🤗
Hi @amyeroberts . I run the 2 tests individually and got passed, see
I also run your command and got the following output
These failed tests are due to some import error:
Hi @amyeroberts . Do you need more actions before merging? Please let me know, thx!
Hi @amyeroberts @gante . I think this PR should be ready to merge :)
@jiqing-feng OK, sorry, I think I messed up with the pytest command. Could you try this instead:
RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_different_gpu
RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_gpu_cpu
@jiqing-feng OK, sorry, I think I messed up with the pytest command. Could you try this instead:
RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_different_gpu RUN_SLOW=1 pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_in_gpu_cpu
All passed
Hi @amyeroberts . The failed CIs are not relate to my changes, would you please review my changes?
Hi @amyeroberts @gante , would you please help to merge this PR? Thx!
Hi @jiqing-feng, we had to wait for somethings to be resolved upstream and to wait for a new CI run (which I triggered last night)