ao icon indicating copy to clipboard operation
ao copied to clipboard

convert_to_float8_training and torch.compile make model slow

Open Yzx835 opened this issue 7 months ago • 12 comments

NOTE: I retest the times

Configuration Training Time 1 (s) Training Time 2 (s) Training Time 3 (s) Mean (s)
Without convert_to_float8_training and without torch.compile 0.1294 0.1314 0.1303 0.1303
With convert_to_float8_training and without torch.compile 0.1839 0.1856 0.1827 0.1844
Without convert_to_float8_training and with torch.compile 0.9767 0.9438 1.0092 0.9766
With convert_to_float8_training and with torch.compile 1.0895 1.2346 1.2095 1.1779

my env: H100

torch                                    2.5.0+cu124
torchao                                  0.11.0

my test code: (copy from https://github.com/pytorch/ao/tree/main/torchao/float8#float8-linear-with-dynamic-tensorwise-scaling)

import time

import torch
import torch.nn as nn
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
    raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
m = nn.Sequential(
    nn.Linear(2048, 4096),
    nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)


# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True


# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
m = torch.compile(m)

for __ in range(1):
    torch.cuda.synchronize()  # Ensure all previous GPU work is complete
    start_time = time.time()

    # toy training loop
    for _ in range(10):
        optimizer.zero_grad()
        y = m(x)
        y.sum().backward()
        optimizer.step()

    torch.cuda.synchronize()  # Ensure all previous GPU work is complete
    end_time = time.time()

    print("Training time:", end_time - start_time)

Yzx835 avatar May 26 '25 10:05 Yzx835

These shapes are too small to benefit from float8 training. The benefit of using FP8 tensorcores is outweighed by the dynamic quant overhead in this regime. Check out this table in the docs to determine if the GEMMs in your model will benefit from dynamic float8 quantization or not.

Try adjusting your script to use bigger shapes and you should see a perf improvement.

danielvegamyhre avatar May 26 '25 14:05 danielvegamyhre

hello, @danielvegamyhre, I also test a bigger linear shape and model size:

# create model and sample input
m = nn.Sequential(
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 16384),
    nn.Linear(16384, 128),
).bfloat16().cuda()
x = torch.randn(16384, 16384, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
Configuration Training Time 1 (s) Training Time 2 (s) Training Time 3 (s) Mean (s)
Without convert_to_float8_training and without torch.compile 6.6669 6.6856 6.6797 6.6774
With convert_to_float8_training and without torch.compile 8.1664 8.1776 8.1725 8.1722
Without convert_to_float8_training and with torch.compile 7.9422 7.9149 7.9176 7.9249
With convert_to_float8_training and with torch.compile 7.7129 7.7049 7.6724 7.6967

The speed of w fp8 training and w torch compile is slightly faster than w/o fp8 training w torch compile. According to the table, it should show a 1.80x speedup. However, the fastest speed is oberved in w/o fp8 and w/o torch compile.

Yzx835 avatar May 27 '25 06:05 Yzx835

Try filtering out the last linear nn.Linear(16384, 128) in your module_filter_fn which has such a small N dim that it will have a substantial slowdown with float8. That is probably tanking the overal net speedup from float8.

danielvegamyhre avatar May 27 '25 14:05 danielvegamyhre

I am facing a similar issue with fsdp2 enabled:

m = nn.Sequential(
        nn.Linear(4096, 4096*3, bias=False),
        nn.Linear(4096*3, 4096, bias=False),
    ).to(device=device, dtype=torch.bfloat16)

 x = torch.randn(32000, 4096, device="cuda", dtype=torch.bfloat16)

With FP8:

[GPU-0] Training Time: 18.083s
[GPU-0] Avg. Iter. Time: 0.181s
[GPU-0] Peak Memory Use: 2294.3MBs
[GPU-0] FP8 enabled: True

Without FP8:

[GPU-0] Training Time: 16.955s
[GPU-0] Avg. Iter. Time: 0.170s
[GPU-0] Peak Memory Use: 2753.6MBs
[GPU-0] FP8 enabled: False

Notice that even memory savings are not substantial with FP8

import os
import argparse
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torchao.float8 import convert_to_float8_training, Float8MMConfig, Float8LinearConfig
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
    raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# FSDP setup
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))

def dist_print(text):
    if LOCAL_RANK == 0:
        print(f"[GPU-{LOCAL_RANK}] " + text)

def parse_args():
    parser = argparse.ArgumentParser(description="TorchAO Float8 Training with FSDP2")
    parser.add_argument(
        "--fp8",
        action="store_true",
        default=False,
        help="Enable FP8 training using torchao.float8.convert_to_float8_training"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32000,
        help="Batch size for training"
    )
    parser.add_argument(
        "--num-iters",
        type=int,
        default=10,
        help="Number of training iterations"
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.1,
        help="Learning rate"
    )
    return parser.parse_args()

def main():
    args = parse_args()
    
    # Initialize distributed training
    if WORLD_SIZE > 1:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(LOCAL_RANK)
        dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
    
    # Set seed for reproducibility
    torch.manual_seed(1234)
    torch.cuda.manual_seed(1234)

    # create model on cuda device directly (FSDP2 doesn't require meta device)
    device = "cuda"
    m = nn.Sequential(
        nn.Linear(4096, 4096*3, bias=False),
        nn.Linear(4096*3, 4096, bias=False),
    ).to(device=device, dtype=torch.bfloat16)

    # Print memory usage before FSDP
    if WORLD_SIZE > 1:
        pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
        dist_print(f"Pre-FSDP memory use = {pre_mem_use}MiB")

    # optional: filter modules from being eligible for float8 conversion
    def module_filter_fn(mod: torch.nn.Module, fqn: str):
        # don't convert the last module
        # if fqn == "1":
        #     return False
        # don't convert linear modules with weight dimensions not divisible by 16
        if isinstance(mod, torch.nn.Linear):
            if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
                return False
        return True

    # convert specified `torch.nn.Linear` modules to `Float8Linear` if fp8 is enabled
    if args.fp8:
        dist_print("Applying FP8 conversion to model...")
        config = Float8LinearConfig(
            enable_fsdp_float8_all_gather=True,
            force_recompute_fp8_weight_in_bwd=True,

        )
        convert_to_float8_training(m, module_filter_fn=module_filter_fn, config=config)
    else:
        dist_print("FP8 conversion disabled, using standard precision")
    
    # Apply FSDP2 if distributed
    if WORLD_SIZE > 1:
        # Create mixed precision policy for FSDP2
        mp_policy = MixedPrecisionPolicy(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
        )
        
        # Apply fully_shard to each layer for better control
        for i, layer in enumerate(m):
            fully_shard(layer, mp_policy=mp_policy)
        
        # Apply fully_shard to the entire model
        fully_shard(m, mp_policy=mp_policy)
        
        # Print memory usage after FSDP
        post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
        dist_print(f"Post-FSDP2 memory use = {post_mem_use}MiB")
        dist_print(f"FSDP2-Wrapped Model:\n{m}")

    # # enable torch.compile for competitive performance
    m = torch.compile(m)

    # create sample input
    x = torch.randn(args.batch_size, 4096, device="cuda", dtype=torch.bfloat16)
    
    # optimizer must be created after FSDP wrapping
    optimizer = torch.optim.SGD(m.parameters(), lr=args.lr)

    # Memory profiling setup
    torch.cuda.reset_peak_memory_stats()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    start.record()

    # warmup
    for i in range(20):
        optimizer.zero_grad()
        y = m(x)
        loss = y.sum()
        loss.backward()
        optimizer.step()

    # toy training loop
    for i in range(args.num_iters):
        optimizer.zero_grad()
        y = m(x)
        loss = y.sum()
        loss.backward()
        optimizer.step()
        
        if LOCAL_RANK == 0 and i % 5 == 0:
            print(f"Iteration {i}, Loss: {loss.item():.4f}")

    # Performance metrics
    end.record()
    torch.cuda.synchronize()
    peak_mem = torch.cuda.max_memory_allocated()
    train_time = start.elapsed_time(end) / 1000.0
    
    dist_print(f"Training Time: {train_time:.3f}s")
    dist_print(f"Avg. Iter. Time: {train_time / args.num_iters:.3f}s")
    dist_print(f"Peak Memory Use: {peak_mem * 1e-6:.1f}MBs")
    dist_print(f"FP8 enabled: {args.fp8}")

    # Cleanup
    if WORLD_SIZE > 1:
        dist.destroy_process_group()

if __name__ == "__main__":
    main()

Env: CUDA: 12.8 Pytorch Version: 2.7.0+cu128 torchao 0.11.0 Hardware: H100

cc: @danielvegamyhre

ajWithNucleus avatar May 27 '25 19:05 ajWithNucleus

@ajWithNucleus 2 things:

  1. Hmm we should see improved throughput in your example. I'll try to repro with your code snippet and take a look at the trace. It might have to do with how your benchmarking code is set up. Also, seems like you're doing fully_shard both on indiviudal layers as well as the full model afterwards, is that intentional?

  2. As far as peak memory, with this simple 2 linear layer model, there will be no memory savings with fp8. The memory savings come from using torch.compile doing optimizations like CSE, cross-layer fusions, re-using buffers, etc. - not from using fp8. Dynamic float8 quant on its own will actually use some additional memory, but will increase throughput. This is because in training, we are using dynamic float8 quantization where the weights and activations stay in bf16, but we dynamically compute fp8 versions of each just to perform fp8 GEMMs, which are faster than bf16 GEMMs. The output of the fp8 gemm is in bf16.

danielvegamyhre avatar May 27 '25 20:05 danielvegamyhre

cc @vkuzo as well

danielvegamyhre avatar May 27 '25 20:05 danielvegamyhre

Got it, thanks a lot for the clarification. @danielvegamyhre so if my understanding is correct this is different from transformer engine's impl, where activations might be stored in FP8?

ajWithNucleus avatar May 28 '25 00:05 ajWithNucleus

Try filtering out the last linear nn.Linear(16384, 128) in your module_filter_fn which has such a small N dim that it will have a substantial slowdown with float8. That is probably tanking the overal net speedup from float8.

hello @danielvegamyhre, I tried filtering out the last linear layer. But, it didn't work as expected.

Configuration Training Time 1 (s) Training Time 2 (s) Training Time 3 (s)
Without convert_to_float8_training and without torch.compile 6.6573 6.6762 6.6804
With convert_to_float8_training and with torch.compile 7.5983 7.5895 7.6059

Could you please share your test results for the convert_to_float8_training and torch.compile? I want to know whether this is a bug or my env issue.

Yzx835 avatar May 28 '25 08:05 Yzx835

@Yzx835 , do you see better results if you use linears without bias in your benchmark? torch.nn.Linear(M, K, bias=False)

We currently add bias in a separate kernel from the float8 matrix multiply. If there are multiple matrix multiplies in a row (such as in the benchmark above, but rarely occuring in modern models), there isn't anything to fuse the bias to so there is a performance penalty.

vkuzo avatar May 28 '25 10:05 vkuzo

by the way, we should update the README.md snippet to something less toy and more representable, and mention the expected performance of the snippet

vkuzo avatar May 28 '25 10:05 vkuzo

@Yzx835 , do you see better results if you use linears without bias in your benchmark? torch.nn.Linear(M, K, bias=False)

We currently add bias in a separate kernel from the float8 matrix multiply. If there are multiple matrix multiplies in a row (such as in the benchmark above, but rarely occuring in modern models), there isn't anything to fuse the bias to so there is a performance penalty.

I attempted to remove the bias, but the results were quite similar to those with the bias present.

yqyao avatar May 29 '25 11:05 yqyao

I see, thank you. I am unavailable today/tomorrow but let me take a further look early next week. Thank you for writing this up!

vkuzo avatar May 29 '25 12:05 vkuzo

if you modify your benchmarking code to ignore the first iteration (which spends a non-trivial amount of time on torch.compile warmup), and slightly increase the shapes then you should see a speedup. Here is a diff to modify the README.md example to demonstrate this: https://gist.github.com/vkuzo/f7e642f52096f8873edbee15a065236f

I'm having some issues with internet connectivity on my dev machine at the moment, as soon as I resolve them I'll put up a PR to modify the README.md to make this easier.

vkuzo avatar Jun 03 '25 13:06 vkuzo

https://github.com/pytorch/ao/pull/2291

vkuzo avatar Jun 03 '25 14:06 vkuzo

@vkuzo Thank you!

I tested the new code, which ignores the first few iterations. I did see a speedup.

fp8 training, torch compile, Training time: 22.640191793441772 without fp8 training, torch compile, Training time: 34.18098497390747 fp8 training, without torch compile, Training time: 47.06538534164429 without fp8 training, without torch compile, Training time: 34.08416795730591

this is my test code: (with larger model size and more train iterations)

import time

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
    raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
M, K, N = 8192, 16384, 16384
m = nn.Sequential(
    nn.Linear(K, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, N, bias=False),
    nn.Linear(N, 128, bias=False),
).bfloat16().cuda()
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)


# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True


# convert specified `torch.nn.Linear` modules to `Float8Linear`
print('fp8 training')
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# print('without fp8 training')

# enable torch.compile for competitive performance
# print('torch compile')
# m = torch.compile(m)
print('without torch compile')

# warm up torch.compile for a clean training time measurement
for _ in range(10):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()
    optimizer.step()

torch.cuda.synchronize()
start_time = time.time()

# toy training loop
for _ in range(100):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()
    optimizer.step()

torch.cuda.synchronize()
end_time = time.time()
print("Training time:", end_time - start_time)

Yzx835 avatar Jun 05 '25 07:06 Yzx835