lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Collect popular models' `GraphModule`s with distributed collective communication operators in there

Open crcrpar opened this issue 6 months ago • 5 comments

crcrpar avatar Oct 03 '25 04:10 crcrpar

I applied ThunderFX in sglang.srt.model_executor.cuda_graph_runner.patch_model instead of torch.compile and collected its compiled GraphModules on Qwen/Qwen3-30B-A3B-Instruct-2507. See outputs and details HERE.

The ops that need support are:

  1. t0.is_cpu
  2. torch.ops.sglang.inplace_all_reduce(t0, group_name = 'tp:0')
  3. torch.cuda.get_device_capability(0)
  4. torch.ops.sglang.reg_all_gather_into_tensor(t0, t1, group_name = 'tp:0')

I will do this on other models and collect unsupported ops.

shino16 avatar Oct 15 '25 18:10 shino16

I added results on some HF models to my fork.

I did not find any new ops that were unsupported. For now, we can assume all we need to cover is:

  1. t0.is_cpu
  2. torch.ops.sglang.inplace_all_reduce(t0, group_name = 'tp:0')
  3. torch.cuda.get_device_capability(0)
  4. torch.ops.sglang.reg_all_gather_into_tensor(t0, t1, group_name = 'tp:0')

Note: I tested on

  • Qwen/Qwen3-30B-A3B-Instruct-2507
  • Qwen/Qwen3-0.6B
  • openai/gpt-oss-20b
  • mistralai/Magistral-Small-2509
  • mistralai/Ministral-8B-Instruct-2410
  • mistralai/Mistral-Large-Instruct-2411
  • mistralai/Mixtral-8x7B-Instruct-v0.1
  • microsoft/Phi-4-mini-reasoning
  • huihui-ai/Qwen2.5-32B-Instruct-abliterated
  • unsloth/Qwen3-30B-A3B-GGUF
  • unsloth/gpt-oss-20b-BF16

shino16 avatar Oct 16 '25 20:10 shino16

Can you check llama 4, deep seek v3.1, and qwen3 next? I understand the models are quite heavy though

crcrpar avatar Oct 17 '25 05:10 crcrpar

I couldn't run meta-llama/Llama-4-Maverick-17B-128E-Instruct and meta-llama/Llama-4-Scout-17B-16E-Instruct. They gave:

  File "/usr/local/lib/python3.12/dist-packages/triton_kernels/numerics_details/flexpoint.py", line 55, in <module>
    @tl.constexpr_function
     ^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'triton.language' has no attribute 'constexpr_function'

And I only got partial results from Qwen3 Next models:

  File "/opt/sglang/sglang-src/python/sglang/srt/model_executor/model_runner.py", line 1917, in init_device_graphs
    self.graph_runner = graph_runners[self.device](self)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/sglang/sglang-src/python/sglang/srt/model_executor/cuda_graph_runner.py", line 421, in __init__
    raise Exception(
Exception: Capture cuda graph failed: backend='<thunder.dynamo.compiler.ThunderCompiler object at 0xec7f602bb9e0>' raised:
AssertionError: Failed to find method wait_stream

See my fork for the updated results.

Now we need to support:

  1. torch.ops.higher_order.triton_kernel_wrapper_mutation(kernel_idx = 6, constant_args_idx = 5, grid = [(56, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'y_ptr': input_2d, 'y_q_ptr': x_q, 'y_s_ptr': x_s})
  2. l_a_.new_empty((1, 2112), dtype = torch.bfloat16)
  3. torch.ops.higher_order.triton_kernel_wrapper_mutation(kernel_idx = 8, constant_args_idx = 7, grid = [(17, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'A': l_a_, 'B': l_b_, 'C': l_c_, 'As': l_as_, 'Bs': l_bs_})
  4. torch.ops.higher_order.triton_kernel_wrapper_mutation(kernel_idx = 15, constant_args_idx = 23, grid = [(1, 5, 1)], tma_descriptor_metadata = {}, kwargs = {'kv_buffer_ptr': l_forward_batch_token_to_kv_pool_kv_buffer_0_, 'cache_k_nope_ptr': l_k_, 'cache_k_rope_ptr': l_k_rope_, 'loc_ptr': l_forward_batch_out_cache_loc})
  5. torch.bmm(transpose, l_self_w_vc, out = transpose_1)
  6. torch.ops.sgl_kernel.dsv3_router_gemm(output, l_hidden_states_, l_self_modules_gate_parameters_weight_)
  7. torch.ops.sglang.inplace_fused_experts(l_hidden_states_, l_self_modules_experts_parameters_w13_weight_, l_self_modules_experts_parameters_w2_weight_, l_stack0_topk_weights, l_stack0_topk_ids, None, None, 'silu', False, True, False, False, False, False, l_self_modules_experts_parameters_w13_weight_scale_inv_, l_self_modules_experts_parameters_w2_weight_scale_inv_, None, None, None, None, [128, 128], 2.5, None, None)

for DeepSeek v3.1, and

  1. torch.ops.higher_order.triton_kernel_wrapper_mutation(kernel_idx = 4, constant_args_idx = 10, grid = [(129, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'GatherIndx': topk_indx, 'ScatterIndx': gate_indx, 'GateScal': gate_scal, 'ExptScal': l_stack0_0_, 'ExptIndx': l_stack0_1_, 'PartialOffs': out_partials_2, 'TokensStart': expt_offs, 'Hist': hist, 'MDTileStarts': token_offs_pad_1, 'MDTileInfo': block_pid_map_1})
  2. torch.ops.sgl_kernel.rmsnorm.default(out, l_hidden_states_, _get_data_attr, 1e-05, True),
  3. torch._C._autograd._get_data_attr(l_self_layer_communicator_input_layernorm_parameters_weight_)

for openai/gpt-oss-20b and unsloth/gpt-oss-20b-BF16, and

  1. torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm(input_tensor = l_stack0_, residual = l_residual_, weight = l_self_layer_communicator_post_attention_layernorm_parameters_weight_, eps = 1e-05)
  2. transpose[slice(None, 1, None)] (indexing)
  3. torch.empty((1, 32), dtype = torch.uint32, device = device(type='cuda', index=0))
  4. torch.transpose(bitmatrix, 0, 1)

specifically for unsloth/gpt-oss-20b-BF16.

shino16 avatar Oct 21 '25 05:10 shino16

Update on Llama-4: the error comes from a discrepancy between triton's and triton_kernels' versions. Installing triton==3.4.0 fixes this, but PyPI makes the triton==3.4.0 whl packages available only for x86_64 platforms. Now I need either

  1. a node with x86_64 CPU and strong GPUs, or
  2. a Docker container with a correct installation of triton and triton_kernels.

Update on Qwen3 Next: The error can be fixed by the workaround in https://github.com/Lightning-AI/lightning-thunder/issues/2332, but the models can't be compiled even with that workaround, because Dynamo does not support receiving a Stream as a model parameter. See the error and the GraphModule being compiled before splitting (see how current_stream_2 is passed around).

shino16 avatar Oct 21 '25 19:10 shino16