lightning-thunder
lightning-thunder copied to clipboard
A mechanism for easily calling a module's traces
🚀 Feature
Motivation
When debugging programs, you can easily get a modules trace with thunder.last_traces(compiled_model)
However, calling it is not simple because traces can be thousand of lines long and take a lot of arguments. For example with lit-gpt:
def forward(cos, sin, mask_cache, transformer_wte_weight, transformer_h_0_norm_1_weight, transformer_h_0_attn_attn_weight, transformer_h_0_attn_kv_cache_k, transformer_h_0_attn_kv_cache_v, transformer_h_0_attn_proj_weight, transformer_h_0_norm_2_weight, transformer_h_0_mlp_fc_1_weight, transformer_h_0_mlp_fc_2_weight, transformer_h_0_mlp_proj_weight, transformer_h_1_norm_1_weight, transformer_h_1_attn_attn_weight, transformer_h_1_attn_kv_cache_k, transformer_h_1_attn_kv_cache_v, transformer_h_1_attn_proj_weight, transformer_h_1_norm_2_weight, transformer_h_1_mlp_fc_1_weight, transformer_h_1_mlp_fc_2_weight, transformer_h_1_mlp_proj_weight, transformer_h_2_norm_1_weight, transformer_h_2_attn_attn_weight, transformer_h_2_attn_kv_cache_k, transformer_h_2_attn_kv_cache_v, transformer_h_2_attn_proj_weight, transformer_h_2_norm_2_weight, transformer_h_2_mlp_fc_1_weight, transformer_h_2_mlp_fc_2_weight, transformer_h_2_mlp_proj_weight, transformer_h_3_norm_1_weight, transformer_h_3_attn_attn_weight, transformer_h_3_attn_kv_cache_k, transformer_h_3_attn_kv_cache_v, transformer_h_3_attn_proj_weight, transformer_h_3_norm_2_weight, transformer_h_3_mlp_fc_1_weight, transformer_h_3_mlp_fc_2_weight, transformer_h_3_mlp_proj_weight, transformer_h_4_norm_1_weight, transformer_h_4_attn_attn_weight, transformer_h_4_attn_kv_cache_k, transformer_h_4_attn_kv_cache_v, transformer_h_4_attn_proj_weight, transformer_h_4_norm_2_weight, transformer_h_4_mlp_fc_1_weight, transformer_h_4_mlp_fc_2_weight, transformer_h_4_mlp_proj_weight, transformer_h_5_norm_1_weight, transformer_h_5_attn_attn_weight, transformer_h_5_attn_kv_cache_k, transformer_h_5_attn_kv_cache_v, transformer_h_5_attn_proj_weight, transformer_h_5_norm_2_weight, transformer_h_5_mlp_fc_1_weight, transformer_h_5_mlp_fc_2_weight, transformer_h_5_mlp_proj_weight, transformer_h_6_norm_1_weight, transformer_h_6_attn_attn_weight, transformer_h_6_attn_kv_cache_k, transformer_h_6_attn_kv_cache_v, transformer_h_6_attn_proj_weight, transformer_h_6_norm_2_weight, transformer_h_6_mlp_fc_1_weight, transformer_h_6_mlp_fc_2_weight, transformer_h_6_mlp_proj_weight, transformer_h_7_norm_1_weight, transformer_h_7_attn_attn_weight, transformer_h_7_attn_kv_cache_k, transformer_h_7_attn_kv_cache_v, transformer_h_7_attn_proj_weight, transformer_h_7_norm_2_weight, transformer_h_7_mlp_fc_1_weight, transformer_h_7_mlp_fc_2_weight, transformer_h_7_mlp_proj_weight, transformer_h_8_norm_1_weight, transformer_h_8_attn_attn_weight, transformer_h_8_attn_kv_cache_k, transformer_h_8_attn_kv_cache_v, transformer_h_8_attn_proj_weight, transformer_h_8_norm_2_weight, transformer_h_8_mlp_fc_1_weight, transformer_h_8_mlp_fc_2_weight, transformer_h_8_mlp_proj_weight, transformer_h_9_norm_1_weight, transformer_h_9_attn_attn_weight, transformer_h_9_attn_kv_cache_k, transformer_h_9_attn_kv_cache_v, transformer_h_9_attn_proj_weight, transformer_h_9_norm_2_weight, transformer_h_9_mlp_fc_1_weight, transformer_h_9_mlp_fc_2_weight, transformer_h_9_mlp_proj_weight, transformer_h_10_norm_1_weight, transformer_h_10_attn_attn_weight, transformer_h_10_attn_kv_cache_k, transformer_h_10_attn_kv_cache_v, transformer_h_10_attn_proj_weight, transformer_h_10_norm_2_weight, transformer_h_10_mlp_fc_1_weight, transformer_h_10_mlp_fc_2_weight, transformer_h_10_mlp_proj_weight, transformer_h_11_norm_1_weight, transformer_h_11_attn_attn_weight, transformer_h_11_attn_kv_cache_k, transformer_h_11_attn_kv_cache_v, transformer_h_11_attn_proj_weight, transformer_h_11_norm_2_weight, transformer_h_11_mlp_fc_1_weight, transformer_h_11_mlp_fc_2_weight, transformer_h_11_mlp_proj_weight, transformer_h_12_norm_1_weight, transformer_h_12_attn_attn_weight, transformer_h_12_attn_kv_cache_k, transformer_h_12_attn_kv_cache_v, transformer_h_12_attn_proj_weight, transformer_h_12_norm_2_weight, transformer_h_12_mlp_fc_1_weight, transformer_h_12_mlp_fc_2_weight, transformer_h_12_mlp_proj_weight, transformer_h_13_norm_1_weight, transformer_h_13_attn_attn_weight, transformer_h_13_attn_kv_cache_k, transformer_h_13_attn_kv_cache_v, transformer_h_13_attn_proj_weight, transformer_h_13_norm_2_weight, transformer_h_13_mlp_fc_1_weight, transformer_h_13_mlp_fc_2_weight, transformer_h_13_mlp_proj_weight, transformer_h_14_norm_1_weight, transformer_h_14_attn_attn_weight, transformer_h_14_attn_kv_cache_k, transformer_h_14_attn_kv_cache_v, transformer_h_14_attn_proj_weight, transformer_h_14_norm_2_weight, transformer_h_14_mlp_fc_1_weight, transformer_h_14_mlp_fc_2_weight, transformer_h_14_mlp_proj_weight, transformer_h_15_norm_1_weight, transformer_h_15_attn_attn_weight, transformer_h_15_attn_kv_cache_k, transformer_h_15_attn_kv_cache_v, transformer_h_15_attn_proj_weight, transformer_h_15_norm_2_weight, transformer_h_15_mlp_fc_1_weight, transformer_h_15_mlp_fc_2_weight, transformer_h_15_mlp_proj_weight, transformer_h_16_norm_1_weight, transformer_h_16_attn_attn_weight, transformer_h_16_attn_kv_cache_k, transformer_h_16_attn_kv_cache_v, transformer_h_16_attn_proj_weight, transformer_h_16_norm_2_weight, transformer_h_16_mlp_fc_1_weight, transformer_h_16_mlp_fc_2_weight, transformer_h_16_mlp_proj_weight, transformer_h_17_norm_1_weight, transformer_h_17_attn_attn_weight, transformer_h_17_attn_kv_cache_k, transformer_h_17_attn_kv_cache_v, transformer_h_17_attn_proj_weight, transformer_h_17_norm_2_weight, transformer_h_17_mlp_fc_1_weight, transformer_h_17_mlp_fc_2_weight, transformer_h_17_mlp_proj_weight, transformer_h_18_norm_1_weight, transformer_h_18_attn_attn_weight, transformer_h_18_attn_kv_cache_k, transformer_h_18_attn_kv_cache_v, transformer_h_18_attn_proj_weight, transformer_h_18_norm_2_weight, transformer_h_18_mlp_fc_1_weight, transformer_h_18_mlp_fc_2_weight, transformer_h_18_mlp_proj_weight, transformer_h_19_norm_1_weight, transformer_h_19_attn_attn_weight, transformer_h_19_attn_kv_cache_k, transformer_h_19_attn_kv_cache_v, transformer_h_19_attn_proj_weight, transformer_h_19_norm_2_weight, transformer_h_19_mlp_fc_1_weight, transformer_h_19_mlp_fc_2_weight, transformer_h_19_mlp_proj_weight, transformer_h_20_norm_1_weight, transformer_h_20_attn_attn_weight, transformer_h_20_attn_kv_cache_k, transformer_h_20_attn_kv_cache_v, transformer_h_20_attn_proj_weight, transformer_h_20_norm_2_weight, transformer_h_20_mlp_fc_1_weight, transformer_h_20_mlp_fc_2_weight, transformer_h_20_mlp_proj_weight, transformer_h_21_norm_1_weight, transformer_h_21_attn_attn_weight, transformer_h_21_attn_kv_cache_k, transformer_h_21_attn_kv_cache_v, transformer_h_21_attn_proj_weight, transformer_h_21_norm_2_weight, transformer_h_21_mlp_fc_1_weight, transformer_h_21_mlp_fc_2_weight, transformer_h_21_mlp_proj_weight, transformer_h_22_norm_1_weight, transformer_h_22_attn_attn_weight, transformer_h_22_attn_kv_cache_k, transformer_h_22_attn_kv_cache_v, transformer_h_22_attn_proj_weight, transformer_h_22_norm_2_weight, transformer_h_22_mlp_fc_1_weight, transformer_h_22_mlp_fc_2_weight, transformer_h_22_mlp_proj_weight, transformer_h_23_norm_1_weight, transformer_h_23_attn_attn_weight, transformer_h_23_attn_kv_cache_k, transformer_h_23_attn_kv_cache_v, transformer_h_23_attn_proj_weight, transformer_h_23_norm_2_weight, transformer_h_23_mlp_fc_1_weight, transformer_h_23_mlp_fc_2_weight, transformer_h_23_mlp_proj_weight, transformer_h_24_norm_1_weight, transformer_h_24_attn_attn_weight, transformer_h_24_attn_kv_cache_k, transformer_h_24_attn_kv_cache_v, transformer_h_24_attn_proj_weight, transformer_h_24_norm_2_weight, transformer_h_24_mlp_fc_1_weight, transformer_h_24_mlp_fc_2_weight, transformer_h_24_mlp_proj_weight, transformer_h_25_norm_1_weight, transformer_h_25_attn_attn_weight, transformer_h_25_attn_kv_cache_k, transformer_h_25_attn_kv_cache_v, transformer_h_25_attn_proj_weight, transformer_h_25_norm_2_weight, transformer_h_25_mlp_fc_1_weight, transformer_h_25_mlp_fc_2_weight, transformer_h_25_mlp_proj_weight, transformer_h_26_norm_1_weight, transformer_h_26_attn_attn_weight, transformer_h_26_attn_kv_cache_k, transformer_h_26_attn_kv_cache_v, transformer_h_26_attn_proj_weight, transformer_h_26_norm_2_weight, transformer_h_26_mlp_fc_1_weight, transformer_h_26_mlp_fc_2_weight, transformer_h_26_mlp_proj_weight, transformer_h_27_norm_1_weight, transformer_h_27_attn_attn_weight, transformer_h_27_attn_kv_cache_k, transformer_h_27_attn_kv_cache_v, transformer_h_27_attn_proj_weight, transformer_h_27_norm_2_weight, transformer_h_27_mlp_fc_1_weight, transformer_h_27_mlp_fc_2_weight, transformer_h_27_mlp_proj_weight, transformer_h_28_norm_1_weight, transformer_h_28_attn_attn_weight, transformer_h_28_attn_kv_cache_k, transformer_h_28_attn_kv_cache_v, transformer_h_28_attn_proj_weight, transformer_h_28_norm_2_weight, transformer_h_28_mlp_fc_1_weight, transformer_h_28_mlp_fc_2_weight, transformer_h_28_mlp_proj_weight, transformer_h_29_norm_1_weight, transformer_h_29_attn_attn_weight, transformer_h_29_attn_kv_cache_k, transformer_h_29_attn_kv_cache_v, transformer_h_29_attn_proj_weight, transformer_h_29_norm_2_weight, transformer_h_29_mlp_fc_1_weight, transformer_h_29_mlp_fc_2_weight, transformer_h_29_mlp_proj_weight, transformer_h_30_norm_1_weight, transformer_h_30_attn_attn_weight, transformer_h_30_attn_kv_cache_k, transformer_h_30_attn_kv_cache_v, transformer_h_30_attn_proj_weight, transformer_h_30_norm_2_weight, transformer_h_30_mlp_fc_1_weight, transformer_h_30_mlp_fc_2_weight, transformer_h_30_mlp_proj_weight, transformer_h_31_norm_1_weight, transformer_h_31_attn_attn_weight, transformer_h_31_attn_kv_cache_k, transformer_h_31_attn_kv_cache_v, transformer_h_31_attn_proj_weight, transformer_h_31_norm_2_weight, transformer_h_31_mlp_fc_1_weight, transformer_h_31_mlp_fc_2_weight, transformer_h_31_mlp_proj_weight, transformer_ln_f_weight, lm_head_weight, idx, input_pos=None):
# cos: "cuda:0 f32[8, 128]"
# sin: "cuda:0 f32[8, 128]"
# mask_cache: "cuda:0 b8[1, 1, 8, 8]"
# transformer_wte_weight: "cuda:0 bf16[32000, 4096]"
# transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_0_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_0_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_0_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_0_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_0_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_0_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_1_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_1_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_1_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_1_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_1_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_1_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_1_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_1_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_2_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_2_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_2_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_2_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_2_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_2_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_2_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_2_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_3_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_3_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_3_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_3_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_3_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_3_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_3_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_3_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_4_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_4_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_4_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_4_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_4_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_4_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_4_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_4_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_5_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_5_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_5_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_5_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_5_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_5_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_5_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_5_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_6_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_6_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_6_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_6_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_6_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_6_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_6_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_6_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_7_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_7_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_7_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_7_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_7_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_7_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_7_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_7_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_8_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_8_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_8_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_8_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_8_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_8_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_8_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_8_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_9_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_9_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_9_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_9_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_9_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_9_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_9_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_9_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_10_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_10_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_10_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_10_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_10_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_10_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_10_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_10_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_11_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_11_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_11_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_11_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_11_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_11_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_11_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_11_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_12_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_12_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_12_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_12_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_12_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_12_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_12_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_12_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_13_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_13_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_13_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_13_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_13_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_13_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_13_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_13_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_14_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_14_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_14_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_14_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_14_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_14_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_14_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_14_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_15_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_15_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_15_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_15_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_15_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_15_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_15_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_15_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_16_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_16_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_16_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_16_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_16_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_16_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_16_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_16_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_16_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_17_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_17_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_17_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_17_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_17_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_17_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_17_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_17_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_17_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_18_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_18_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_18_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_18_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_18_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_18_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_18_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_18_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_18_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_19_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_19_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_19_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_19_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_19_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_19_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_19_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_19_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_19_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_20_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_20_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_20_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_20_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_20_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_20_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_20_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_20_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_20_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_21_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_21_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_21_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_21_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_21_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_21_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_21_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_21_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_21_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_22_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_22_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_22_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_22_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_22_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_22_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_22_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_22_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_22_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_23_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_23_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_23_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_23_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_23_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_23_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_23_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_23_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_23_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_24_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_24_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_24_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_24_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_24_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_24_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_24_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_24_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_24_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_25_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_25_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_25_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_25_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_25_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_25_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_25_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_25_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_25_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_26_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_26_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_26_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_26_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_26_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_26_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_26_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_26_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_26_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_27_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_27_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_27_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_27_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_27_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_27_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_27_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_27_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_27_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_28_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_28_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_28_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_28_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_28_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_28_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_28_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_28_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_28_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_29_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_29_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_29_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_29_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_29_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_29_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_29_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_29_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_29_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_30_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_30_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_30_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_30_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_30_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_30_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_30_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_30_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_30_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_h_31_norm_1_weight: "cuda:0 bf16[4096]"
# transformer_h_31_attn_attn_weight: "cuda:0 bf16[12288, 4096]"
# transformer_h_31_attn_kv_cache_k: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_31_attn_kv_cache_v: "cuda:0 bf16[1, 32, 8, 128]"
# transformer_h_31_attn_proj_weight: "cuda:0 bf16[4096, 4096]"
# transformer_h_31_norm_2_weight: "cuda:0 bf16[4096]"
# transformer_h_31_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_31_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]"
# transformer_h_31_mlp_proj_weight: "cuda:0 bf16[4096, 11008]"
# transformer_ln_f_weight: "cuda:0 bf16[4096]"
# lm_head_weight: "cuda:0 bf16[32000, 4096]"
# idx: "cuda:0 i32[1, 1]"
# input_pos: "cuda:0 i64[1]"
# actual forward starts now.......
Having this option is useful because it would allow debugging the trace directly.
cc @carmocca