pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

HuggingFace BertForMaskedLM - Log Softmax Fusion with Autocast has bad perf

Open kevinstephano opened this issue 3 years ago • 6 comments

🐛 Describe the bug

Benchmark commandline:

PYTORCH_NVFUSER_DUMP=python_definition,fusion_args python -u benchmarks/huggingface.py --training -d cuda --fast --backend nvprims_nvfuser --skip-accuracy-check --performance --only BertForMaskedLM --amp

Fusion Repro:

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Half)
    T1 = fd.define_tensor(symbolic_sizes=[-1, 1, 1, -1], contiguous=[True, True, True, True], dtype=DataType.Float)
    T2 = fd.ops.view(T0, original_shape=[768, 128, 128], new_shape=[64, 12, 128, 128])
    S3 = fd.define_constant(1)
    T4 = fd.ops.mul(T1, S3)
    T5 = fd.ops.cast(T2, dtype=DataType.Float)
    S6 = fd.define_constant(1.00000)
    T7 = fd.ops.sub(S6, T4)
    S8 = fd.define_constant(8.00000)
    T9 = fd.ops.div(T5, S8)
    S10 = fd.define_constant(-10000.0)
    T11 = fd.ops.mul(T7, S10)
    T12 = fd.ops.cast(T9, dtype=DataType.Half)
    T13 = fd.ops.broadcast_in_dim(T11, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
    T14 = fd.ops.cast(T12, dtype=DataType.Float)
    T15 = fd.ops.add(T14, T13)
    T16 = fd.ops.max(T15, axes=[3], keepdim=False, dtype=DataType.Null)
    T17 = fd.ops.broadcast_in_dim(T16, output_shape=[64, 12, 128, 1], broadcast_dims=[0, 1, 2])
    T18 = fd.ops.broadcast_in_dim(T17, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
    T19 = fd.ops.sub(T15, T18)
    T20 = fd.ops.exp(T19)
    T21 = fd.ops.sum(T20, axes=[3], keepdim=False, dtype=DataType.Null)
    T22 = fd.ops.broadcast_in_dim(T21, output_shape=[64, 12, 128, 1], broadcast_dims=[0, 1, 2])
    T23 = fd.ops.broadcast_in_dim(T22, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
    T24 = fd.ops.div(T20, T23)
    T25 = fd.ops.rand_like(T24)
    S26 = fd.define_constant(0.100000)
    T27 = fd.ops.gt(T25, S26)
    T28 = fd.ops.cast(T27, dtype=DataType.Float)
    T29 = fd.ops.mul(T28, T24)
    S30 = fd.define_constant(1.11111)
    T31 = fd.ops.mul(T29, S30)
    T32 = fd.ops.cast(T31, dtype=DataType.Half)
    T33 = fd.ops.broadcast_in_dim(T32, output_shape=[64, 12, 128, 128], broadcast_dims=[0, 1, 2, 3])
    fd.add_output(T11)
    fd.add_output(T24)
    fd.add_output(T27)
    fd.add_output(T33)

fs = Fusion()

with FusionDefinition(fs) as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn(768, 128, 128, device='cuda', dtype=torch.float16),
    torch.randn(64, 1, 1, 128, device='cuda', dtype=torch.float32),
]

for _ in range(5):
    out = fs.execute(inputs)

Performance output on A100:

rguments for fusion1:
Inputs:
tensor dtype: __half sizes: (768, 128, 128, ) stride: (16384, 128, 1, ) pointer: 0x7fd5fa000000
tensor dtype: float sizes: (64, 1, 1, 128, ) stride: (128, 128, 128, 1, ) pointer: 0x7fd5fba00000
Outputs:
Launch Parameters: BlockDim.x = -1, BlockDim.y = 16, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
kernel1 run in 0.254976 ms, achieved: 444.402 GB/s

Versions

csarofeen/torchbenchPerf

kevinstephano avatar Oct 18 '22 07:10 kevinstephano

A couple things to try:

  1. Remove the top view from the Fusion.
  2. Remove fd.add_output(T11). It's not clear to me that saving this tensor is necessary.

kevinstephano avatar Oct 18 '22 08:10 kevinstephano

Removing the view clearly matters for perf.

kevinstephano avatar Oct 19 '22 04:10 kevinstephano

Another example to think about:

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
    T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
    T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
    T3 = fd.ops.cast(T2, dtype=DataType.Float)
    T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
    T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
    T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    T7 = fd.ops.sub(T3, T6)
    T8 = fd.ops.exp(T7)
    T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
    T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
    T11 = fd.ops.log(T10)
    T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    T13 = fd.ops.sub(T7, T12)
    T14 = fd.ops.cast(T13, dtype=DataType.Half)
    T15 = fd.ops.cast(T14, dtype=DataType.Float)
    fd.add_output(T1)
    fd.add_output(T14)
    fd.add_output(T15)

fs = Fusion()

with FusionDefinition(fs) as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    #torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
    torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]

for _ in range(5):
    out = fs.execute(inputs)

It looks like bad perf happens at this point:

T13 = fd.ops.sub(T7, T12)

kevinstephano avatar Oct 19 '22 04:10 kevinstephano

The final omission of the final subtract seems to change the registers needed from 87 to 50 which is large enough to change the kernel time from 3.5 ms to ~1ms.

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
    #T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
    #T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
    T3 = fd.ops.cast(T0, dtype=DataType.Float)
    T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
    T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
    T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    T7 = fd.ops.sub(T3, T6)
    T8 = fd.ops.exp(T7)
    T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
    T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
    T11 = fd.ops.log(T10)
    T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    #T13 = fd.ops.sub(T7, T12)
    #T14 = fd.ops.cast(T12, dtype=DataType.Half)
    #T15 = fd.ops.cast(T12, dtype=DataType.Float)
    #fd.add_output(T1)
    #fd.add_output(T14)
    fd.add_output(T13)

fs = Fusion()

with FusionDefinition(fs) as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    #torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
    torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]

for _ in range(5):
    out = fs.execute(inputs)

kevinstephano avatar Oct 19 '22 08:10 kevinstephano

@kevinstephano How exactly did you modify the fusion? Commenting out the line for T13 wouldn't work as T13 is added as an output.

naoyam avatar Oct 19 '22 17:10 naoyam

I meant to show that this version has much better performance without the sub at the end of the fusion:

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
    #T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
    #T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
    T3 = fd.ops.cast(T0, dtype=DataType.Float)
    T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
    T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
    T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    T7 = fd.ops.sub(T3, T6)
    T8 = fd.ops.exp(T7)
    T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
    T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
    T11 = fd.ops.log(T10)
    T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    #T13 = fd.ops.sub(T7, T12)
    fd.add_output(T12)

fs = Fusion()

with FusionDefinition(fs) as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    #torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
    torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]

for _ in range(5):
    out = fs.execute(inputs)

Whereas this version is significantly worse:

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Half)
    #T1 = fd.ops.view(T0, original_shape=[8192, 30522], new_shape=[64, 128, 30522])
    #T2 = fd.ops.view(T1, original_shape=[64, 128, 30522], new_shape=[-1, 30522])
    T3 = fd.ops.cast(T0, dtype=DataType.Float)
    T4 = fd.ops.max(T3, axes=[1], keepdim=False, dtype=DataType.Null)
    T5 = fd.ops.broadcast_in_dim(T4, output_shape=[8192, 1], broadcast_dims=[0])
    T6 = fd.ops.broadcast_in_dim(T5, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    T7 = fd.ops.sub(T3, T6)
    T8 = fd.ops.exp(T7)
    T9 = fd.ops.sum(T8, axes=[1], keepdim=False, dtype=DataType.Null)
    T10 = fd.ops.broadcast_in_dim(T9, output_shape=[8192, 1], broadcast_dims=[0])
    T11 = fd.ops.log(T10)
    T12 = fd.ops.broadcast_in_dim(T11, output_shape=[8192, 30522], broadcast_dims=[0, 1])
    T13 = fd.ops.sub(T7, T12)
    fd.add_output(T13)

fs = Fusion()

with FusionDefinition(fs) as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    #torch.randn(64, 12, 128, 128, device='cuda', dtype=torch.float16),
    torch.randn(8192, 30522, device='cuda', dtype=torch.float16),
]

for _ in range(5):
    out = fs.execute(inputs)

kevinstephano avatar Oct 26 '22 17:10 kevinstephano