ao icon indicating copy to clipboard operation
ao copied to clipboard

Tensor Subclass + VLLM Compile

Open drisspg opened this issue 5 months ago • 0 comments
trafficstars

VLLM Torch.compile Issue Tracker

Summary

This document tracks the existing issue with the way VLLM uses torch.compile and tensor subclasses.

TLDR: VLLM doesn't setup aotdispatch correctly, causing subclass flattening to not take place.

@zou3519 has theoretically fixed this issue with: https://github.com/vllm-project/vllm/pull/17057, which enables standalone compile.

Testing Results

MXFP4 Model Testing

Command:

python vllm/sample_output.py --model_name "data/mxfp4-Qwen2-7B-Instruct" --compile True

Issue: The swizzle kernel needs to be enabled because without it, the dynamic control flow in to_blocked will error with how vllm bakes out different graphs.

Error encountered:

torch._inductor.exc.InductorError: RuntimeError: Failed to import /tmp/torchinductor_drisspg/zw/czwhsx3gfghqtfk5wbhit2p2xpgr3yxil54pbsi2mtnirmrnrths.py
IndentationError: unexpected indent (czwhsx3gfghqtfk5wbhit2p2xpgr3yxil54pbsi2mtnirmrnrths.py, line 173)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Note: I suspect that it doesn't like inlining the user defined Triton kernel correctly. cc @oulgen if you have any ideas here

FP8 Model Testing

Command:

python vllm/sample_output.py --model_name "data/fp8-Qwen2-7B-Instruct" --compile True

Results: ✅ Working

First run compilation time:

INFO 05-21 22:29:11 [monitor.py:33] torch.compile takes 136.26 s in total

Sample outputs:

  • Prompt: 'Why is Pytorch 2.0 the best machine learning compiler?'
    Generated: ' PyTorch 2.0, currently not released officially, is anticipated to be a significant upgrade in several areas including performance, features, and usability...'

  • Prompt: 'Hello, my name is'
    Generated: ' Mandy Lowry, and I am a certified financial planner and a long-time friend of the Girls' Town...'

  • Prompt: 'The president of the United States is'
    Generated: " the leader of the government of the United States. The president is also the commander-in-chief of the United States Armed Forces..."

  • Prompt: 'The capital of France is'
    Generated: '__.\nEdinburgh\nGeneva\nParis\nLondon\n答案:\n\nC...'

  • Prompt: 'The future of AI is'
    Generated: ' moving closer to reality with the launch of a revolutionary new AI-powered software platform called Hummingbird...'

Second run compilation time (cached):

INFO 05-21 22:30:32 [backends.py:134] Directly load the compiled graph(s) for shape None from the cache, took 8.395 s
INFO 05-21 22:30:42 [monitor.py:33] torch.compile takes 7.26 s in total

Test Script

import os
import random
import numpy as np
import torch

from vllm import LLM, SamplingParams
from rich import print


def set_seed(seed):
    """Set seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def main(
    model_name: str = "Qwen/Qwen2-7B-Instruct",
    max_tokens=64,
    tp_size: int = 1,
    compile: bool = True,
):
    # Set seed before creating the LLM
    set_seed(42)

    # Environment variables for VLLM configuration
    # os.environ["VLLM_TORCH_PROFILER_DIR"] = "data/flex_profile"  # Enable torch profiler
    os.environ["VLLM_USE_V1"] = "1"
    # os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION_VLLM_V1"
    # os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
    # os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
    os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"

    # Create sampling params
    sampling_params = SamplingParams(
        temperature=0.8, 
        top_p=0.95, 
        seed=42, 
        max_tokens=max_tokens
    )
    
    # Create LLM instance
    print(f"Using Model name: {model_name}")
    llm = LLM(
        model=model_name, 
        tensor_parallel_size=tp_size, 
        enforce_eager=not compile
    )
    
    # Test prompts
    prompts = [
        "Why is Pytorch 2.0 the best machine learning compiler?",
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    
    # Generate outputs
    outputs = llm.generate(prompts, sampling_params)

    # Print results
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)

drisspg avatar May 22 '25 05:05 drisspg