iree icon indicating copy to clipboard operation
iree copied to clipboard

[ROCM] SDXL UNet, VAE failures from attention dispatch after tiling+decomposition exceeding shared memory limit.

Open monorimet opened this issue 1 year ago • 1 comments

What happened?

dispatch benchmark IR for easy repro (compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.mlir):

module attributes {hal.device.targets = [#hal.device.target<"rocm", {executable_targets = [#hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "none"}>], legacy_sync}>]} {
  hal.executable private @main_dispatch_91 {
    hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "none"}>) {
      hal.executable.export public @main_dispatch_91_attention_20x4096x64xf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 4, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {subgroup_size = 32 : index, translation_info = #iree_codegen.translation_info<LLVMGPUDistribute>, workgroup_size = [64 : index, 1 : index, 1 : index]} {
      ^bb0(%arg0: !hal.device):
        %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
        hal.return %x, %y, %z : index, index, index
      }
      builtin.module {
        func.func @main_dispatch_91_attention_20x4096x64xf16() {
          %0 = hal.interface.constant.load[0] : i32
          %1 = hal.interface.constant.load[1] : i32
          %2 = hal.interface.constant.load[2] : i32
          %3 = hal.interface.constant.load[3] : i32
          %4 = arith.index_castui %0 : i32 to index
          %5 = arith.index_castui %1 : i32 to index
          %6 = arith.index_castui %2 : i32 to index
          %7 = arith.index_castui %3 : i32 to index
          %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
          %9 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
          %10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
          %11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%7) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
          %12 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
          %13 = flow.dispatch.tensor.load %9, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
          %14 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
          %15 = tensor.empty() : tensor<20x4096x64xf16>
          %16 = iree_linalg_ext.attention {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 64]]>} ins(%12, %13, %14 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%15 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
          flow.dispatch.tensor.store %16, %11, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
          return
        }
      }
    }
  }
  util.global private mutable @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer
  util.initializer {
    %c839971840 = arith.constant 839971840 : index
    %c-1_i64 = arith.constant -1 : i64
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %allocator = hal.device.allocator<%device_0 : !hal.device> : !hal.allocator
    %buffer = hal.allocator.allocate<%allocator : !hal.allocator> affinity(%c-1_i64) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage") : !hal.buffer{%c839971840}
    util.global.store %buffer, @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer
    util.return
  }
  util.func public @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16(%arg0: i32) attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} {
    %c-1_i32 = arith.constant -1 : i32
    %c-1_i64 = arith.constant -1 : i64
    %c419985920 = arith.constant 419985920 : index
    %c1 = arith.constant 1 : index
    %c419985728 = arith.constant 419985728 : index
    %c94927168_i32 = arith.constant 94927168 : i32
    %c136870208_i32 = arith.constant 136870208 : i32
    %c126384448_i32 = arith.constant 126384448 : i32
    %c84441408_i32 = arith.constant 84441408 : i32
    %c0 = arith.constant 0 : index
    %0 = arith.index_cast %arg0 : i32 to index
    %device_0 = hal.devices.get %c0 : !hal.device
    %cmd = hal.command_buffer.create device(%device_0 : !hal.device) mode("OneShot|AllowInlineExecution") categories(Dispatch) : !hal.command_buffer
    %pipeline_layout = hal.pipeline_layout.lookup device(%device_0 : !hal.device) layout(<push_constants = 4, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) : !hal.pipeline_layout
    hal.command_buffer.push_constants<%cmd : !hal.command_buffer> layout(%pipeline_layout : !hal.pipeline_layout) offset(0) values([%c84441408_i32, %c126384448_i32, %c136870208_i32, %c94927168_i32]) : i32, i32, i32, i32
    %main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer = util.global.load @main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer
    hal.command_buffer.push_descriptor_set<%cmd : !hal.command_buffer> layout(%pipeline_layout : !hal.pipeline_layout)[%c0] bindings([
      %c0 = (%main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer)[%c0, %c419985728], 
      %c1 = (%main_dispatch_91_rocm_hsaco_fb_main_dispatch_91_attention_20x4096x64xf16_buffer : !hal.buffer)[%c419985920, %c419985728]
    ])
    %workgroup_x, %workgroup_y, %workgroup_z = hal.executable.calculate_workgroups device(%device_0 : !hal.device) target(@main_dispatch_91::@rocm_hsaco_fb::@main_dispatch_91_attention_20x4096x64xf16) : index, index, index
    %exe = hal.executable.lookup device(%device_0 : !hal.device) executable(@main_dispatch_91) : !hal.executable
    %ordinal = hal.executable.export.ordinal target(@main_dispatch_91::@rocm_hsaco_fb::@main_dispatch_91_attention_20x4096x64xf16) : index
    scf.for %arg1 = %c0 to %0 step %c1 {
      hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z])
      hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
    }
    hal.command_buffer.finalize<%cmd : !hal.command_buffer>
    %1 = util.null : !hal.fence
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    hal.device.queue.execute<%device_0 : !hal.device> affinity(%c-1_i64) wait(%1) signal(%fence) commands([%cmd])
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32
    util.status.check_ok %status, "failed to wait on timepoint"
    util.return
  }
}

Compiles successfully with this command:

iree-compile compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.mlir --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --iree-rocm-link-bc --iree-rocm-bc-dir=C:\AMD\ROCm\5.5\amdgcn\bitcode --mlir-print-debuginfo=false --mlir-print-op-on-diagnostic=false --iree-rocm-target-chip=gfx1100 --iree-opt-const-expr-hoisting=False --iree-codegen-linalg-max-constant-fold-elements=9223372036854775807 --iree-opt-strip-assertions=true --verify=false -o compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.vmfb

Runs into shared memory allocation over allowed limit at runtime:

iree-benchmark-module --device=rocm --module=compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.vmfb --parameters=model=./stable_diffusion_xl_base_1_0_unet.safetensors --function=main       
C:\V\iree\experimental\rocm\native_executable.c:137: INVALID_ARGUMENT; function 'main_dispatch_91_attention_20x4096x64xf16' requested shared memory size of 66304 larger than allowed size of 65536; while invoking native function hal.executable.create; while calling import;                
[ 1]   native hal.executable.create:0 -
[ 0] bytecode module.__init:298 C:\V\SHARK-Turbine\sdxl_attn\unet_rocm_dispatches\compiled_unet_main_dispatch_91_rocm_hsaco_fb_benchmark.mlir:40:18

The error is caught in the compiler if --iree-llvmgpu-shared-memory-limit=65536 is set in the compile CLI command.

I suppose we just need to change the tile size for the decomposed attention op, but I'm not sure of a good way to approach this. any help would be appreciated.

What component(s) does this issue relate to?

Compiler

Version information

Latest IREE (https://github.com/openxla/iree/commit/5d8907e82fc1eb741a4d4d27f5cae865323fd1d7)

(Notably enabled by https://github.com/openxla/iree/commit/946375cad71786462bcfd63dde6fe305d1e3b9ff)

monorimet avatar Feb 22 '24 22:02 monorimet

Very similar issue on SDXL VAE:

python ..\models\turbine_models\custom_models\sdxl_inference\vae.py --compile_to=vmfb --external_weights=safetensors --device=rocm --variant="decode" --precision="fp16" --iree_target_triple=gfx1100 --external_weight_path=stable_diffusion_xl_base_1_0_vae.safetensors 
C:\V\SHARK-Turbine\turb.env\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
C:\V\SHARK-Turbine\turb.env\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
  File "C:\V\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\vae.py", line 156, in <module>
    mod_str = export_vae_model(
              ^^^^^^^^^^^^^^^^^
  File "C:\V\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\vae.py", line 148, in export_vae_model
    utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
  File "C:\V\SHARK-Turbine\models\turbine_models\custom_models\sd_inference\utils.py", line 99, in compile_to_vmfb
    flatbuffer_blob = ireec.compile_str(
                      ^^^^^^^^^^^^^^^^^^
  File "C:\V\iree-build\compiler\bindings\python\iree\compiler\tools\core.py", line 299, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\V\iree-build\compiler\bindings\python\iree\compiler\tools\binaries.py", line 198, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile.exe
Error code: 1
Diagnostics:
<stdin>:638:14: error: 'func.func' op uses 221952 bytes of shared memory; exceeded the limit of 166912 bytes
    %172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
             ^
<stdin>:250:3: note: called from
  func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
  ^
<stdin>:638:14: error: Failures have been detected while processing an MLIR pass pipeline
    %172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
             ^
<stdin>:250:3: note: called from
  func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
  ^
<stdin>:638:14: note: Pipeline failed while executing [`TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_10, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_3, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_17, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_21, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_22, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_25, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_26, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_27, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_40, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_41, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_83, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @main_dispatch_84, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `LLVMGPULowerExecutableTarget` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `OptimizeTensorInsertExtractSlices` on 'func.func' operation: @main_dispatch_40_conv_2d_nchw_fchw_1x512x128x128x512x3x3_f16, `Canonicalizer` on 'builtin.module' operation, `GPUDistributeSharedMemoryCopy` on 'func.func' operation: @main_dispatch_27_generic_512x16384_f32xf16xf16xf16xf32, `Canonicalizer` on 'func.func' operation: @main_dispatch_17_conv_2d_nchw_fchw_1x512x128x128x512x3x3_f16, `IREEComprehensiveBufferize` on 'builtin.module' operation, `ConvertToROCDL` on 'builtin.module' operation, `FoldMemRefAliasOps` on 'func.func' operation: @main_dispatch_26_matmul_transpose_b_16384x512x512_f16xf16xf32, `ExtractAddressComputationGPU` on 'builtin.module' operation, `Canonicalizer` on 'func.func' operation: @main_dispatch_21_generic_32x16x16384_f16xf32xf32xf16xf16xf16, `GPUCheckResourceUsage` on 'builtin.module' operation, `LoopInvariantCodeMotion` on 'func.func' operation: @main_dispatch_22_matmul_transpose_b_16384x512x512_f16]: reproducer generated at `./shark_tmp/core-reproducer.mlir`
    %172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
             ^
<stdin>:638:14: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "none"}>
    %172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
             ^
<stdin>:250:3: note: called from
  func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
  ^
<stdin>:638:14: error: failed to translate executables
    %172:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%165, %168, %171, %float0.000000e00, %false_169, %none_170, %none_171) : (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384,512],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,16384,512],f16>, !torch.vtensor<[1,1,16384],f32>)
             ^
<stdin>:250:3: note: called from
  func.func @main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
  ^


Invoked with:
 iree-compile.exe C:\V\iree-build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-compile.exe - --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-opt-strip-assertions=true --verify=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx1100 --iree-rocm-link-bc=true --iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode --iree-vm-bytecode-module-strip-source-map=true --iree-vm-target-truncate-unsupported-floats --iree-opt-const-expr-hoisting=False --iree-codegen-linalg-max-constant-fold-elements=9223372036854775807

Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers

IIUC we just need a good way to control/improve tiling for these tiled+decomposed attention op dispatches.

monorimet avatar Feb 23 '24 07:02 monorimet