lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Dtype mismatch with LitGPT and autocast

Open carmocca opened this issue 1 year ago • 5 comments

🐛 Bug

To Reproduce

import thunder
from thunder.tests.litgpt_model import GPT
import torch

device = torch.device("cuda")
with device:
    model = GPT.from_name("llama2-like")
    x = torch.randint(0, 100, (2, 5))
model = thunder.jit(model)

with torch.autocast("cuda", torch.float16):
    model(x)

print("Good!")

I believe the problem is in this cat

https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/lit_gpt_model.py#L484

Where q_roped is in float32 and the non-roped bits in float16

cc @crcrpar

carmocca avatar Mar 04 '24 19:03 carmocca

Debugging shows that qkv is float32 in https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/lit_gpt_model.py#L460 which is unexpected

carmocca avatar Mar 04 '24 19:03 carmocca

I think this might not be an autocast but more a cat problem: PyTorch will cat fp16 and fp32:

In [1]: import torch

In [2]: a = torch.randn(2, 2)

In [3]: b = torch.randn(2, 2, dtype=torch.float16)

In [4]: torch.cat([a, b])
Out[4]: 
tensor([[ 0.1820, -1.1422],
        [-0.5954,  0.8083],
        [ 0.5239, -0.6465],
        [-1.3232, -1.3340]])

t-vi avatar Mar 04 '24 21:03 t-vi

Are you sure? If I print the dtype in the model

$ git diff
diff --git a/thunder/tests/lit_gpt_model.py b/thunder/tests/lit_gpt_model.py
index 04d7cfa8..5cc44e46 100644
--- a/thunder/tests/lit_gpt_model.py
+++ b/thunder/tests/lit_gpt_model.py
@@ -458,6 +458,7 @@ class CausalSelfAttention(nn.Module):
         B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
 
         qkv = self.attn(x)
+        print(qkv.dtype)
 
         # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
         q_per_kv = self.config.n_head // self.config.n_query_groups

I get torch.float16 without Thunder enabled and float32 with Thunder

carmocca avatar Mar 04 '24 21:03 carmocca

So modifying cat to do upcasting as in PyTorch lets us run the trace, but it would insert casting ops which we would then need to undo in the autocast pass...

t-vi avatar Mar 04 '24 21:03 t-vi

I get torch.float16 without Thunder enabled and float32 with Thunder

So I think we expect the tracing itself to run without autocast and the autocast transform inserting casts as appropriate. So letting cat (and stack) upcast is a correct thing to do at any rate. However the autocast transform would need undo this. (Don't know if we can know that a.to(fp32).to(fp16) is a nop for a of dtype fp16.

t-vi avatar Mar 04 '24 22:03 t-vi