LLMSpeculativeSampling icon indicating copy to clipboard operation
LLMSpeculativeSampling copied to clipboard

output logits not match. question about decoding when draft model and target model is the same.

Open 66RING opened this issue 10 months ago • 4 comments

In my opinion, the generation should be the same when draft model and target model is the same and temparature is 0.

But in this case, the output logits of draft model and target model have a bit difference. But the argmax result is the same.

THE QUESTION: why is the output logits difference when the draft and target is the same model.

reproduce:

  1. change code as shown below: compare output's logits directly p[:, prefix_len + i - 1, j] == q[:, prefix_len + i - 1, j]
diff --git a/sampling/speculative_sampling.py b/sampling/speculative_sampling.py
index 48e1f8d..c2eed70 100644
--- a/sampling/speculative_sampling.py
+++ b/sampling/speculative_sampling.py
@@ -164,10 +164,12 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul
                 r = torch.rand(1, device = p.device)
                 j = x[:, prefix_len + i]
                 
-                if r < torch.min(torch.tensor([1], device=q.device), p[:, prefix_len + i - 1, j] / q[:, prefix_len + i - 1, j]):
+                if p[:, prefix_len + i - 1, j] == q[:, prefix_len + i - 1, j]:
                     # accept, and update n
                     n += 1
                 else:
+                    print(p[:, prefix_len + i - 1, j] - q[:, prefix_len + i - 1, j])
+                    print("unexpected reject!")
                     # reject
                     t = sample(max_fn(p[:, n, :] - q[:, n, :]))
                     is_all_accept = False

  1. launch script:
python main.py \
    --input "One day, Lily met a Shoggoth." \
    --max_tokens 128 \
    --benchmark \
    --target_model_name nickypro/tinyllama-110M \
    --approx_model_name nickypro/tinyllama-110M \
  1. "unexpected reject!" get print
speculative sampling:   0%|                                                                          | 0/141 [00:00<?, ?it/s]
tensor([[-5.2199e-05]], device='cuda:0')
unexpected reject!
tensor([[0.0019]], device='cuda:0')
unexpected reject!
speculative sampling:  19%|████████████▎                                                   | 27/141 [00:00<00:00, 205.92it/s]
tensor([[0.0006]], device='cuda:0')
unexpected reject!
tensor([[-0.0019]], device='cuda:0')
unexpected reject!
speculative sampling:  35%|██████████████████████▏                                         | 49/141 [00:00<00:00, 132.50it/s]
tensor([[0.0006]], device='cuda:0')
unexpected reject!
speculative sampling:  99%|███████████████████████████████████████████████████████████████▌| 140/141 [00:01<00:00, 98.67it/s]

66RING avatar Apr 08 '24 13:04 66RING