[ROCM] SDXL UNet, VAE failures from attention dispatch after tiling+decomposition exceeding shared memory limit.
What happened?
dispatch benchmark IR for easy repro (compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.mlir):
module attributes {hal.device.targets = [#hal.device.target<"rocm", {executable_targets = [#hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "none"}>], legacy_sync}>]} {
hal.executable private @main_dispatch_91 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "none"}>) {
hal.executable.export public @main_dispatch_91_attention_20x4096x64xf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 4, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {subgroup_size = 32 : index, translation_info = #iree_codegen.translation_info<LLVMGPUDistribute>, workgroup_size = [64 : index, 1 : index, 1 : index]} {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @main_dispatch_91_attention_20x4096x64xf16() {
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = arith.index_castui %0 : i32 to index
%5 = arith.index_castui %1 : i32 to index
%6 = arith.index_castui %2 : i32 to index
%7 = arith.index_castui %3 : i32 to index
%8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%9 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%7) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
%12 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%13 = flow.dispatch.tensor.load %9, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%14 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%15 = tensor.empty() : tensor<20x4096x64xf16>
%16 = iree_linalg_ext.attention {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 64]]>} ins(%12, %13, %14 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%15 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
flow.dispatch.tensor.store %16, %11, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
return
}
}
}
}
util.global private mutable @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer
util.initializer {
%c839971840 = arith.constant 839971840 : index
%c-1_i64 = arith.constant -1 : i64
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%allocator = hal.device.allocator<%device_0 : !hal.device> : !hal.allocator
%buffer = hal.allocator.allocate<%allocator : !hal.allocator> affinity(%c-1_i64) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage") : !hal.buffer{%c839971840}
util.global.store %buffer, @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer
util.return
}
util.func public @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16(%arg0: i32) attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} {
%c-1_i32 = arith.constant -1 : i32
%c-1_i64 = arith.constant -1 : i64
%c419985920 = arith.constant 419985920 : index
%c1 = arith.constant 1 : index
%c419985728 = arith.constant 419985728 : index
%c94927168_i32 = arith.constant 94927168 : i32
%c136870208_i32 = arith.constant 136870208 : i32
%c126384448_i32 = arith.constant 126384448 : i32
%c84441408_i32 = arith.constant 84441408 : i32
%c0 = arith.constant 0 : index
%0 = arith.index_cast %arg0 : i32 to index
%device_0 = hal.devices.get %c0 : !hal.device
%cmd = hal.command_buffer.create device(%device_0 : !hal.device) mode("OneShot|AllowInlineExecution") categories(Dispatch) : !hal.command_buffer
%pipeline_layout = hal.pipeline_layout.lookup device(%device_0 : !hal.device) layout(<push_constants = 4, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) : !hal.pipeline_layout
hal.command_buffer.push_constants<%cmd : !hal.command_buffer> layout(%pipeline_layout : !hal.pipeline_layout) offset(0) values([%c84441408_i32, %c126384448_i32, %c136870208_i32, %c94927168_i32]) : i32, i32, i32, i32
%main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer = util.global.load @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer
hal.command_buffer.push_descriptor_set<%cmd : !hal.command_buffer> layout(%pipeline_layout : !hal.pipeline_layout)[%c0] bindings([
%c0 = (%main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer)[%c0, %c419985728],
%c1 = (%main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer)[%c419985920, %c419985728]
])
%workgroup_x, %workgroup_y, %workgroup_z = hal.executable.calculate_workgroups device(%device_0 : !hal.device) target(@main_dispatch_91::@rocm_hsaco_fb::@main_dispatch_91_attention_20x4096x64xf16) : index, index, index
%exe = hal.executable.lookup device(%device_0 : !hal.device) executable(@main_dispatch_91) : !hal.executable
%ordinal = hal.executable.export.ordinal target(@main_dispatch_91::@rocm_hsaco_fb::@main_dispatch_91_attention_20x4096x64xf16) : index
scf.for %arg1 = %c0 to %0 step %c1 {
hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z])
hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
}
hal.command_buffer.finalize<%cmd : !hal.command_buffer>
%1 = util.null : !hal.fence
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
hal.device.queue.execute<%device_0 : !hal.device> affinity(%c-1_i64) wait(%1) signal(%fence) commands([%cmd])
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32
util.status.check_ok %status, "failed to wait on timepoint"
util.return
}
}
Compiles successfully with this command:
iree-compile compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.mlir --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --iree-rocm-link-bc --iree-rocm-bc-dir=C:\AMD\ROCm\5.5\amdgcn\bitcode --mlir-print-debuginfo=false --mlir-print-op-on-diagnostic=false --iree-rocm-target-chip=gfx1100 --iree-opt-const-expr-hoisting=False --iree-codegen-linalg-max-constant-fold-elements=9223372036854775807 --iree-opt-strip-assertions=true --verify=false -o compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.vmfb
Runs into shared memory allocation over allowed limit at runtime:
iree-benchmark-module --device=rocm --module=compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.vmfb --parameters=model=./stable_diffusion_xl_base_1_0_unet.safetensors --function=main
C:\V\iree\experimental\rocm\native_executable.c:137: INVALID_ARGUMENT; function 'main_dispatch_91_attention_20x4096x64xf16' requested shared memory size of 66304 larger than allowed size of 65536; while invoking native function hal.executable.create; while calling import;
[ 1] native hal.executable.create:0 -
[ 0] bytecode module.__init:298 C:\V\SHARK-Turbine\sdxl_attn\unet_rocm_dispatches\compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.mlir:40:18
The error is caught in the compiler if --iree-llvmgpu-shared-memory-limit=65536 is set in the compile CLI command.
I suppose we just need to change the tile size for the decomposed attention op, but I'm not sure of a good way to approach this. any help would be appreciated.
What component(s) does this issue relate to?
Compiler
Version information
Latest IREE (https://github.com/openxla/iree/commit/5d8907e82fc1eb741a4d4d27f5cae865323fd1d7)
(Notably enabled by https://github.com/openxla/iree/commit/946375cad71786462bcfd63dde6fe305d1e3b9ff)
Very similar issue on SDXL VAE:
python ..\models\turbine_models\custom_models\sdxl_inference\vae.py --compile_to=vmfb --external_weights=safetensors --device=rocm --variant="decode" --precision="fp16" --iree_target_triple=gfx1100 --external_weight_path=stable_diffusion_xl_base_1_0_vae.safetensors
C:\V\SHARK-Turbine\turb.env\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils._pytree._register_pytree_node(
C:\V\SHARK-Turbine\turb.env\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
File "C:\V\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\vae.py", line 156, in <module>
mod_str = export_vae_model(
^^^^^^^^^^^^^^^^^
File "C:\V\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\vae.py", line 148, in export_vae_model
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
File "C:\V\SHARK-Turbine\models\turbine_models\custom_models\sd_inference\utils.py", line 99, in compile_to_vmfb
flatbuffer_blob = ireec.compile_str(
^^^^^^^^^^^^^^^^^^
File "C:\V\iree-build\compiler\bindings\python\iree\compiler\tools\core.py", line 299, in compile_str
result = invoke_immediate(cl, immediate_input=input_bytes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\V\iree-build\compiler\bindings\python\iree\compiler\tools\binaries.py", line 198, in invoke_immediate
raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile.exe
Error code: 1
Diagnostics:
<stdin>:638:14: error: 'func.func' op uses 221952 bytes of shared memory; exceeded the limit of 166912 bytes
%172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
^
<stdin>:250:3: note: called from
func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
^
<stdin>:638:14: error: Failures have been detected while processing an MLIR pass pipeline
%172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
^
<stdin>:250:3: note: called from
func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
^
<stdin>:638:14: note: Pipeline failed while executing [`TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_10, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_3, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_17, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_21, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_22, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_25, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_26, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_27, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_40, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_41, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_83, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_84, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `OptimizeTensorInsertExtractSlices` on 'func.func' operation: @main_dispatch_40_conv_2d_nchw_fchw_1x512x128x128x512x3x3_f16, `Canonicalizer` on 'builtin.module' operation, `GPUDistributeSharedMemoryCopy` on 'func.func' operation: @main_dispatch_27_generic_512x16384_f32xf16xf16xf16xf32, `Canonicalizer` on 'func.func' operation: @main_dispatch_17_conv_2d_nchw_fchw_1x512x128x128x512x3x3_f16, `IREEComprehensiveBufferize` on 'builtin.module' operation, `ConvertToROCDL` on 'builtin.module' operation, `FoldMemRefAliasOps` on 'func.func' operation: @main_dispatch_26_matmul_transpose_b_16384x512x512_f16xf16xf32, `ExtractAddressComputationGPU` on 'builtin.module' operation, `Canonicalizer` on 'func.func' operation: @main_dispatch_21_generic_32x16x16384_f16xf32xf32xf16xf16xf16, `GPUCheckResourceUsage` on 'builtin.module' operation, `LoopInvariantCodeMotion` on 'func.func' operation: @main_dispatch_22_matmul_transpose_b_16384x512x512_f16]: reproducer generated at `./shark_tmp/core-reproducer.mlir`
%172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
^
<stdin>:638:14: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "none"}>
%172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
^
<stdin>:250:3: note: called from
func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
^
<stdin>:638:14: error: failed to translate executables
%172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
^
<stdin>:250:3: note: called from
func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
^
Invoked with:
iree-compile.exe C:\V\iree-build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-compile.exe - --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-opt-strip-assertions=true --verify=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx1100 --iree-rocm-link-bc=true --iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode --iree-vm-bytecode-module-strip-source-map=true --iree-vm-target-truncate-unsupported-floats --iree-opt-const-expr-hoisting=False --iree-codegen-linalg-max-constant-fold-elements=9223372036854775807
Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers
IIUC we just need a good way to control/improve tiling for these tiled+decomposed attention op dispatches.