pytorch
pytorch copied to clipboard
`views` in HF-Bart Self-Attention
🐛 Describe the bug
I have a horizontal fusion situation with reshape that I would like to understand if this can be fused. I think we have a knob to turn this on or a place to switch this. Jie might know. It would be good if this could be 1 kernel.
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id27(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float, is_cpu=False)
T1 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, -1], contiguous=[True, True, True, True], dtype=DataType.Float, is_cpu=False)
T2 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float, is_cpu=False)
T3 = fd.ops.reshape(T2, original_shape=[8, 1024, 1024], new_shape=[8, 1024, 16, 64])
T4 = fd.ops.permute(T3, dims=[0, 2, 1, 3])
T5 = fd.ops.reshape(T0, original_shape=[8, 1024, 1024], new_shape=[8, 1024, 16, 64])
T6 = fd.ops.permute(T5, dims=[0, 2, 1, 3])
T7 = fd.ops.reshape(T6, original_shape=[8, 16, 1024, 64], new_shape=[128, 1024, 64])
T8 = fd.ops.reshape(T1, original_shape=[8, 16, 1024, 64], new_shape=[128, 1024, 64])
T9 = fd.ops.reshape(T4, original_shape=[8, 16, 1024, 64], new_shape=[128, 1024, 64])
T10 = fd.ops.permute(T8, dims=[0, 2, 1])
fd.add_output(T9)
fd.add_output(T7)
fd.add_output(T10)
inputs = [
torch.randn(8, 1024, 1024, device='cuda'),
torch.randn(8, 1024, 16, 64, device='cuda'),
torch.randn(8, 1024, 1024, device='cuda'),
]
with FusionDefinition() as fd:
nvfuser_fusion_id27(fd)
for _ in range(5):
out = fd.execute(inputs)
Second case looks okay. Could you just double check it is okay with FP16 inputs?
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id21(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float, is_cpu=False)
T1 = fd.define_tensor(symbolic_sizes=[-1, 1, 1, -1], contiguous=[True, True, True, True], dtype=DataType.Float, is_cpu=False)
T2 = fd.ops.reshape(T0, original_shape=[128, 1024, 1024], new_shape=[8, 16, 1024, 1024])
T3 = fd.ops.broadcast_in_dim(T1, output_shape=[8, 16, 1024, 1024], broadcast_dims=[0, 1, 2, 3])
T4 = fd.ops.add(T2, T3)
T5 = fd.ops.reshape(T4, original_shape=[8, 16, 1024, 1024], new_shape=[128, 1024, 1024])
T6 = fd.ops.max(T5, axes=[2], keepdim=False, dtype=DataType.Null)
T7 = fd.ops.broadcast_in_dim(T6, output_shape=[128, 1024, 1], broadcast_dims=[0, 1])
T8 = fd.ops.broadcast_in_dim(T7, output_shape=[128, 1024, 1024], broadcast_dims=[0, 1, 2])
T9 = fd.ops.sub(T5, T8)
T10 = fd.ops.exp(T9)
T11 = fd.ops.sum(T10, axes=[2], keepdim=False, dtype=DataType.Null)
T12 = fd.ops.broadcast_in_dim(T11, output_shape=[128, 1024, 1], broadcast_dims=[0, 1])
T13 = fd.ops.broadcast_in_dim(T12, output_shape=[128, 1024, 1024], broadcast_dims=[0, 1, 2])
T14 = fd.ops.div(T10, T13)
fd.add_output(T14)
inputs = [
torch.randn(128, 1024, 1024, device='cuda'),
torch.randn(8, 1, 1, 1024, device='cuda'),
]
with FusionDefinition() as fd:
nvfuser_fusion_id21(fd)
for _ in range(5):
out = fd.execute(inputs)
Versions
TOT
This is a good test case. I think I know where the heuristic fails. This is probably related to https://github.com/csarofeen/pytorch/pull/2455
Despite having two reshapes, the second case produces 1 kernel with either float or half inputs. I'm not sure how that is happening since there are two reshapes, so it matches the "comment out C" pattern from https://github.com/csarofeen/pytorch/issues/2090#issuecomment-1398665847.
In the first case the segmenter is refusing to merge across the three connected components. I'm not sure this is due to reshapes: I don't think this is ever done: see this comment: https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/fusion_segmenter.cpp#L3275-L3279 For this particular case since the three groups are independent, wouldn't three kernels actually be preferable?