pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

Worse performance than ATen: aten.any.default

Open IvanYashchuk opened this issue 1 year ago • 0 comments

🐛 Describe the bug

aten.any

aten.any is used in AllenaiLongformerBase, BartForConditionalGeneration, BlenderbotSmallForConditionalGeneration, M2M100ForConditionalGeneration, MBartForConditionalGeneration, PLBartForConditionalGeneration, PegasusForConditionalGeneration.

Here's the result comparing to ATen:

benchmark geomean 20th percentile 50th percentile 80th percentile
HuggingFace 0.92 0.85 0.93 0.98
Torchbench 0.96 0.94 0.96 0.98

Both ATen and nvFuser path are using CUDA Graphs.

git clone https://gitlab-master.nvidia.com/iyashchuk/aten_ops_perf.git
cd aten_ops_perf
python aten_ops_perf.py --suite huggingface --dtype float32 --max-samples 100 --op aten.any.default

Here are logs with PYTORCH_NVFUSER_DUMP=ptxas_verbose,launch_param,scheduler_params: (dump_eff_bandwidth doesn't work because of CUDA Graphs)

HuggingFace
Input sample #0
arg 0: torch.Size([4, 128, 1024]), (131072, 1024, 1)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 524288
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(256, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================


===== Reduction Stats ========
total_reduction_numel: 1 * 524288
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(256, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorIbLi3EEENS0_IbLi0EEENS0_IxLi1EEES3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorIbLi3EEENS0_IbLi0EEENS0_IxLi1EEES3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 22 registers, 16 bytes smem, 424 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Input sample #1
arg 0: torch.Size([64, 128, 512]), (65536, 512, 1)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 4194304
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(512, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 512, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 32, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorIbLi3EEENS0_IbLi0EEENS0_IxLi1EEES3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorIbLi3EEENS0_IbLi0EEENS0_IxLi1EEES3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 22 registers, 16 bytes smem, 424 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 512, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 32, GridDim.y = -1, GridDim.z = -1, Smem Size = 4096
Launch Parameters: BlockDim.x = 512, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 32, GridDim.y = -1, GridDim.z = -1, Smem Size = 4096
Launch Parameters: BlockDim.x = 512, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 32, GridDim.y = -1, GridDim.z = -1, Smem Size = 4096
Launch Parameters: BlockDim.x = 512, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 32, GridDim.y = -1, GridDim.z = -1, Smem Size = 4096
Launch Parameters: BlockDim.x = 512, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 32, GridDim.y = -1, GridDim.z = -1, Smem Size = 4096
Input sample #2
arg 0: torch.Size([8, 128, 1024]), (131072, 1024, 1)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 1048576
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(256, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Input sample #3
arg 0: torch.Size([2, 1024, 1024]), (1048576, 1024, 1)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 2097152
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(320, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 320, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 27, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

Launch Parameters: BlockDim.x = 320, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 27, GridDim.y = -1, GridDim.z = -1, Smem Size = 2560
Launch Parameters: BlockDim.x = 320, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 27, GridDim.y = -1, GridDim.z = -1, Smem Size = 2560
Launch Parameters: BlockDim.x = 320, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 27, GridDim.y = -1, GridDim.z = -1, Smem Size = 2560
Launch Parameters: BlockDim.x = 320, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 27, GridDim.y = -1, GridDim.z = -1, Smem Size = 2560
Launch Parameters: BlockDim.x = 320, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 27, GridDim.y = -1, GridDim.z = -1, Smem Size = 2560
Input sample #4
arg 0: torch.Size([1024]), (1,)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 1024
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(16, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================


===== Reduction Stats ========
total_reduction_numel: 1 * 1024
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(16, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel2ENS_6TensorIbLi1EEENS0_IbLi0EEENS0_IxLi1EEES3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel2ENS_6TensorIbLi1EEENS0_IbLi0EEENS0_IxLi1EEES3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 22 registers, 16 bytes smem, 408 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Input sample #5
arg 0: torch.Size([8, 128, 768]), (98304, 768, 1)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 786432
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(256, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 12, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 12, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 12, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 12, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 12, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 12, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Input sample #6
arg 0: torch.Size([2, 128, 1024]), (131072, 1024, 1)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 262144
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(128, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 1024
Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 1024
Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 1024
Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 1024
Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16, GridDim.y = -1, GridDim.z = -1, Smem Size = 1024
Torchbench
Input sample #0
arg 0: torch.Size([4, 512, 768]), (393216, 768, 1)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 1572864
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(256, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 24, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================


===== Reduction Stats ========
total_reduction_numel: 1 * 1572864
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(256, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x /  pad to warp / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 24, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorIbLi3EEENS0_IbLi0EEENS0_IxLi1EEES3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorIbLi3EEENS0_IbLi0EEENS0_IxLi1EEES3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 22 registers, 16 bytes smem, 424 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 24, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 24, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 24, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 24, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Launch Parameters: BlockDim.x = 256, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 24, GridDim.y = -1, GridDim.z = -1, Smem Size = 2048
Input sample #1
arg 0: torch.Size([2048]), (1,)
kwargs: {}

===== Reduction Stats ========
total_reduction_numel: 1 * 2048
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(16, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================


===== Reduction Stats ========
total_reduction_numel: 1 * 2048
total_iteration_numel: 1
vectorize_factor: 16
n_tensor_inputs: 1
max_input_dtype_size: 1
block(16, 1, 1)

===== Reduction Parameters ========

Red On Fastest Dim

Iteration Domain: blockIdx.y /
Inner Reduction Domain: cross block - threadIdx.x / cross grid - blockIdx.x / split grid dim / split grid dimension / vectorize / factor 16
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel2ENS_6TensorIbLi1EEENS0_IbLi0EEENS0_IxLi1EEES3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel2ENS_6TensorIbLi1EEENS0_IbLi0EEENS0_IxLi1EEES3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 22 registers, 16 bytes smem, 408 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 128
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8, GridDim.y = -1, GridDim.z = -1, Smem Size = 128

any is implemented as ne(sum(input), False) here: https://github.com/pytorch/pytorch/blob/f884e817d448228cb8b0685f774ede1d8207ff72/torch/_refs/init.py#L2055-L2069

While ATen implemented reduction kernel with OR here: https://github.com/pytorch/pytorch/blob/f884e817d448228cb8b0685f774ede1d8207ff72/aten/src/ATen/native/cuda/ReduceLogicKernel.cu#L28-L29

Versions

Checked on upstream master.

IvanYashchuk avatar Nov 03 '22 14:11 IvanYashchuk