convert_to_float8_training and torch.compile make model slow
NOTE: I retest the times
| Configuration | Training Time 1 (s) | Training Time 2 (s) | Training Time 3 (s) | Mean (s) |
|---|---|---|---|---|
Without convert_to_float8_training and without torch.compile |
0.1294 | 0.1314 | 0.1303 | 0.1303 |
With convert_to_float8_training and without torch.compile |
0.1839 | 0.1856 | 0.1827 | 0.1844 |
Without convert_to_float8_training and with torch.compile |
0.9767 | 0.9438 | 1.0092 | 0.9766 |
With convert_to_float8_training and with torch.compile |
1.0895 | 1.2346 | 1.2095 | 1.1779 |
my env: H100
torch 2.5.0+cu124
torchao 0.11.0
my test code: (copy from https://github.com/pytorch/ao/tree/main/torchao/float8#float8-linear-with-dynamic-tensorwise-scaling)
import time
import torch
import torch.nn as nn
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# enable torch.compile for competitive performance
m = torch.compile(m)
for __ in range(1):
torch.cuda.synchronize() # Ensure all previous GPU work is complete
start_time = time.time()
# toy training loop
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
torch.cuda.synchronize() # Ensure all previous GPU work is complete
end_time = time.time()
print("Training time:", end_time - start_time)
These shapes are too small to benefit from float8 training. The benefit of using FP8 tensorcores is outweighed by the dynamic quant overhead in this regime. Check out this table in the docs to determine if the GEMMs in your model will benefit from dynamic float8 quantization or not.
Try adjusting your script to use bigger shapes and you should see a perf improvement.
hello, @danielvegamyhre, I also test a bigger linear shape and model size:
# create model and sample input
m = nn.Sequential(
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 16384),
nn.Linear(16384, 128),
).bfloat16().cuda()
x = torch.randn(16384, 16384, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
| Configuration | Training Time 1 (s) | Training Time 2 (s) | Training Time 3 (s) | Mean (s) |
|---|---|---|---|---|
Without convert_to_float8_training and without torch.compile |
6.6669 | 6.6856 | 6.6797 | 6.6774 |
With convert_to_float8_training and without torch.compile |
8.1664 | 8.1776 | 8.1725 | 8.1722 |
Without convert_to_float8_training and with torch.compile |
7.9422 | 7.9149 | 7.9176 | 7.9249 |
With convert_to_float8_training and with torch.compile |
7.7129 | 7.7049 | 7.6724 | 7.6967 |
The speed of w fp8 training and w torch compile is slightly faster than w/o fp8 training w torch compile. According to the table, it should show a 1.80x speedup.
However, the fastest speed is oberved in w/o fp8 and w/o torch compile.
Try filtering out the last linear nn.Linear(16384, 128) in your module_filter_fn which has such a small N dim that it will have a substantial slowdown with float8. That is probably tanking the overal net speedup from float8.
I am facing a similar issue with fsdp2 enabled:
m = nn.Sequential(
nn.Linear(4096, 4096*3, bias=False),
nn.Linear(4096*3, 4096, bias=False),
).to(device=device, dtype=torch.bfloat16)
x = torch.randn(32000, 4096, device="cuda", dtype=torch.bfloat16)
With FP8:
[GPU-0] Training Time: 18.083s
[GPU-0] Avg. Iter. Time: 0.181s
[GPU-0] Peak Memory Use: 2294.3MBs
[GPU-0] FP8 enabled: True
Without FP8:
[GPU-0] Training Time: 16.955s
[GPU-0] Avg. Iter. Time: 0.170s
[GPU-0] Peak Memory Use: 2753.6MBs
[GPU-0] FP8 enabled: False
Notice that even memory savings are not substantial with FP8
import os
import argparse
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torchao.float8 import convert_to_float8_training, Float8MMConfig, Float8LinearConfig
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
# FSDP setup
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
def dist_print(text):
if LOCAL_RANK == 0:
print(f"[GPU-{LOCAL_RANK}] " + text)
def parse_args():
parser = argparse.ArgumentParser(description="TorchAO Float8 Training with FSDP2")
parser.add_argument(
"--fp8",
action="store_true",
default=False,
help="Enable FP8 training using torchao.float8.convert_to_float8_training"
)
parser.add_argument(
"--batch-size",
type=int,
default=32000,
help="Batch size for training"
)
parser.add_argument(
"--num-iters",
type=int,
default=10,
help="Number of training iterations"
)
parser.add_argument(
"--lr",
type=float,
default=0.1,
help="Learning rate"
)
return parser.parse_args()
def main():
args = parse_args()
# Initialize distributed training
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(LOCAL_RANK)
dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
# Set seed for reproducibility
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
# create model on cuda device directly (FSDP2 doesn't require meta device)
device = "cuda"
m = nn.Sequential(
nn.Linear(4096, 4096*3, bias=False),
nn.Linear(4096*3, 4096, bias=False),
).to(device=device, dtype=torch.bfloat16)
# Print memory usage before FSDP
if WORLD_SIZE > 1:
pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
dist_print(f"Pre-FSDP memory use = {pre_mem_use}MiB")
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
# if fqn == "1":
# return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear` if fp8 is enabled
if args.fp8:
dist_print("Applying FP8 conversion to model...")
config = Float8LinearConfig(
enable_fsdp_float8_all_gather=True,
force_recompute_fp8_weight_in_bwd=True,
)
convert_to_float8_training(m, module_filter_fn=module_filter_fn, config=config)
else:
dist_print("FP8 conversion disabled, using standard precision")
# Apply FSDP2 if distributed
if WORLD_SIZE > 1:
# Create mixed precision policy for FSDP2
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
# Apply fully_shard to each layer for better control
for i, layer in enumerate(m):
fully_shard(layer, mp_policy=mp_policy)
# Apply fully_shard to the entire model
fully_shard(m, mp_policy=mp_policy)
# Print memory usage after FSDP
post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
dist_print(f"Post-FSDP2 memory use = {post_mem_use}MiB")
dist_print(f"FSDP2-Wrapped Model:\n{m}")
# # enable torch.compile for competitive performance
m = torch.compile(m)
# create sample input
x = torch.randn(args.batch_size, 4096, device="cuda", dtype=torch.bfloat16)
# optimizer must be created after FSDP wrapping
optimizer = torch.optim.SGD(m.parameters(), lr=args.lr)
# Memory profiling setup
torch.cuda.reset_peak_memory_stats()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
# warmup
for i in range(20):
optimizer.zero_grad()
y = m(x)
loss = y.sum()
loss.backward()
optimizer.step()
# toy training loop
for i in range(args.num_iters):
optimizer.zero_grad()
y = m(x)
loss = y.sum()
loss.backward()
optimizer.step()
if LOCAL_RANK == 0 and i % 5 == 0:
print(f"Iteration {i}, Loss: {loss.item():.4f}")
# Performance metrics
end.record()
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated()
train_time = start.elapsed_time(end) / 1000.0
dist_print(f"Training Time: {train_time:.3f}s")
dist_print(f"Avg. Iter. Time: {train_time / args.num_iters:.3f}s")
dist_print(f"Peak Memory Use: {peak_mem * 1e-6:.1f}MBs")
dist_print(f"FP8 enabled: {args.fp8}")
# Cleanup
if WORLD_SIZE > 1:
dist.destroy_process_group()
if __name__ == "__main__":
main()
Env: CUDA: 12.8 Pytorch Version: 2.7.0+cu128 torchao 0.11.0 Hardware: H100
cc: @danielvegamyhre
@ajWithNucleus 2 things:
-
Hmm we should see improved throughput in your example. I'll try to repro with your code snippet and take a look at the trace. It might have to do with how your benchmarking code is set up. Also, seems like you're doing fully_shard both on indiviudal layers as well as the full model afterwards, is that intentional?
-
As far as peak memory, with this simple 2 linear layer model, there will be no memory savings with fp8. The memory savings come from using torch.compile doing optimizations like CSE, cross-layer fusions, re-using buffers, etc. - not from using fp8. Dynamic float8 quant on its own will actually use some additional memory, but will increase throughput. This is because in training, we are using dynamic float8 quantization where the weights and activations stay in bf16, but we dynamically compute fp8 versions of each just to perform fp8 GEMMs, which are faster than bf16 GEMMs. The output of the fp8 gemm is in bf16.
cc @vkuzo as well
Got it, thanks a lot for the clarification. @danielvegamyhre so if my understanding is correct this is different from transformer engine's impl, where activations might be stored in FP8?
Try filtering out the last linear
nn.Linear(16384, 128)in yourmodule_filter_fnwhich has such a small N dim that it will have a substantial slowdown with float8. That is probably tanking the overal net speedup from float8.
hello @danielvegamyhre, I tried filtering out the last linear layer. But, it didn't work as expected.
| Configuration | Training Time 1 (s) | Training Time 2 (s) | Training Time 3 (s) |
|---|---|---|---|
Without convert_to_float8_training and without torch.compile |
6.6573 | 6.6762 | 6.6804 |
With convert_to_float8_training and with torch.compile |
7.5983 | 7.5895 | 7.6059 |
Could you please share your test results for the convert_to_float8_training and torch.compile?
I want to know whether this is a bug or my env issue.
@Yzx835 , do you see better results if you use linears without bias in your benchmark? torch.nn.Linear(M, K, bias=False)
We currently add bias in a separate kernel from the float8 matrix multiply. If there are multiple matrix multiplies in a row (such as in the benchmark above, but rarely occuring in modern models), there isn't anything to fuse the bias to so there is a performance penalty.
by the way, we should update the README.md snippet to something less toy and more representable, and mention the expected performance of the snippet
@Yzx835 , do you see better results if you use linears without bias in your benchmark?
torch.nn.Linear(M, K, bias=False)We currently add bias in a separate kernel from the float8 matrix multiply. If there are multiple matrix multiplies in a row (such as in the benchmark above, but rarely occuring in modern models), there isn't anything to fuse the bias to so there is a performance penalty.
I attempted to remove the bias, but the results were quite similar to those with the bias present.
I see, thank you. I am unavailable today/tomorrow but let me take a further look early next week. Thank you for writing this up!
if you modify your benchmarking code to ignore the first iteration (which spends a non-trivial amount of time on torch.compile warmup), and slightly increase the shapes then you should see a speedup. Here is a diff to modify the README.md example to demonstrate this: https://gist.github.com/vkuzo/f7e642f52096f8873edbee15a065236f
I'm having some issues with internet connectivity on my dev machine at the moment, as soon as I resolve them I'll put up a PR to modify the README.md to make this easier.
https://github.com/pytorch/ao/pull/2291
@vkuzo Thank you!
I tested the new code, which ignores the first few iterations. I did see a speedup.
fp8 training, torch compile, Training time: 22.640191793441772 without fp8 training, torch compile, Training time: 34.18098497390747 fp8 training, without torch compile, Training time: 47.06538534164429 without fp8 training, without torch compile, Training time: 34.08416795730591
this is my test code: (with larger model size and more train iterations)
import time
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
# create model and sample input
M, K, N = 8192, 16384, 16384
m = nn.Sequential(
nn.Linear(K, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, N, bias=False),
nn.Linear(N, 128, bias=False),
).bfloat16().cuda()
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear`
print('fp8 training')
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# print('without fp8 training')
# enable torch.compile for competitive performance
# print('torch compile')
# m = torch.compile(m)
print('without torch compile')
# warm up torch.compile for a clean training time measurement
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
torch.cuda.synchronize()
start_time = time.time()
# toy training loop
for _ in range(100):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
torch.cuda.synchronize()
end_time = time.time()
print("Training time:", end_time - start_time)