ao
ao copied to clipboard
add rowwise-scaling to fp8 inference
Stacked PRs:
- #709
- ->#707
- #706
add rowwise-scaling to fp8 inference
Using this script to exercise:
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from torchao.float8.inference import quantize_to_float8, ActivationCasting, QuantConfig, ScalingGranularity
from torchao.float8.float8_utils import compute_error
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, profiler
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate
torch._dynamo.config.automatic_dynamic_shapes = False
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000
compile_backend = "inductor"
# compile_backend = None
class FeedForward(nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = nn.Linear(4096, 14336, bias=False)
self.w3 = nn.Linear(4096, 14336, bias=False)
self.w2 = nn.Linear(14336, 4096, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def setup_model(original_model, quant_config):
model = copy.deepcopy(original_model)
if quant_config:
quantize_to_float8(model, quant_config)
if compile_backend is None:
return model
else:
return torch.compile(model, backend=compile_backend)
def run_benchmark(model, input_tensor, name, num_warmup=10, profile=False):
with torch.no_grad():
for _ in range(num_warmup):
model(input_tensor)
if profile:
with profiler(Path(f"/home/drisspg/meta/scripts/fp8/data/{name}")):
model(input_tensor)
time = benchmark_cuda_function_in_microseconds(model, input_tensor)
output = model(input_tensor)
return time, output
def run_sweep(original_mlp, variants, input_sizes, profile=False):
results = []
for batch_size, num_tokens in tqdm(input_sizes):
input_tensor = torch.rand(batch_size, num_tokens, 4096, device="cuda", dtype=torch.bfloat16) * 5
variant_results = []
outputs = {}
for name, quant_config in variants:
if name == "FP8_Static_AxisWise":
# Update the static quantization scale for AxisWise
quant_config = QuantConfig(ActivationCasting.STATIC, torch.full((num_tokens*batch_size, 1), 1.0, device="cuda", dtype=torch.float32), scaling_granularity=ScalingGranularity.AXIS_WISE)
model = setup_model(original_mlp, quant_config)
time, output = run_benchmark(model, input_tensor, f"{name}_{batch_size}_{num_tokens}", profile=profile)
variant_results.append([name, f"{time:.2f}"])
outputs[name] = output
bf16_output = outputs["BF16"]
bf16_time = float(variant_results[0][1]) # Assuming BF16 is the first variant
comparison_results = [
[row[0], row[1], f"{bf16_time / float(row[1]):.2f}x", f"{compute_error(output, bf16_output):.6e}"]
for row, (name, output) in zip(variant_results, outputs.items())
]
results.append((batch_size, num_tokens, comparison_results))
return results
if __name__ == "__main__":
profile = False
original_mlp = FeedForward().to("cuda").to(torch.bfloat16)
variants = [
("BF16", None),
("FP8_Dynamic_TensorWise", QuantConfig(ActivationCasting.DYNAMIC, scaling_granularity=ScalingGranularity.TENSOR_WISE)),
("FP8_Static_TensorWise", QuantConfig(ActivationCasting.STATIC, torch.tensor([1.0], device="cuda", dtype=torch.float32), scaling_granularity=ScalingGranularity.TENSOR_WISE)),
("FP8_Weight_Only_TensorWise", QuantConfig(ActivationCasting.WEIGHT_ONLY, scaling_granularity=ScalingGranularity.TENSOR_WISE)),
("FP8_Dynamic_AxisWise", QuantConfig(ActivationCasting.DYNAMIC,scaling_granularity= ScalingGranularity.AXIS_WISE)),
("FP8_Static_AxisWise", None), # We'll update this in run_sweep
("FP8_Weight_Only_AxisWise", QuantConfig(ActivationCasting.WEIGHT_ONLY, scaling_granularity=ScalingGranularity.AXIS_WISE)),
]
input_sizes = [
(1, 128),
(1, 1024),
(32, 128),
(32, 1024),
(64, 2048),
]
sweep_results = run_sweep(original_mlp, variants, input_sizes, profile=profile)
# Print results
headers = ["Variant", "Time (μs)", "Speedup vs BF16", "SQNR vs BF16"]
for batch_size, num_tokens, comparison_results in sweep_results:
print(f"\nResults for batch_size={batch_size}, num_tokens={num_tokens}:")
print(tabulate(comparison_results, headers=headers, tablefmt="grid"))
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/707
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit b569e7a037c9cf113b76be0ff165db309426d54e with merge base ac8ce4ceb548c264a2f508f7b6a1f413d8d4454c ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.