Dtype mismatch with LitGPT and autocast
🐛 Bug
To Reproduce
import thunder
from thunder.tests.litgpt_model import GPT
import torch
device = torch.device("cuda")
with device:
model = GPT.from_name("llama2-like")
x = torch.randint(0, 100, (2, 5))
model = thunder.jit(model)
with torch.autocast("cuda", torch.float16):
model(x)
print("Good!")
I believe the problem is in this cat
https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/lit_gpt_model.py#L484
Where q_roped is in float32 and the non-roped bits in float16
cc @crcrpar
Debugging shows that qkv is float32 in https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/lit_gpt_model.py#L460 which is unexpected
I think this might not be an autocast but more a cat problem: PyTorch will cat fp16 and fp32:
In [1]: import torch
In [2]: a = torch.randn(2, 2)
In [3]: b = torch.randn(2, 2, dtype=torch.float16)
In [4]: torch.cat([a, b])
Out[4]:
tensor([[ 0.1820, -1.1422],
[-0.5954, 0.8083],
[ 0.5239, -0.6465],
[-1.3232, -1.3340]])
Are you sure? If I print the dtype in the model
$ git diff
diff --git a/thunder/tests/lit_gpt_model.py b/thunder/tests/lit_gpt_model.py
index 04d7cfa8..5cc44e46 100644
--- a/thunder/tests/lit_gpt_model.py
+++ b/thunder/tests/lit_gpt_model.py
@@ -458,6 +458,7 @@ class CausalSelfAttention(nn.Module):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
qkv = self.attn(x)
+ print(qkv.dtype)
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
I get torch.float16 without Thunder enabled and float32 with Thunder
So modifying cat to do upcasting as in PyTorch lets us run the trace, but it would insert casting ops which we would then need to undo in the autocast pass...
I get torch.float16 without Thunder enabled and float32 with Thunder
So I think we expect the tracing itself to run without autocast and the autocast transform inserting casts as appropriate.
So letting cat (and stack) upcast is a correct thing to do at any rate. However the autocast transform would need undo this. (Don't know if we can know that a.to(fp32).to(fp16) is a nop for a of dtype fp16.