ao icon indicating copy to clipboard operation
ao copied to clipboard

add rowwise-scaling to fp8 inference

Open drisspg opened this issue 1 year ago • 1 comments

Stacked PRs:

  • #709
  • ->#707
  • #706

add rowwise-scaling to fp8 inference

Using this script to exercise:

import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from torchao.float8.inference import quantize_to_float8, ActivationCasting, QuantConfig, ScalingGranularity
from torchao.float8.float8_utils import compute_error
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, profiler
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate

torch._dynamo.config.automatic_dynamic_shapes = False
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000
compile_backend = "inductor"
# compile_backend = None

class FeedForward(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.w1 = nn.Linear(4096, 14336, bias=False)
        self.w3 = nn.Linear(4096, 14336, bias=False)
        self.w2 = nn.Linear(14336, 4096, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

def setup_model(original_model, quant_config):
    model = copy.deepcopy(original_model)
    if quant_config:
        quantize_to_float8(model, quant_config)
    if compile_backend is None:
        return model
    else:
        return torch.compile(model, backend=compile_backend)

def run_benchmark(model, input_tensor, name, num_warmup=10, profile=False):
    with torch.no_grad():
        for _ in range(num_warmup):
            model(input_tensor)
        
        if profile:
            with profiler(Path(f"/home/drisspg/meta/scripts/fp8/data/{name}")):
                model(input_tensor)
        
        time = benchmark_cuda_function_in_microseconds(model, input_tensor)
        output = model(input_tensor)
    return time, output

def run_sweep(original_mlp, variants, input_sizes, profile=False):
    results = []
    for batch_size, num_tokens in tqdm(input_sizes):
        input_tensor = torch.rand(batch_size, num_tokens, 4096, device="cuda", dtype=torch.bfloat16) * 5

        variant_results = []
        outputs = {}
        for name, quant_config in variants:
            if name == "FP8_Static_AxisWise":
                # Update the static quantization scale for AxisWise
                quant_config = QuantConfig(ActivationCasting.STATIC, torch.full((num_tokens*batch_size, 1), 1.0, device="cuda", dtype=torch.float32),  scaling_granularity=ScalingGranularity.AXIS_WISE)
            
            model = setup_model(original_mlp, quant_config)
            time, output = run_benchmark(model, input_tensor, f"{name}_{batch_size}_{num_tokens}", profile=profile)
            variant_results.append([name, f"{time:.2f}"])
            outputs[name] = output

        bf16_output = outputs["BF16"]
        bf16_time = float(variant_results[0][1])  # Assuming BF16 is the first variant
        
        comparison_results = [
            [row[0], row[1], f"{bf16_time / float(row[1]):.2f}x", f"{compute_error(output, bf16_output):.6e}"]
            for row, (name, output) in zip(variant_results, outputs.items())
        ]
        
        results.append((batch_size, num_tokens, comparison_results))
    
    return results

if __name__ == "__main__":
    profile = False
    original_mlp = FeedForward().to("cuda").to(torch.bfloat16)

    variants = [
        ("BF16", None),
        ("FP8_Dynamic_TensorWise", QuantConfig(ActivationCasting.DYNAMIC, scaling_granularity=ScalingGranularity.TENSOR_WISE)),
        ("FP8_Static_TensorWise", QuantConfig(ActivationCasting.STATIC, torch.tensor([1.0], device="cuda", dtype=torch.float32),  scaling_granularity=ScalingGranularity.TENSOR_WISE)),
        ("FP8_Weight_Only_TensorWise", QuantConfig(ActivationCasting.WEIGHT_ONLY,  scaling_granularity=ScalingGranularity.TENSOR_WISE)),
        ("FP8_Dynamic_AxisWise", QuantConfig(ActivationCasting.DYNAMIC,scaling_granularity= ScalingGranularity.AXIS_WISE)),
        ("FP8_Static_AxisWise", None),  # We'll update this in run_sweep
        ("FP8_Weight_Only_AxisWise", QuantConfig(ActivationCasting.WEIGHT_ONLY,  scaling_granularity=ScalingGranularity.AXIS_WISE)),
    ]

    input_sizes = [
        (1, 128),
        (1, 1024),
        (32, 128),
        (32, 1024),
        (64, 2048),
    ]

    sweep_results = run_sweep(original_mlp, variants, input_sizes, profile=profile)

    # Print results
    headers = ["Variant", "Time (μs)", "Speedup vs BF16", "SQNR vs BF16"]
    for batch_size, num_tokens, comparison_results in sweep_results:
        print(f"\nResults for batch_size={batch_size}, num_tokens={num_tokens}:")
        print(tabulate(comparison_results, headers=headers, tablefmt="grid"))

drisspg avatar Aug 19 '24 21:08 drisspg

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/707

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit b569e7a037c9cf113b76be0ff165db309426d54e with merge base ac8ce4ceb548c264a2f508f7b6a1f413d8d4454c (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Aug 19 '24 21:08 pytorch-bot[bot]