LLMSpeculativeSampling
LLMSpeculativeSampling copied to clipboard
output logits not match. question about decoding when draft model and target model is the same.
In my opinion, the generation should be the same when draft model and target model is the same and temparature is 0.
But in this case, the output logits of draft model and target model have a bit difference. But the argmax result is the same.
THE QUESTION: why is the output logits difference when the draft and target is the same model.
reproduce:
- change code as shown below: compare output's logits directly
p[:, prefix_len + i - 1, j] == q[:, prefix_len + i - 1, j]
diff --git a/sampling/speculative_sampling.py b/sampling/speculative_sampling.py
index 48e1f8d..c2eed70 100644
--- a/sampling/speculative_sampling.py
+++ b/sampling/speculative_sampling.py
@@ -164,10 +164,12 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul
r = torch.rand(1, device = p.device)
j = x[:, prefix_len + i]
- if r < torch.min(torch.tensor([1], device=q.device), p[:, prefix_len + i - 1, j] / q[:, prefix_len + i - 1, j]):
+ if p[:, prefix_len + i - 1, j] == q[:, prefix_len + i - 1, j]:
# accept, and update n
n += 1
else:
+ print(p[:, prefix_len + i - 1, j] - q[:, prefix_len + i - 1, j])
+ print("unexpected reject!")
# reject
t = sample(max_fn(p[:, n, :] - q[:, n, :]))
is_all_accept = False
- launch script:
python main.py \
--input "One day, Lily met a Shoggoth." \
--max_tokens 128 \
--benchmark \
--target_model_name nickypro/tinyllama-110M \
--approx_model_name nickypro/tinyllama-110M \
- "unexpected reject!" get print
speculative sampling: 0%| | 0/141 [00:00<?, ?it/s]
tensor([[-5.2199e-05]], device='cuda:0')
unexpected reject!
tensor([[0.0019]], device='cuda:0')
unexpected reject!
speculative sampling: 19%|████████████▎ | 27/141 [00:00<00:00, 205.92it/s]
tensor([[0.0006]], device='cuda:0')
unexpected reject!
tensor([[-0.0019]], device='cuda:0')
unexpected reject!
speculative sampling: 35%|██████████████████████▏ | 49/141 [00:00<00:00, 132.50it/s]
tensor([[0.0006]], device='cuda:0')
unexpected reject!
speculative sampling: 99%|███████████████████████████████████████████████████████████████▌| 140/141 [00:01<00:00, 98.67it/s]