pytorch
pytorch copied to clipboard
HuggingFace BertForMaskedLM - Log Softmax Fusion with Autocast has bad perf
🐛 Describe the bug
Benchmark commandline:
PYTORCH_NVFUSER_DUMP=python_definition,fusion_args python -u benchmarks/huggingface.py --training -d cuda --fast --backend nvprims_nvfuser --skip-accuracy-check --performance --only BertForMaskedLM --amp
Fusion Repro:
import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Half)
T1 = fd.define_tensor(symbolic_sizes=[-1, 1, 1, -1], contiguous=[True, True, True, True], dtype=DataType.Float)
T2 = fd.ops.view(T0, original_shape=[768, 128, 128], new_shape=[64, 12, 128, 128])
S3 = fd.define_constant(1)
T4 = fd.ops.mul(T1, S3)
T5 = fd.ops.cast(T2, dtype=DataType.Float)
S6 = fd.define_constant(1.00000)
T7 = fd.ops.sub(S6, T4)
S8 = fd.define_constant(8.00000)
T9 = fd.ops.div(T5, S8)
S10 = fd.define_constant(-10000.0)
T11 = fd.ops.mul(T7, S10)
T12 = fd.ops.cast(T9, dtype=DataType.Half)
T13 = fd.ops.broadcast_in_dim(T11, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
T14 = fd.ops.cast(T12, dtype=DataType.Float)
T15 = fd.ops.add(T14, T13)
T16 = fd.ops.max(T15, axes=[3], keepdim=False, dtype=DataType.Null)
T17 = fd.ops.broadcast_in_dim(T16, output_shape=[64, 12, 128, 1], broadcast_dims=[0, 1, 2])
T18 = fd.ops.broadcast_in_dim(T17, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
T19 = fd.ops.sub(T15, T18)
T20 = fd.ops.exp(T19)
T21 = fd.ops.sum(T20, axes=[3], keepdim=False, dtype=DataType.Null)
T22 = fd.ops.broadcast_in_dim(T21, output_shape=[64, 12, 128, 1], broadcast_dims=[0, 1, 2])
T23 = fd.ops.broadcast_in_dim(T22, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
T24 = fd.ops.div(T20, T23)
T25 = fd.ops.rand_like(T24)
S26 = fd.define_constant(0.100000)
T27 = fd.ops.gt(T25, S26)
T28 = fd.ops.cast(T27, dtype=DataType.Float)
T29 = fd.ops.mul(T28, T24)
S30 = fd.define_constant(1.11111)
T31 = fd.ops.mul(T29, S30)
T32 = fd.ops.cast(T31, dtype=DataType.Half)
T33 = fd.ops.broadcast_in_dim(T32, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
fd.add_output(T11)
fd.add_output(T24)
fd.add_output(T27)
fd.add_output(T33)
fs = Fusion()
with FusionDefinition(fs) as fd:
nvfuser_fusion_id0(fd)
inputs = [
torch.randn(768, 128, 128, device='cuda', dtype=torch.float16),
torch.randn(64, 1, 1, 128, device='cuda', dtype=torch.float32),
]
for _ in range(5):
out = fs.execute(inputs)
Performance output on A100:
rguments for fusion1:
Inputs:
tensor dtype: __half sizes: (768, 128, 128, ) stride: (16384, 128, 1, ) pointer: 0x7fd5fa000000
tensor dtype: float sizes: (64, 1, 1, 128, ) stride: (128, 128, 128, 1, ) pointer: 0x7fd5fba00000
Outputs:
Launch Parameters: BlockDim.x = -1, BlockDim.y = 16, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
kernel1 run in 0.254976 ms, achieved: 444.402 GB/s
Versions
csarofeen/torchbenchPerf
A couple things to try:
- Remove the top
viewfrom the Fusion. - Remove
fd.add_output(T11). It's not clear to me that saving this tensor is necessary.
Removing the view clearly matters for perf.
Another example to think about:
import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
T3 = fd.ops.cast(T2, dtype=DataType.Float)
T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
T7 = fd.ops.sub(T3, T6)
T8 = fd.ops.exp(T7)
T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
T11 = fd.ops.log(T10)
T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
T13 = fd.ops.sub(T7, T12)
T14 = fd.ops.cast(T13, dtype=DataType.Half)
T15 = fd.ops.cast(T14, dtype=DataType.Float)
fd.add_output(T1)
fd.add_output(T14)
fd.add_output(T15)
fs = Fusion()
with FusionDefinition(fs) as fd:
nvfuser_fusion_id0(fd)
inputs = [
#torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]
for _ in range(5):
out = fs.execute(inputs)
It looks like bad perf happens at this point:
T13 = fd.ops.sub(T7, T12)
The final omission of the final subtract seems to change the registers needed from 87 to 50 which is large enough to change the kernel time from 3.5 ms to ~1ms.
import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
#T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
#T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
T3 = fd.ops.cast(T0, dtype=DataType.Float)
T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
T7 = fd.ops.sub(T3, T6)
T8 = fd.ops.exp(T7)
T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
T11 = fd.ops.log(T10)
T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
#T13 = fd.ops.sub(T7, T12)
#T14 = fd.ops.cast(T12, dtype=DataType.Half)
#T15 = fd.ops.cast(T12, dtype=DataType.Float)
#fd.add_output(T1)
#fd.add_output(T14)
fd.add_output(T13)
fs = Fusion()
with FusionDefinition(fs) as fd:
nvfuser_fusion_id0(fd)
inputs = [
#torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]
for _ in range(5):
out = fs.execute(inputs)
@kevinstephano How exactly did you modify the fusion? Commenting out the line for T13 wouldn't work as T13 is added as an output.
I meant to show that this version has much better performance without the sub at the end of the fusion:
import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
#T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
#T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
T3 = fd.ops.cast(T0, dtype=DataType.Float)
T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
T7 = fd.ops.sub(T3, T6)
T8 = fd.ops.exp(T7)
T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
T11 = fd.ops.log(T10)
T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
#T13 = fd.ops.sub(T7, T12)
fd.add_output(T12)
fs = Fusion()
with FusionDefinition(fs) as fd:
nvfuser_fusion_id0(fd)
inputs = [
#torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]
for _ in range(5):
out = fs.execute(inputs)
Whereas this version is significantly worse:
import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
#T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
#T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
T3 = fd.ops.cast(T0, dtype=DataType.Float)
T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
T7 = fd.ops.sub(T3, T6)
T8 = fd.ops.exp(T7)
T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
T11 = fd.ops.log(T10)
T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
T13 = fd.ops.sub(T7, T12)
fd.add_output(T13)
fs = Fusion()
with FusionDefinition(fs) as fd:
nvfuser_fusion_id0(fd)
inputs = [
#torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]
for _ in range(5):
out = fs.execute(inputs)