CPU Overhead of te.Linear FP8 Layers
Hi,
we are looking into training some transformer models with FP8 and we see a lot of overhead on the CPU side when te.Linear layers are scheduled in the forward pass of the net.
I'm using the following:
- H100 GPUs with 12.2, V12.2.140
- TE version git+https://github.com/NVIDIA/TransformerEngine.git@cf6fc898286e4ad347ff88925c88663324e2b87d
- PyTorch 2.1.0 with cuDNN 8906
Concretely, running a toy model we see the FP8 model being slightly faster at around 300ms per iteration vs the BF16 model with 320ms per iteration. We're always using te.Linear layers, regardless of whether we're doing FP8 or BF16.
However, looking at the profiles we see that the forward pass of the FP8 model (wall duration roughly 140ms) is much slower than the forward pass on the BF16 model (wall duration roughly 77ms). The GPU is also idle a lot of the time for the FP8 forward pass. GPU utilization is near 100% for the backward pass for both models.
Looking at the CPU side it seems like scheduling a te.Linear layer in FP8 takes more than 2x more time compared to scheduling the te.Linear layer in BF16.
Attached a screenshot of part of the forward pass of the FP8 model:
On the BF16 model:
I think this is related to #445 which observed similar behavior. Do you have any suggestions about how to optimize this?
Code to reproduce: Call with
python fp8_minimal_example.py --dtype bf16
python fp8_minimal_example.py --dtype fp8
Add --profile to generate a PyTorch profile.
import argparse
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
class TEBlock(nn.Module):
def __init__(self, hidden_size: int, mlp_ratio: float):
super().__init__()
linear = te.Linear
# timestep modulation predicts several parameters conditioned on the timestep
self.timestep_modulation = linear(hidden_size, 6 * hidden_size, bias=True)
# simulate self attention layer for getting qkv embedding
self.self_attn = linear(hidden_size, 3 * hidden_size, bias=False)
# simulate cross attention layer for getting qkv embedding
self.cross_attn_q = linear(hidden_size, hidden_size, bias=False)
self.cross_attn_kv = linear(hidden_size, 2 * hidden_size, bias=False)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
linear(hidden_size, mlp_hidden_dim),
linear(mlp_hidden_dim, hidden_size),
)
def forward(self, x):
# simulating predicting parameters for timestep modulation
shift, scale, _, _, _, _ = self.timestep_modulation(x).chunk(6, dim=-1)
x = shift * x / scale
# simulating self attention
sa_q, _, _ = self.self_attn(x).chunk(3, dim=-1)
x = x + sa_q
# simulating cross attention
ca_q = self.cross_attn_q(x)
ca_k, _ = self.cross_attn_kv(x).chunk(2, dim=-1)
x = x + ca_q + ca_k
# run MLP
x = x + self.mlp(x)
return x
class TEModel(nn.Module):
def __init__(
self,
num_blocks: int,
hidden_size: int,
mlp_ratio: int,
):
super().__init__()
self.blocks = torch.nn.ModuleList()
for _ in range(num_blocks):
self.blocks.append(TEBlock(hidden_size, mlp_ratio))
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for block_idx, block in enumerate(self.blocks):
with torch.autograd.profiler.record_function(f"block_{block_idx}"):
x = block(x)
return x
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Command line arguments")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--dtype", type=str, default="bf16", help="Data type")
parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size")
parser.add_argument("--depth", type=int, default=32, help="Depth")
parser.add_argument("--seq_length", type=int, default=1024, help="Sequence length")
parser.add_argument("--mlp_ratio", type=float, default=4, help="mlp_ratio")
parser.add_argument("--profile", action="store_true", help="Run PyTorch profiler")
args = parser.parse_args()
if args.dtype == "bf16":
dtype = torch.bfloat16
cast_type = "bf16"
elif args.dtype == "fp8":
dtype = torch.float32
cast_type = "fp8"
else:
print("Invalid data type, must be either bf16 or fp8")
exit(0)
# Define FP8 recipe
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
# Generate random model input and target for MSE loss
model_input = (
torch.rand(args.batch_size, args.seq_length, args.hidden_size)
.cuda()
.to(dtype=dtype)
)
target = (
torch.rand(args.batch_size, args.seq_length, args.hidden_size)
.cuda()
.to(dtype=dtype)
)
criterion = torch.nn.MSELoss()
# Define the model and optimizer
model = TEModel(args.depth, args.hidden_size, args.mlp_ratio)
model.to(dtype=torch.float32).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Define autocast kwargs
if cast_type == "fp8":
autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
autocast = te.fp8_autocast
elif cast_type == "bf16":
autocast_args = {
"device_type": "cuda",
"enabled": True,
"dtype": torch.bfloat16,
}
autocast = torch.autocast
# Run PyTorch profile
if args.profile:
with autocast(**autocast_args):
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
skip_first=5, wait=10, warmup=5, active=3
),
) as prof:
for _ in range(25):
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
prof.step()
profile_name = cast_type + "_bs_" + str(args.batch_size)
profile_name = f"_profile_{profile_name}.json"
prof.export_chrome_trace(profile_name)
print(f"Saved profile as {profile_name}")
# Time model iterations
else:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
timing_iters = 50
# warmup iterations
for _ in range(10):
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# estimate memory usage
free, total = torch.cuda.mem_get_info()
memory = (total - free) / 1024**2
# benchmark
torch.cuda.synchronize()
start.record()
for _ in range(timing_iters):
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
end.record()
mean_time = start.elapsed_time(end) / timing_iters
print(f"Mean time {mean_time} ms per iteration ({memory} GB used)")
Hi @tohinz could you give this a try ? I have no easy access to a H100, but it had positive impact on an L4 (which is pretty different :D)
python fp8_minimal_example.py --dtype fp8 vs python fp8_minimal_example.py --dtype fp8 --cuda_graph
For reference: https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
import argparse
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
class TEBlock(nn.Module):
def __init__(self, hidden_size: int, mlp_ratio: float):
super().__init__()
linear = te.Linear
# timestep modulation predicts several parameters conditioned on the timestep
self.timestep_modulation = linear(hidden_size, 6 * hidden_size, bias=True)
# simulate self attention layer for getting qkv embedding
self.self_attn = linear(hidden_size, 3 * hidden_size, bias=False)
# simulate cross attention layer for getting qkv embedding
self.cross_attn_q = linear(hidden_size, hidden_size, bias=False)
self.cross_attn_kv = linear(hidden_size, 2 * hidden_size, bias=False)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
linear(hidden_size, mlp_hidden_dim),
linear(mlp_hidden_dim, hidden_size),
)
def forward(self, x):
# simulating predicting parameters for timestep modulation
shift, scale, _, _, _, _ = self.timestep_modulation(x).chunk(6, dim=-1)
x = shift * x / scale
# simulating self attention
sa_q, _, _ = self.self_attn(x).chunk(3, dim=-1)
x = x + sa_q
# simulating cross attention
ca_q = self.cross_attn_q(x)
ca_k, _ = self.cross_attn_kv(x).chunk(2, dim=-1)
x = x + ca_q + ca_k
# run MLP
x = x + self.mlp(x)
return x
class TEModel(nn.Module):
def __init__(
self,
num_blocks: int,
hidden_size: int,
mlp_ratio: int,
):
super().__init__()
self.blocks = torch.nn.ModuleList()
for _ in range(num_blocks):
self.blocks.append(TEBlock(hidden_size, mlp_ratio))
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for block_idx, block in enumerate(self.blocks):
with torch.autograd.profiler.record_function(f"block_{block_idx}"):
x = block(x)
return x
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Command line arguments")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--dtype", type=str, default="bf16", help="Data type")
parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size")
parser.add_argument("--depth", type=int, default=32, help="Depth")
parser.add_argument("--seq_length", type=int, default=1024, help="Sequence length")
parser.add_argument("--mlp_ratio", type=float, default=4, help="mlp_ratio")
parser.add_argument("--profile", action="store_true", help="Run PyTorch profiler")
parser.add_argument("--cuda_graph", action="store_true", help="Run with cuda graph capture")
args = parser.parse_args()
if args.dtype == "bf16":
dtype = torch.bfloat16
cast_type = "bf16"
elif args.dtype == "fp8":
dtype = torch.float32
cast_type = "fp8"
else:
print("Invalid data type, must be either bf16 or fp8")
exit(0)
# Define FP8 recipe
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
# Generate random model input and target for MSE loss
model_input = (
torch.rand(args.batch_size, args.seq_length, args.hidden_size)
.cuda()
.to(dtype=dtype)
)
target = (
torch.rand(args.batch_size, args.seq_length, args.hidden_size)
.cuda()
.to(dtype=dtype)
)
criterion = torch.nn.MSELoss()
# Define the model and optimizer
model = TEModel(args.depth, args.hidden_size, args.mlp_ratio)
model.to(dtype=torch.float32).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, capturable=True)
# Define autocast kwargs
if cast_type == "fp8":
autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
autocast = te.fp8_autocast
elif cast_type == "bf16":
autocast_args = {
"device_type": "cuda",
"enabled": True,
"dtype": torch.bfloat16,
}
autocast = torch.autocast
# Run PyTorch profile
if args.profile:
with autocast(**autocast_args):
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
skip_first=5, wait=10, warmup=5, active=3
),
) as prof:
for _ in range(25):
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
prof.step()
profile_name = cast_type + "_bs_" + str(args.batch_size)
profile_name = f"_profile_{profile_name}.json"
prof.export_chrome_trace(profile_name)
print(f"Saved profile as {profile_name}")
# Time model iterations
else:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
timing_iters = 50
# estimate memory usage
free, total = torch.cuda.mem_get_info()
memory = (total - free) / 1024**2
def inner():
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
# warmup iterations
for _ in range(10):
inner()
torch.cuda.current_stream().wait_stream(s)
if args.cuda_graph:
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
inner()
# benchmark
torch.cuda.synchronize()
start.record()
for _ in range(timing_iters):
if args.cuda_graph:
g.replay()
else:
with autocast(**autocast_args):
output = model(model_input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
end.record()
mean_time = start.elapsed_time(end) / timing_iters
print(f"Mean time {mean_time} ms per iteration ({memory} GB used)")
Related to #575
Thanks, I tested this briefly and it does seem to reduce the CPU overhead. A brief look at a generated profile seems to show pretty much 100% GPU util also in the forward pass now.
I had a brief look at the CUDA Graph documentation and it states that: Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence has the same size and layout in every replay.
In our case, we will have different shapes for each batch since the inputs are images at different aspect ratios or resolutions so our sequence length will be different for different batches. Is there any way to make that scenario work with CUDA Graphs? If not, do you imagine there are other ways to reduce the CPU overhead?
Can you agree on a fixed set of aspect ratios and do a "switch" between the multiple graphs ? I know this is ugly and i am also not really happy with the proposal.
Otherwise the linked PR will apparently reduce the overhead, but likely not to the same significance as CUDA graphs did. I cannot comment on further optimizations as of now and we will have to see what the TE team recommends.
How do you currently create your batch? Do you pad the samples to create a bshd tensor or do you just concatenate the inputs without padding (the format we call thd in our DotProductAttention API)? If you pad already then the static shape requirement would not necessarily be as problematic. In case of the concatenated inputs, as long as your container is static size and large enough, you should also be able to make that work (you would just need to pass the sequence lengths properly).
Are you using the DotProductAttention from TE (or using the full TransformerLayer) as well or just replacing the Linear layers? In general, full TransformerLayer (or its smaller components like LayerNormLinear and LayerNormMLP) should be better overhead-wise compared to just the Linear layer, since some of the casts to FP8 would then be fused into preceding kernels and less kernel launches would be needed.
Can you agree on a fixed set of aspect ratios and do a "switch" between the multiple graphs ?
In theory yes, but we have a lot of different aspect ratios/shapes so this will be very ugly and not clear how/if it will generalize to future requirements.
How do you currently create your batch?
We don't do any padding currently since for each batch all elements have the same shape, but for different batches the resolution changes. So one batch might have a sequence length of 16x16=256 and another batch might have a sequence length of 12x20=240, etc.
Right now the aspect ratios are designed such that the total sequence length is always roughly similar between aspect ratios to keep the total compute similar between batches. We could of course pad each batch but that would lead to wasted compute. Without padding in this concrete example the sequence length would be roughly 256, with padding the sequence length would be 512 which would probably cost us more than any compute savings we can get here.
Are you using the DotProductAttention from TE (or using the full TransformerLayer) as well or just replacing the Linear layers? In general, full TransformerLayer (or its smaller components like LayerNormLinear and LayerNormMLP) should be better overhead-wise compared to just the Linear layer, since some of the casts to FP8 would then be fused into preceding kernels and less kernel launches would be needed.
We are using the TE DotProductAttention but overall we are limited to replacing Linear and LayerNorm layers individually by TE layers and can't make use of any of the fused layers since our network architecture doesn't follow the default LLM transformer layout. Overall our network is simliar to a Diffuson Transformer with additional terms for adaptive normalization based on timestep embeddings, cross-attention for additional conditioning, etc.
@tohinz not sure if you saw the linked PR. Feedback from you if this improves or solves your problems is very welcome.