Grads for Qwen/Qwen2.5-7B-Instruct on CUDA are not close to reference results with large relative difference
🐛 Bug
- Install transformers==4.50.3 https://github.com/Lightning-AI/lightning-thunder/blob/6f6584c8f7d9a3483c92f68598e2fb04ca03a985/requirements/test.txt#L21
- Remove the skip added in https://github.com/Lightning-AI/lightning-thunder/commit/a6698fafc0fe652801312860d8d86bf3322f4f6b https://github.com/Lightning-AI/lightning-thunder/blob/6f6584c8f7d9a3483c92f68598e2fb04ca03a985/thunder/tests/test_networks.py#L425
- Run
pytest thunder/tests/test_networks.py -k "test_hf_for_nemo[Qwen/Qwen2.5-7B-Instruct]" -vvv -s
grads_ref = torch.autograd.grad(ref_loss, model.parameters(), grad_outputs=loss_grad)
grads_compiled = torch.autograd.grad(compiled_loss, model.parameters(), grad_outputs=loss_grad)
> torch.testing.assert_close(grads_ref, grads_compiled, rtol=1e-2, atol=1e-2)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 2 / 448 (0.4%)
E Greatest absolute difference: 0.01324462890625 at index (11, 14) (up to 0.01 allowed)
E Greatest relative difference: 2.09375 at index (11, 5) (up to 0.01 allowed)
E
E The failure occurred for item [0]
thunder/tests/test_networks.py:471: AssertionError
========================================================================= short test summary info ==========================================================================
FAILED thunder/tests/test_networks.py::test_hf_for_nemo[Qwen/Qwen2.5-7B-Instruct] - AssertionError: Tensor-likes are not close!
Mismatched elements: 2 / 448 (0.4%)
Greatest absolute difference: 0.01324462890625 at index (11, 14) (up to 0.01 allowed)
Greatest relative difference: 2.09375 at index (11, 5) (up to 0.01 allowed)
The failure occurred for item [0]
cc @borda @mruberry
@kiya00, how should I use the reproducer tools you developed recently to isolate where the problems occur?
Hi Ivan, With a patch (I'll commit this):
wayan@3d0e6ea-lcedt:~/lightning-thunder$ git diff thunder/dynamo/report.py
diff --git a/thunder/dynamo/report.py b/thunder/dynamo/report.py
index c8af6223..a1480269 100644
--- a/thunder/dynamo/report.py
+++ b/thunder/dynamo/report.py
@@ -1351,4 +1351,4 @@ def save_failing_repros(
report.run_repro(compile_fn, check_consistency)
except Exception as e:
comment = f"Failed to run the function using {compile_fn.name} with exception: {e}"
- report.write_repro(repros_folder, compile_fn, extra_comment_str=comment)
+ report.write_repro(repros_folder, compile_fn, extra_comment_str=comment, check_consistency=check_consistency)
and add the following line in the test_hf_for_nemo:
from thunder.dynamo.report import get_thunder_fxgraph_reports, save_failing_repros
thunder_fxgraph_reports = get_thunder_fxgraph_reports(model, fullgraph=fullgraph)(input_ids=input_ids, labels=input_ids)
# assert len(thunder_fxgraph_reports) == 1
from thunder.dynamo.benchmark_utils import ThunderCompileSpecification
save_failing_repros(thunder_fxgraph_reports[0].subgraph_reports, ThunderCompileSpecification(), "failrepros", check_consistency=True)
I've saved a script that might be useful:
graph0_thunder_0.py
from math import inf
from math import nan
NoneType = type(None)
import torch
from torch import device
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
import thunder
class DynamoModule(torch.nn.Module):
def forward(self, l_kwargs_input_ids_: "i64[1, 32][32, 1]cuda:0", l_args_0_modules_model_modules_embed_tokens_parameters_weight_: "bf16[16, 28][28, 1]cuda:0", l_args_0_modules_model_modules_rotary_emb_buffers_inv_freq_: "bf16[1][1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_: "bf16[28][1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_: "bf16[28, 28][28, 1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_bias_: "bf16[28][1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_: "bf16[4, 28][28, 1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_bias_: "bf16[4][1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_: "bf16[4, 28][28, 1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_bias_: "bf16[4][1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_: "bf16[28, 28][28, 1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_: "bf16[28][1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_: "bf16[18944, 28][28, 1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_: "bf16[18944, 28][28, 1]cuda:0", l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_: "bf16[28, 18944][18944, 1]cuda:0", l_args_0_modules_model_modules_norm_parameters_weight_: "bf16[28][1]cuda:0", l_args_0_modules_lm_head_parameters_weight_: "bf16[16, 28][28, 1]cuda:0"):
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:535 in forward, code: inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds: "bf16[1, 32, 28][896, 28, 1]cuda:0" = torch.nn.functional.embedding(l_kwargs_input_ids_, l_args_0_modules_model_modules_embed_tokens_parameters_weight_, 15, None, 2.0, False, False); l_args_0_modules_model_modules_embed_tokens_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:542 in forward, code: cache_position = torch.arange(
cache_position: "i64[32][1]cuda:0" = torch.arange(0, 32, device = device(type='cuda', index=0))
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:547 in forward, code: position_ids = cache_position.unsqueeze(0)
position_ids: "i64[1, 32][32, 1]cuda:0" = cache_position.unsqueeze(0); cache_position = None
# No stacktrace found for following nodes
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:329 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
getitem: "bf16[1, 1, 1][1, 1, 1]cuda:0" = l_args_0_modules_model_modules_rotary_emb_buffers_inv_freq_[(None, slice(None, None, None), None)]; l_args_0_modules_model_modules_rotary_emb_buffers_inv_freq_ = None
float_1: "f32[1, 1, 1][1, 1, 1]cuda:0" = getitem.float(); getitem = None
inv_freq_expanded: "f32[1, 1, 1][1, 1, 1]cuda:0" = float_1.expand(1, -1, 1); float_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:330 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
getitem_1: "i64[1, 1, 32][32, 32, 1]cuda:0" = position_ids[(slice(None, None, None), None, slice(None, None, None))]; position_ids = None
position_ids_expanded: "f32[1, 1, 32][32, 32, 1]cuda:0" = getitem_1.float(); getitem_1 = None
# No stacktrace found for following nodes
_enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', None, False, None)
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:335 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
float_3: "f32[1, 1, 1][1, 1, 1]cuda:0" = inv_freq_expanded.float(); inv_freq_expanded = None
to: "f32[1, 1, 1][1, 1, 1]cuda:0" = float_3.to(device(type='cuda', index=0)); float_3 = None
float_4: "f32[1, 1, 32][32, 32, 1]cuda:0" = position_ids_expanded.float(); position_ids_expanded = None
matmul: "f32[1, 1, 32][32, 32, 1]cuda:0" = to @ float_4; to = float_4 = None
freqs: "f32[1, 32, 1][32, 1, 32]cuda:0" = matmul.transpose(1, 2); matmul = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:336 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
emb: "f32[1, 32, 2][64, 2, 1]cuda:0" = torch.cat((freqs, freqs), dim = -1); freqs = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:337 in forward, code: cos = emb.cos()
cos: "f32[1, 32, 2][64, 2, 1]cuda:0" = emb.cos()
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:338 in forward, code: sin = emb.sin()
sin: "f32[1, 32, 2][64, 2, 1]cuda:0" = emb.sin(); emb = None
# No stacktrace found for following nodes
_exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:341 in forward, code: cos = cos * self.attention_scaling
cos_1: "f32[1, 32, 2][64, 2, 1]cuda:0" = cos * 1.0; cos = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:342 in forward, code: sin = sin * self.attention_scaling
sin_1: "f32[1, 32, 2][64, 2, 1]cuda:0" = sin * 1.0; sin = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:344 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
cos_2: "bf16[1, 32, 2][64, 2, 1]cuda:0" = cos_1.to(dtype = torch.bfloat16); cos_1 = None
sin_2: "bf16[1, 32, 2][64, 2, 1]cuda:0" = sin_1.to(dtype = torch.bfloat16); sin_1 = None
# No stacktrace found for following nodes
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
# File: /usr/local/lib/python3.12/dist-packages/torch/_dynamo/polyfills/__init__.py:157 in instantiate_user_defined_class_object, code: obj.__init__(*args, **kwargs)
_log_api_usage_once = torch._C._log_api_usage_once('python.nn_module'); _log_api_usage_once = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:220 in forward, code: hidden_states = hidden_states.to(torch.float32)
hidden_states: "f32[1, 32, 28][896, 28, 1]cuda:0" = inputs_embeds.to(torch.float32)
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:221 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states.pow(2)
variance: "f32[1, 32, 1][32, 1, 1]cuda:0" = pow_1.mean(-1, keepdim = True); pow_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:222 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add: "f32[1, 32, 1][32, 1, 1]cuda:0" = variance + 1e-06; variance = None
rsqrt: "f32[1, 32, 1][32, 1, 1]cuda:0" = torch.rsqrt(add); add = None
hidden_states_1: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states * rsqrt; hidden_states = rsqrt = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:223 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_4: "bf16[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_1.to(torch.bfloat16); hidden_states_1 = None
hidden_states_2: "bf16[1, 32, 28][896, 28, 1]cuda:0" = l_args_0_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ * to_4; l_args_0_modules_model_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = to_4 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:162 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
linear: "bf16[1, 32, 28][896, 28, 1]cuda:0" = torch._C._nn.linear(hidden_states_2, l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_, l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_bias_); l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_ = l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_bias_ = None
view: "bf16[1, 32, 28, 1][896, 28, 1, 1]cuda:0" = linear.view((1, 32, -1, 1)); linear = None
query_states: "bf16[1, 28, 32, 1][896, 1, 28, 1]cuda:0" = view.transpose(1, 2); view = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:163 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
linear_1: "bf16[1, 32, 4][128, 4, 1]cuda:0" = torch._C._nn.linear(hidden_states_2, l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_, l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_bias_); l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_ = l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_bias_ = None
view_1: "bf16[1, 32, 4, 1][128, 4, 1, 1]cuda:0" = linear_1.view((1, 32, -1, 1)); linear_1 = None
key_states: "bf16[1, 4, 32, 1][128, 1, 4, 1]cuda:0" = view_1.transpose(1, 2); view_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:164 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
linear_2: "bf16[1, 32, 4][128, 4, 1]cuda:0" = torch._C._nn.linear(hidden_states_2, l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_, l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_bias_); hidden_states_2 = l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_weight_ = l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_v_proj_parameters_bias_ = None
view_2: "bf16[1, 32, 4, 1][128, 4, 1, 1]cuda:0" = linear_2.view((1, 32, -1, 1)); linear_2 = None
value_states: "bf16[1, 4, 32, 1][128, 1, 4, 1]cuda:0" = view_2.transpose(1, 2); view_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:88 in apply_rotary_pos_emb, code: cos = cos.unsqueeze(unsqueeze_dim)
cos_3: "bf16[1, 1, 32, 2][64, 64, 2, 1]cuda:0" = cos_2.unsqueeze(1); cos_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:89 in apply_rotary_pos_emb, code: sin = sin.unsqueeze(unsqueeze_dim)
sin_3: "bf16[1, 1, 32, 2][64, 64, 2, 1]cuda:0" = sin_2.unsqueeze(1); sin_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:90 in apply_rotary_pos_emb, code: q_embed = (q * cos) + (rotate_half(q) * sin)
mul_4: "bf16[1, 28, 32, 2][1792, 2, 56, 1]cuda:0" = query_states * cos_3
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:63 in rotate_half, code: x1 = x[..., : x.shape[-1] // 2]
x1: "bf16[1, 28, 32, 0][896, 1, 28, 1]cuda:0" = query_states[(Ellipsis, slice(None, 0, None))]
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:64 in rotate_half, code: x2 = x[..., x.shape[-1] // 2 :]
x2: "bf16[1, 28, 32, 1][896, 1, 28, 1]cuda:0" = query_states[(Ellipsis, slice(0, None, None))]; query_states = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:65 in rotate_half, code: return torch.cat((-x2, x1), dim=-1)
neg: "bf16[1, 28, 32, 1][896, 1, 28, 28]cuda:0" = -x2; x2 = None
cat_1: "bf16[1, 28, 32, 1][896, 32, 1, 1]cuda:0" = torch.cat((neg, x1), dim = -1); neg = x1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:90 in apply_rotary_pos_emb, code: q_embed = (q * cos) + (rotate_half(q) * sin)
mul_5: "bf16[1, 28, 32, 2][1792, 64, 2, 1]cuda:0" = cat_1 * sin_3; cat_1 = None
q_embed: "bf16[1, 28, 32, 2][1792, 2, 56, 1]cuda:0" = mul_4 + mul_5; mul_4 = mul_5 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:91 in apply_rotary_pos_emb, code: k_embed = (k * cos) + (rotate_half(k) * sin)
mul_6: "bf16[1, 4, 32, 2][256, 2, 8, 1]cuda:0" = key_states * cos_3; cos_3 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:63 in rotate_half, code: x1 = x[..., : x.shape[-1] // 2]
x1_1: "bf16[1, 4, 32, 0][128, 1, 4, 1]cuda:0" = key_states[(Ellipsis, slice(None, 0, None))]
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:64 in rotate_half, code: x2 = x[..., x.shape[-1] // 2 :]
x2_1: "bf16[1, 4, 32, 1][128, 1, 4, 1]cuda:0" = key_states[(Ellipsis, slice(0, None, None))]; key_states = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:65 in rotate_half, code: return torch.cat((-x2, x1), dim=-1)
neg_1: "bf16[1, 4, 32, 1][128, 1, 4, 4]cuda:0" = -x2_1; x2_1 = None
cat_2: "bf16[1, 4, 32, 1][128, 32, 1, 1]cuda:0" = torch.cat((neg_1, x1_1), dim = -1); neg_1 = x1_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:91 in apply_rotary_pos_emb, code: k_embed = (k * cos) + (rotate_half(k) * sin)
mul_7: "bf16[1, 4, 32, 2][256, 64, 2, 1]cuda:0" = cat_2 * sin_3; cat_2 = sin_3 = None
k_embed: "bf16[1, 4, 32, 2][256, 2, 8, 1]cuda:0" = mul_6 + mul_7; mul_6 = mul_7 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:14 in repeat_kv, code: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
getitem_6: "bf16[1, 4, 1, 32, 2][256, 2, 256, 8, 1]cuda:0" = k_embed[(slice(None, None, None), slice(None, None, None), None, slice(None, None, None), slice(None, None, None))]
hidden_states_3: "bf16[1, 4, 7, 32, 2][256, 2, 0, 8, 1]cuda:0" = getitem_6.expand(1, 4, 7, 32, 2); getitem_6 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:15 in repeat_kv, code: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
key: "bf16[1, 28, 32, 2][1792, 64, 2, 1]cuda:0" = hidden_states_3.reshape(1, 28, 32, 2); hidden_states_3 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:14 in repeat_kv, code: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
getitem_7: "bf16[1, 4, 1, 32, 1][128, 1, 128, 4, 1]cuda:0" = value_states[(slice(None, None, None), slice(None, None, None), None, slice(None, None, None), slice(None, None, None))]
hidden_states_4: "bf16[1, 4, 7, 32, 1][128, 1, 0, 4, 1]cuda:0" = getitem_7.expand(1, 4, 7, 32, 1); getitem_7 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:15 in repeat_kv, code: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
value: "bf16[1, 28, 32, 1][896, 32, 1, 1]cuda:0" = hidden_states_4.reshape(1, 28, 32, 1); hidden_states_4 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:39 in sdpa_attention_forward, code: query = query.contiguous()
query: "bf16[1, 28, 32, 2][1792, 64, 2, 1]cuda:0" = q_embed.contiguous(); q_embed = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:40 in sdpa_attention_forward, code: key = key.contiguous()
key_1: "bf16[1, 28, 32, 2][1792, 64, 2, 1]cuda:0" = key.contiguous(); key = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:41 in sdpa_attention_forward, code: value = value.contiguous()
value_1: "bf16[1, 28, 32, 1][896, 32, 1, 1]cuda:0" = value.contiguous(); value = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:54 in sdpa_attention_forward, code: attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output: "bf16[1, 28, 32, 1][896, 32, 1, 1]cuda:0" = torch._C._nn.scaled_dot_product_attention(query, key_1, value_1, attn_mask = None, dropout_p = 0.0, scale = 1.0, is_causal = True); query = key_1 = value_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py:63 in sdpa_attention_forward, code: attn_output = attn_output.transpose(1, 2).contiguous()
transpose_4: "bf16[1, 32, 28, 1][896, 1, 32, 1]cuda:0" = attn_output.transpose(1, 2); attn_output = None
attn_output_1: "bf16[1, 32, 28, 1][896, 28, 1, 1]cuda:0" = transpose_4.contiguous(); transpose_4 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:204 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_2: "bf16[1, 32, 28][896, 28, 1]cuda:0" = attn_output_1.reshape(1, 32, -1); attn_output_1 = None
attn_output_2: "bf16[1, 32, 28][896, 28, 1]cuda:0" = reshape_2.contiguous(); reshape_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:205 in forward, code: attn_output = self.o_proj(attn_output)
attn_output_3: "bf16[1, 32, 28][896, 28, 1]cuda:0" = torch._C._nn.linear(attn_output_2, l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, None); attn_output_2 = l_args_0_modules_model_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:271 in forward, code: hidden_states = residual + hidden_states
hidden_states_5: "bf16[1, 32, 28][896, 28, 1]cuda:0" = inputs_embeds + attn_output_3; inputs_embeds = attn_output_3 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:220 in forward, code: hidden_states = hidden_states.to(torch.float32)
hidden_states_6: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_5.to(torch.float32)
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:221 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_6.pow(2)
variance_1: "f32[1, 32, 1][32, 1, 1]cuda:0" = pow_2.mean(-1, keepdim = True); pow_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:222 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_4: "f32[1, 32, 1][32, 1, 1]cuda:0" = variance_1 + 1e-06; variance_1 = None
rsqrt_1: "f32[1, 32, 1][32, 1, 1]cuda:0" = torch.rsqrt(add_4); add_4 = None
hidden_states_7: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_6 * rsqrt_1; hidden_states_6 = rsqrt_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:223 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_6: "bf16[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_7.to(torch.bfloat16); hidden_states_7 = None
hidden_states_8: "bf16[1, 32, 28][896, 28, 1]cuda:0" = l_args_0_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ * to_6; l_args_0_modules_model_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = to_6 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:57 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
linear_4: "bf16[1, 32, 18944][606208, 18944, 1]cuda:0" = torch._C._nn.linear(hidden_states_8, l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_, None); l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_gate_proj_parameters_weight_ = None
silu: "bf16[1, 32, 18944][606208, 18944, 1]cuda:0" = torch.nn.functional.silu(linear_4, inplace = False); linear_4 = None
linear_5: "bf16[1, 32, 18944][606208, 18944, 1]cuda:0" = torch._C._nn.linear(hidden_states_8, l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_, None); hidden_states_8 = l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_up_proj_parameters_weight_ = None
mul_10: "bf16[1, 32, 18944][606208, 18944, 1]cuda:0" = silu * linear_5; silu = linear_5 = None
down_proj: "bf16[1, 32, 28][896, 28, 1]cuda:0" = torch._C._nn.linear(mul_10, l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_, None); mul_10 = l_args_0_modules_model_modules_layers_modules_0_modules_mlp_modules_down_proj_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:277 in forward, code: hidden_states = residual + hidden_states
hidden_states_9: "bf16[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_5 + down_proj; hidden_states_5 = down_proj = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:220 in forward, code: hidden_states = hidden_states.to(torch.float32)
hidden_states_10: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_9.to(torch.float32); hidden_states_9 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:221 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_10.pow(2)
variance_2: "f32[1, 32, 1][32, 1, 1]cuda:0" = pow_3.mean(-1, keepdim = True); pow_3 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:222 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_6: "f32[1, 32, 1][32, 1, 1]cuda:0" = variance_2 + 1e-06; variance_2 = None
rsqrt_2: "f32[1, 32, 1][32, 1, 1]cuda:0" = torch.rsqrt(add_6); add_6 = None
hidden_states_11: "f32[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_10 * rsqrt_2; hidden_states_10 = rsqrt_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:223 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_8: "bf16[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_11.to(torch.bfloat16); hidden_states_11 = None
hidden_states_12: "bf16[1, 32, 28][896, 28, 1]cuda:0" = l_args_0_modules_model_modules_norm_parameters_weight_ * to_8; l_args_0_modules_model_modules_norm_parameters_weight_ = to_8 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/qwen2/modeling_qwen2.py:872 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
getitem_8: "bf16[1, 32, 28][896, 28, 1]cuda:0" = hidden_states_12[(slice(None, None, None), slice(0, None, None), slice(None, None, None))]; hidden_states_12 = None
logits: "bf16[1, 32, 16][512, 16, 1]cuda:0" = torch._C._nn.linear(getitem_8, l_args_0_modules_lm_head_parameters_weight_, None); getitem_8 = l_args_0_modules_lm_head_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/loss/loss_utils.py:43 in ForCausalLMLoss, code: logits = logits.float()
logits_1: "f32[1, 32, 16][512, 16, 1]cuda:0" = logits.float()
# File: /usr/local/lib/python3.12/dist-packages/transformers/loss/loss_utils.py:46 in ForCausalLMLoss, code: labels = labels.to(logits.device)
labels: "i64[1, 32][32, 1]cuda:0" = l_kwargs_input_ids_.to(device(type='cuda', index=0)); l_kwargs_input_ids_ = None
return (labels, logits_1, value_states, k_embed, logits)
def test_graph0_thunder_0():
inputs = [
torch.testing.make_tensor((1, 32), dtype=torch.int64, device='cuda:0', requires_grad=False, low=0, high=15,),
torch.testing.make_tensor((16, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.06982421875, high=0.05908203125,),
torch.full((1,), 1.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=False, layout=torch.strided),
torch.full((28,), 1.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=True, layout=torch.strided),
torch.testing.make_tensor((28, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.059326171875, high=0.0693359375,),
torch.full((28,), 0.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=True, layout=torch.strided),
torch.testing.make_tensor((4, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.047119140625, high=0.048583984375,),
torch.full((4,), 0.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=True, layout=torch.strided),
torch.testing.make_tensor((4, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.06201171875, high=0.055419921875,),
torch.full((4,), 0.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=True, layout=torch.strided),
torch.testing.make_tensor((28, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.055419921875, high=0.0625,),
torch.full((28,), 1.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=True, layout=torch.strided),
torch.testing.make_tensor((18944, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.09130859375, high=0.0966796875,),
torch.testing.make_tensor((18944, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.1015625, high=0.0927734375,),
torch.testing.make_tensor((28, 18944), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.09130859375, high=0.0869140625,),
torch.full((28,), 1.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=True, layout=torch.strided),
torch.testing.make_tensor((16, 28), dtype=torch.bfloat16, device='cuda:0', requires_grad=True, low=-0.0576171875, high=0.05712890625,),
]
model = DynamoModule()
compiled_model = thunder.jit(model)
from thunder.dynamo.report import run_forward_backward
fwd_result, grads = run_forward_backward(compiled_model, *inputs)
eager_fwd_result, eager_grads = run_forward_backward(model, *inputs)
torch.testing.assert_close(fwd_result, eager_fwd_result)
torch.testing.assert_close(grads, eager_grads)
if __name__ == "__main__":
test_graph0_thunder_0()
"""
Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:
CUDA devices:
0: NVIDIA RTX 6000 Ada Generation
1: NVIDIA RTX 6000 Ada Generation
CUDA version: 12.9
numpy==1.26.4
nvidia-cudnn-frontend==1.11.0
optree==0.15.0
optree==0.15.0
pytorch-lightning==2.5.1.post0
pytorch-triton==3.3.0+git96316ce5.nvinternal
torch==2.8.0a0+5228986c39.nvinternal
torchmetrics==1.7.1
torchvision==0.22.0a0
Versions of Thunder related libraries:
lightning-thunder==0.2.3.dev0
nvfuser==0.2.27+git07effe8
Failed to run the function using thunder with exception: Tensor-likes are not close!
Mismatched elements: 390 / 512 (76.2%)
Greatest absolute difference: 0.00390625 at index (0, 23, 12) (up to 1e-05 allowed)
Greatest relative difference: 3.0 at index (0, 11, 4) (up to 1.3e-06 allowed)
The failure occurred for item [0][1]
"""
Note that you may hit the random error of the following, which is the similar reason of #2025
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Loss.cu:245: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.
The test_hf_for_nemo has been updated. When running pytest thunder/tests/test_networks.py -k test_hf_for_nemo[qwen2] -vs with both transformers==4.50.3 and the current required version 4.52.4, there are no accuracy issues observed.
Given this, do we still need to revisit the old environment to investigate the previous problem? @IvanYashchuk