Implement initial heterogeneous support MVP for CPU (+ maybe something else).
TLDR: programs will be able to selectively allocate buffers that are compatible with all devices they are used on, submit command buffers using resources from each other, and synchronize with semaphores to order operations. Performance is not a goal of this effort - only straight line execution and the addition of passes/APIs to enable the support. An initial MVP will be CPU+CPU only but using nothing any other HAL implementation could do; HIP support is a stretch if anyone is interested in helping out.
Heterogeneous device support builds on the current homogeneous multi-device support by extending what was done there with a few new features (see the sub-issues below).
Once implemented we'll have compiled programs that can run and request allocations on devices selected from a subset of those devices a particular buffer is used on. A policy defined by the application (possibly derived from available devices, but TBD) will be used for selection. For example, a buffer used on CPU and NPU could always be allocated on the NPU if it requires special allocations (vs normal host malloc). Transfers through staging buffers in host memory can be used when two devices cannot import/export though such compiler support may come later.
After improving the compiler the remaining work is entirely within the runtime HAL for better seamless buffer and semaphore interop. For an MVP of CPU + CPU not much is required. For when the stretch goal of CPU + [some accelerator] is used buffers are mostly taken care of as the local executors always map PERSISTENT for use in dispatches via the buffer vtable and as long as the non-CPU device implements mapping properly host pointers can be resolved. Semaphores are the trickier item but the recently added iree_hal_semaphore_import_timepoint/iree_hal_semaphore_export_timepoint can help with that in the mean-time (even if slow to start). Any non-CPU HAL will need to support IREE_HAL_EXTERNAL_TIMEPOINT_TYPE_WAIT_PRIMITIVE to interop with the CPU side.
Improvements to compiler support for eliding more transfers, automatically detecting when transfers are still beneficial (such as in NUMA systems), and declarative topologies for indicating links will be left for future work. Other HAL targets besides CPU and the selected experimental target will also be left for future work. There are several known places with poor performance characteristics around import/export of semaphores and potentially required scoped buffer mapping that will be left for future work. At a baseline any HAL implementation wanting to interop will need to support WAIT_PRIMITIVE and map_range and all other faster paths are considered optimizations.
The largest open design question on the compiler-side is around the required cost modeling we cannot (currently/ever) do. Transfers are inserted today with the assumption that all devices are NUMA and the cost of transferring amortizes the traffic of fetching in consumers across a potentially slow bus without DMA support. To start a coarse non-default flag will be added to elide everything possible. I'm not sure there's any good solution that does not require algorithm-aware analysis that we don't have today (measuring reads/writes and reuse). An explicit attribute on transfer ops inserted in frontends may be required to indicate whether a transfer is load-bearing for performance and should not be elided to allow tuning.
Thanks a lot for taking the time to outline this. We would be happy to help with any issues you created here.
We’ve been doing some prototyping ourselves using the metal/local-task drivers, where we took a shortcut for the “programs will be able to selectively allocate buffers that are compatible with all devices” by simply allocating everything as shared using the metal allocator and manually working around the semaphore issues by hand-modifying the HAL IR.
I have some questions about how to handle buffer import/export—or whether it's even necessary at all for the MVP. My observation so far is that drivers typically require buffers to be allocated using their own allocator, since many of the iree_hal_*_command_buffer_* functions expect a native buffer handle. However, this doesn't seem to be the case for the local-task device, as I couldn’t find any references indicating that it requires its own native handle (which I assume is iree_hal_heap_buffer_t).
Given this, is my understanding correct that for the MVP, buffer import/export might not be required at all? As you mentioned in (#20857/#20855), shared allocations would be performed by the device that supports them (metal, in this example), so the HAL functions on the GPU device would still have access to their native handles, and the local devices wouldn’t care?
If the import/export is still needed would it be something that is handled solely in the runtime or would the transfer ops have to be translated into some form of import/export ops instead of eliding them completely?
If its of any interest for someone, this is the IR snipped i got running for local-task/metal when replacing the local-task allocator with the metal allocator:
metal/local-task example
originated from the tools/test/iree-run-module-multi.mlir testcase, there are not a lot of modifications, mostly to resolve semaphore errors like joining on multiple fences that are not from the same device.
compile command i used:
./build/tools/iree-compile --compile-from=vm --iree-execution-model=async-external --iree-hal-target-device=device_a=local[0] --iree-hal-target-device=device_b=metal --iree-hal-local-target-device-backends=vmvx working_multi_device.mlir -o multi.vmfb
run command:
./build/tools/iree-run-module --module=multi.vmfb --device=local-task --device=metal --input=4xf32=10,11,12,13 --function=multi_device_mul
module attributes {vm.toplevel} {
vm.module public @module {
vm.global.ref private mutable @device_a : !vm.ref<!hal.device>
vm.rodata private @_utf8_hal_device_id_369FCE3B885986F6 {alignment = 1 : i64} "hal.device.id"
vm.rodata private @_utf8_local_2C7344D4E05782F4 {alignment = 1 : i64} "local*"
vm.rodata private @_utf8_hal_executable_format_C7128E2AE1BE720D {alignment = 1 : i64} "hal.executable.format"
vm.rodata private @_utf8_vmvx_bytecode_fb_90AB3AF67FE09641 {alignment = 1 : i64} "vmvx-bytecode-fb"
vm.rodata private @multi_device_mul_dispatch_0_vmvx_bytecode_fb {alignment = 16 : i64, mime_type = "application/x-flatbuffers"} dense<"0xvector<1397xi8>
vm.rodata private @_utf8_metal_FB846B13C56E663B {alignment = 1 : i64} "metal"
vm.rodata private @_utf8_metal_msl_fb_BD2413475D7AB9A0 {alignment = 1 : i64} "metal-msl-fb"
vm.rodata private @multi_device_mul_dispatch_1_metal_msl_fb {alignment = 16 : i64, mime_type = "application/x-flatbuffers"} dense<"0xvector<6336xi8>
vm.global.ref private mutable @__device_a_executable_0_multi_device_mul_dispatch_0 : !vm.ref<!hal.executable>
vm.global.ref private mutable @device_b : !vm.ref<!hal.device>
vm.global.ref private mutable @__device_b_executable_0_multi_device_mul_dispatch_1 : !vm.ref<!hal.executable>
vm.func private @__multi_device_mul_memoize_apply() -> !vm.ref<!hal.command_buffer> attributes {inlining_policy = #util.inline.never} {
%c13 = vm.const.i32 13
%c28 = vm.const.i32 28
%c2 = vm.const.i32 2
%null = vm.const.ref.zero : !vm.ref<!hal.buffer>
%c1 = vm.const.i32 1
%c3 = vm.const.i32 3
%zero = vm.const.i32.zero
%c64 = vm.const.i64 64
%c16 = vm.const.i64 16
%zero_0 = vm.const.i64.zero
%c-1 = vm.const.i64 -1
%device_a = vm.global.load.ref @device_a : !vm.ref<!hal.device>
%__device_a_executable_0_multi_device_mul_dispatch_0 = vm.global.load.ref @__device_a_executable_0_multi_device_mul_dispatch_0 : !vm.ref<!hal.executable>
%ref = vm.call @hal.command_buffer.create(%device_a, %zero, %c3, %c-1, %c3) : (!vm.ref<!hal.device>, i32, i32, i64, i32) -> !vm.ref<!hal.command_buffer>
vm.call.variadic @hal.command_buffer.dispatch(%ref, %__device_a_executable_0_multi_device_mul_dispatch_0, %zero, %c1, %c1, %c1, %zero_0, [], [(%zero, %zero, %null, %zero_0, %c16), (%zero, %c2, %null, %zero_0, %c64)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call @hal.command_buffer.execution_barrier(%ref, %c28, %c13, %zero_0) : (!vm.ref<!hal.command_buffer>, i32, i32, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref, %c2, %c1, %null, %zero_0, %null, %zero_0, %c16, %zero_0) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64, i64) -> ()
vm.call @hal.command_buffer.execution_barrier(%ref, %c28, %c13, %zero_0) : (!vm.ref<!hal.command_buffer>, i32, i32, i64) -> ()
vm.call @hal.command_buffer.finalize(%ref) : (!vm.ref<!hal.command_buffer>) -> ()
vm.return %ref : !vm.ref<!hal.command_buffer>
}
vm.global.ref private mutable @__multi_device_mul_memoize_result_0_device_a : !vm.ref<!hal.command_buffer>
vm.func private @__multi_device_mul_memoize_apply_0() -> !vm.ref<!hal.command_buffer> attributes {inlining_policy = #util.inline.never} {
%c13 = vm.const.i32 13
%c28 = vm.const.i32 28
%c2 = vm.const.i32 2
%null = vm.const.ref.zero : !vm.ref<!hal.buffer>
%c1 = vm.const.i32 1
%c3 = vm.const.i32 3
%zero = vm.const.i32.zero
%c64 = vm.const.i64 64
%c16 = vm.const.i64 16
%zero_0 = vm.const.i64.zero
%c-1 = vm.const.i64 -1
%device_b = vm.global.load.ref @device_b : !vm.ref<!hal.device>
%__device_b_executable_0_multi_device_mul_dispatch_1 = vm.global.load.ref @__device_b_executable_0_multi_device_mul_dispatch_1 : !vm.ref<!hal.executable>
%ref = vm.call @hal.command_buffer.create(%device_b, %zero, %c3, %c-1, %c3) : (!vm.ref<!hal.device>, i32, i32, i64, i32) -> !vm.ref<!hal.command_buffer>
vm.call.variadic @hal.command_buffer.dispatch(%ref, %__device_b_executable_0_multi_device_mul_dispatch_1, %zero, %c1, %c1, %c1, %zero_0, [], [(%zero, %zero, %null, %zero_0, %c16), (%zero, %c2, %null, %zero_0, %c64)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call @hal.command_buffer.execution_barrier(%ref, %c28, %c13, %zero_0) : (!vm.ref<!hal.command_buffer>, i32, i32, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref, %c2, %c1, %null, %zero_0, %null, %zero_0, %c16, %zero_0) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64, i64) -> ()
vm.call @hal.command_buffer.execution_barrier(%ref, %c28, %c13, %zero_0) : (!vm.ref<!hal.command_buffer>, i32, i32, i64) -> ()
vm.call @hal.command_buffer.finalize(%ref) : (!vm.ref<!hal.command_buffer>) -> ()
vm.return %ref : !vm.ref<!hal.command_buffer>
}
vm.global.ref private mutable @__multi_device_mul_memoize_result_0_device_b : !vm.ref<!hal.command_buffer>
vm.import private @hal.buffer.assert(%buffer : !vm.ref<!hal.buffer>, %message : !vm.buffer, %allocator : !vm.ref<!hal.allocator>, %minimum_length : i64, %memory_types : i32, %buffer_usage : i32)
vm.import private @hal.buffer_view.create(%buffer : !vm.ref<!hal.buffer>, %source_offset : i64, %source_length : i64, %element_type : i32, %encoding_type : i32, %shape : i64 ...) -> !vm.ref<!hal.buffer_view> attributes {nosideeffects}
vm.import private @hal.buffer_view.assert(%buffer_view : !vm.ref<!hal.buffer_view>, %message : !vm.buffer, %element_type : i32, %encoding_type : i32, %shape : i64 ...)
vm.import private @hal.buffer_view.buffer(%buffer_view : !vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer> attributes {nosideeffects}
vm.import private @hal.command_buffer.create(%device : !vm.ref<!hal.device>, %modes : i32, %command_categories : i32, %queue_affinity : i64, %binding_capacity : i32) -> !vm.ref<!hal.command_buffer> attributes {minimum_version = 6 : i32}
vm.import private @hal.command_buffer.finalize(%command_buffer : !vm.ref<!hal.command_buffer>)
vm.import private @hal.command_buffer.execution_barrier(%command_buffer : !vm.ref<!hal.command_buffer>, %source_stage_mask : i32, %target_stage_mask : i32, %flags : i64)
vm.import private @hal.command_buffer.copy_buffer(%command_buffer : !vm.ref<!hal.command_buffer>, %source_buffer_slot : i32, %target_buffer_slot : i32, %source_buffer : !vm.ref<!hal.buffer>, %source_offset : i64, %target_buffer : !vm.ref<!hal.buffer>, %target_offset : i64, %length : i64, %flags : i64)
vm.import private @hal.command_buffer.dispatch(%command_buffer : !vm.ref<!hal.command_buffer>, %executable : !vm.ref<!hal.executable>, %entry_point : i32, %workgroup_x : i32, %workgroup_y : i32, %workgroup_z : i32, %flags : i64, %constants : i32 ..., %bindings : tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.import private @hal.device.allocator(%device : !vm.ref<!hal.device>) -> !vm.ref<!hal.allocator> attributes {nosideeffects}
vm.import private @hal.device.query.i64(%device : !vm.ref<!hal.device>, %category : !vm.buffer, %key : !vm.buffer) -> (i32, i64) attributes {nosideeffects}
vm.import private @hal.device.queue.alloca(%device : !vm.ref<!hal.device>, %queue_affinity : i64, %wait_fence : !vm.ref<!hal.fence>, %signal_fence : !vm.ref<!hal.fence>, %pool : i64, %memory_types : i32, %buffer_usage : i32, %allocation_size : i64, %flags : i64) -> !vm.ref<!hal.buffer>
vm.import private @hal.device.queue.dealloca(%device : !vm.ref<!hal.device>, %queue_affinity : i64, %wait_fence : !vm.ref<!hal.fence>, %signal_fence : !vm.ref<!hal.fence>, %buffer : !vm.ref<!hal.buffer>, %flags : i64)
vm.import private @hal.device.queue.execute.indirect(%device : !vm.ref<!hal.device>, %queue_affinity : i64, %wait_fence : !vm.ref<!hal.fence>, %signal_fence : !vm.ref<!hal.fence>, %command_buffer : !vm.ref<!hal.command_buffer>, %flags : i64, %binding_table : tuple<!vm.ref<!hal.buffer>, i64, i64> ...)
vm.import private @hal.devices.count() -> i32 attributes {nosideeffects}
vm.import private @hal.devices.get(%index : i32) -> !vm.ref<!hal.device> attributes {nosideeffects}
vm.import private @hal.executable.create(%device : !vm.ref<!hal.device>, %queue_affinity : i64, %executable_format : !vm.buffer, %executable_data : !vm.buffer, %constants : !vm.buffer) -> !vm.ref<!hal.executable> attributes {nosideeffects}
vm.import private @hal.fence.create(%device : !vm.ref<!hal.device>, %flags : i64) -> !vm.ref<!hal.fence>
vm.import private @hal.fence.join(%flags : i64, %fences : !vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence> attributes {nosideeffects}
vm.import private @hal.fence.signal(%fence : !vm.ref<!hal.fence>) -> () attributes {nosideeffects}
vm.import private @hal.fence.await(%timeout_millis : i32, %flags : i64, %fences : !vm.ref<!hal.fence> ...) -> i32 attributes {vm.yield}
vm.rodata private @_utf8_input0_2D1C6C95BF4E8E4E {alignment = 1 : i64} "input0"
vm.rodata private @_utf8_tensor_8F58A528F0C2C1FD {alignment = 1 : i64} "tensor"
vm.func private @multi_device_mul(%arg0: !vm.ref<!hal.buffer_view>, %arg1: !vm.ref<!hal.fence>, %arg2: !vm.ref<!hal.fence>) -> !vm.ref<!hal.buffer_view> attributes {iree.reflection = {iree.abi.declaration = "async func @multi_device_mul(%input0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})", iree.abi.model = "coarse-fences"}} {
%c48 = vm.const.i32 48
%c3075 = vm.const.i32 3075
%c16 = vm.const.i32 16
%c1 = vm.const.i32 1
%c553648160 = vm.const.i32 553648160
%c4 = vm.const.i64 4
%c16_0 = vm.const.i64 16
%zero = vm.const.i64.zero
%c64 = vm.const.i64 64
%c-1 = vm.const.i64 -1
%device_a = vm.global.load.ref @device_a : !vm.ref<!hal.device>
%device_b = vm.global.load.ref @device_b : !vm.ref<!hal.device>
%__multi_device_mul_memoize_result_0_device_a = vm.global.load.ref @__multi_device_mul_memoize_result_0_device_a : !vm.ref<!hal.command_buffer>
%__multi_device_mul_memoize_result_0_device_b = vm.global.load.ref @__multi_device_mul_memoize_result_0_device_b : !vm.ref<!hal.command_buffer>
%_utf8_input0_2D1C6C95BF4E8E4E = vm.const.ref.rodata @_utf8_input0_2D1C6C95BF4E8E4E : !vm.buffer
vm.call.variadic @hal.buffer_view.assert(%arg0, %_utf8_input0_2D1C6C95BF4E8E4E, %c553648160, %c1, [%c4]) : (!vm.ref<!hal.buffer_view>, !vm.buffer, i32, i32, i64 ...)
%ref = vm.call @hal.buffer_view.buffer(%arg0) {nosideeffects} : (!vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer>
%ref_1 = vm.call @hal.device.allocator(%device_a) {nosideeffects} : (!vm.ref<!hal.device>) -> !vm.ref<!hal.allocator>
%_utf8_tensor_8F58A528F0C2C1FD = vm.const.ref.rodata @_utf8_tensor_8F58A528F0C2C1FD : !vm.buffer
vm.call @hal.buffer.assert(%ref, %_utf8_tensor_8F58A528F0C2C1FD, %ref_1, %c16_0, %c16, %c3075) : (!vm.ref<!hal.buffer>, !vm.buffer, !vm.ref<!hal.allocator>, i64, i32, i32) -> ()
%ref_2 = vm.call @hal.fence.create(%device_b, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
%ref_3 = vm.call @hal.device.queue.alloca(%device_b, %c-1, %arg1, %ref_2, %zero, %c48, %c3075, %c16_0, %zero) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i64, i32, i32, i64, i64) -> !vm.ref<!hal.buffer>
%join_b0 = vm.call.variadic @hal.fence.join(%zero, [%ref_2]) {nosideeffects} : (i64, !vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence>
%ref_4 = vm.call @hal.fence.create(%device_a, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
%ref_5 = vm.call @hal.device.queue.alloca(%device_a, %c-1, %arg1, %ref_4, %zero, %c48, %c3075, %c64, %zero) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i64, i32, i32, i64, i64) -> !vm.ref<!hal.buffer>
%ref_6 = vm.call.variadic @hal.fence.join(%zero, [%ref_4]) {nosideeffects} : (i64, !vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence>
%ref_7 = vm.call @hal.fence.create(%device_a, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
vm.call.variadic @hal.device.queue.execute.indirect(%device_a, %c-1, %ref_6, %ref_7, %__multi_device_mul_memoize_result_0_device_a, %zero, [(%ref, %zero, %c16_0), (%ref_3, %zero, %c16_0), (%ref_5, %zero, %c64)]) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, !vm.ref<!hal.command_buffer>, i64, tuple<!vm.ref<!hal.buffer>, i64, i64> ...)
%ref_8 = vm.call @hal.fence.create(%device_b, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
vm.call @hal.device.queue.dealloca(%device_a, %c-1, %ref_7, %ref_8, %ref_5, %zero) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, !vm.ref<!hal.buffer>, i64) -> ()
%ref_9 = vm.call @hal.fence.create(%device_a, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
%ref_10 = vm.call @hal.device.queue.alloca(%device_a, %c-1, %ref_8, %ref_9, %zero, %c48, %c3075, %c16_0, %zero) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i64, i32, i32, i64, i64) -> !vm.ref<!hal.buffer>
%join_b1 = vm.call.variadic @hal.fence.join(%zero, [%ref_9]) {nosideeffects} : (i64, !vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence>
%ref_11 = vm.call @hal.fence.create(%device_b, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
%ref_12 = vm.call @hal.device.queue.alloca(%device_b, %c-1, %ref_8, %ref_11, %zero, %c48, %c3075, %c64, %zero) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i64, i32, i32, i64, i64) -> !vm.ref<!hal.buffer>
%ref_13 = vm.call.variadic @hal.fence.join(%zero, [ %ref_11]) {nosideeffects} : (i64, !vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence>
%ref_14 = vm.call @hal.fence.create(%device_b, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
vm.call.variadic @hal.device.queue.execute.indirect(%device_b, %c-1, %ref_13, %ref_14, %__multi_device_mul_memoize_result_0_device_b, %zero, [(%ref_3, %zero, %c16_0), (%ref_10, %zero, %c16_0), (%ref_12, %zero, %c64)]) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, !vm.ref<!hal.command_buffer>, i64, tuple<!vm.ref<!hal.buffer>, i64, i64> ...)
%0 = vm.call.variadic @hal.fence.await(%c-1, %zero, [%ref_14]) : (i64, i64, !vm.ref<!hal.fence> ...) -> i32
vm.call @hal.fence.signal(%arg2) : (!vm.ref<!hal.fence>) -> ()
%ref_15 = vm.call.variadic @hal.buffer_view.create(%ref_10, %zero, %c16_0, %c553648160, %c1, [%c4]) {nosideeffects} : (!vm.ref<!hal.buffer>, i64, i64, i32, i32, i64 ...) -> !vm.ref<!hal.buffer_view>
vm.return %ref_15 : !vm.ref<!hal.buffer_view>
}
vm.export @multi_device_mul attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "async func @multi_device_mul(%input0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})", iree.abi.model = "coarse-fences"}}
vm.export @__init
vm.func private @__init() {
%c1 = vm.const.i32 1
%null = vm.const.ref.zero : !vm.buffer
%c14 = vm.const.i32 14
%c-1 = vm.const.i64 -1
%c18 = vm.const.i32 18
%zero = vm.const.i32.zero
%zero_0 = vm.const.i64.zero
%c1_1 = vm.const.i64 1
%null_2 = vm.const.ref.zero : !vm.ref<!hal.device>
%0 = vm.call @hal.devices.count() {nosideeffects} : () -> i32
%1 = vm.ext.i32.i64.s %0 : i32 -> i64
vm.br ^bb1(%zero_0, %zero_0, %null_2 : i64, i64, !vm.ref<!hal.device>)
^bb1(%2: i64, %3: i64, %4: !vm.ref<!hal.device>): // 2 preds: ^bb0, ^bb4
%rnz = vm.cmp.nz.ref %4 : !vm.ref<!hal.device>
%5 = vm.xor.i32 %rnz, %c1 : i32
%slt = vm.cmp.lt.i64.s %2, %1 : i64
%6 = vm.and.i32 %5, %slt : i32
vm.cond_br %6, ^bb2, ^bb5
^bb2: // pred: ^bb1
%7 = vm.trunc.i64.i32 %2 : i64 -> i32
%ref = vm.call @hal.devices.get(%7) {nosideeffects} : (i32) -> !vm.ref<!hal.device>
%_utf8_hal_device_id_369FCE3B885986F6 = vm.const.ref.rodata @_utf8_hal_device_id_369FCE3B885986F6 : !vm.buffer
%_utf8_local_2C7344D4E05782F4 = vm.const.ref.rodata @_utf8_local_2C7344D4E05782F4 : !vm.buffer
%8:2 = vm.call @hal.device.query.i64(%ref, %_utf8_hal_device_id_369FCE3B885986F6, %_utf8_local_2C7344D4E05782F4) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer) -> (i32, i64)
%nz = vm.cmp.nz.i64 %8#1 : i64
%9 = vm.select.i32 %8#0, %nz, %zero : i32
vm.cond_br %9, ^bb3, ^bb4(%zero : i32)
^bb3: // pred: ^bb2
%_utf8_hal_executable_format_C7128E2AE1BE720D = vm.const.ref.rodata @_utf8_hal_executable_format_C7128E2AE1BE720D : !vm.buffer
%_utf8_vmvx_bytecode_fb_90AB3AF67FE09641 = vm.const.ref.rodata @_utf8_vmvx_bytecode_fb_90AB3AF67FE09641 : !vm.buffer
%10:2 = vm.call @hal.device.query.i64(%ref, %_utf8_hal_executable_format_C7128E2AE1BE720D, %_utf8_vmvx_bytecode_fb_90AB3AF67FE09641) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer) -> (i32, i64)
%nz_3 = vm.cmp.nz.i64 %10#1 : i64
%11 = vm.select.i32 %10#0, %nz_3, %zero : i32
vm.br ^bb4(%11 : i32)
^bb4(%12: i32): // 2 preds: ^bb2, ^bb3
%eq = vm.cmp.eq.i64 %3, %zero_0 : i64
%13 = vm.select.i64 %12, %c1_1, %zero_0 : i64
%14 = vm.add.i64 %3, %13 : i64
%15 = vm.and.i32 %12, %eq : i32
%ref_4 = vm.select.ref %15, %ref, %null_2 : !vm.ref<!hal.device>
%16 = vm.add.i64 %2, %c1_1 : i64
vm.br ^bb1(%16, %14, %ref_4 : i64, i64, !vm.ref<!hal.device>)
^bb5: // pred: ^bb1
vm.cond_br %5, ^bb6, ^bb7
^bb6: // pred: ^bb5
vm.fail %c18, "HAL device `device_a` not found or unavailable: #hal.device.target<\22local\22, {ordinal = 0 : index}, [#hal.executable.target<\22vmvx\22, \22vmvx-bytecode-fb\22, {iree.encoding.resolver = #iree_cpu.vmvx_encoding_layout<>, ukernels = \22none\22}>]>"
^bb7: // pred: ^bb5
%_utf8_hal_executable_format_C7128E2AE1BE720D_5 = vm.const.ref.rodata @_utf8_hal_executable_format_C7128E2AE1BE720D : !vm.buffer
%_utf8_vmvx_bytecode_fb_90AB3AF67FE09641_6 = vm.const.ref.rodata @_utf8_vmvx_bytecode_fb_90AB3AF67FE09641 : !vm.buffer
%17:2 = vm.call @hal.device.query.i64(%4, %_utf8_hal_executable_format_C7128E2AE1BE720D_5, %_utf8_vmvx_bytecode_fb_90AB3AF67FE09641_6) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer) -> (i32, i64)
%nz_7 = vm.cmp.nz.i64 %17#1 : i64
%18 = vm.select.i32 %17#0, %nz_7, %zero : i32
%19 = vm.select.i64 %18, %zero_0, %c-1 : i64
%eq_8 = vm.cmp.eq.i64 %19, %zero_0 : i64
vm.global.store.ref %4, @device_a : !vm.ref<!hal.device>
vm.cond_br %eq_8, ^bb8, ^bb9
^bb8: // pred: ^bb7
%multi_device_mul_dispatch_0_vmvx_bytecode_fb = vm.const.ref.rodata @multi_device_mul_dispatch_0_vmvx_bytecode_fb : !vm.buffer
%ref_9 = vm.call @hal.executable.create(%4, %c-1, %_utf8_vmvx_bytecode_fb_90AB3AF67FE09641_6, %multi_device_mul_dispatch_0_vmvx_bytecode_fb, %null) {nosideeffects} : (!vm.ref<!hal.device>, i64, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref<!hal.executable>
vm.global.store.ref %ref_9, @__device_a_executable_0_multi_device_mul_dispatch_0 : !vm.ref<!hal.executable>
vm.br ^bb10(%zero_0, %zero_0, %null_2 : i64, i64, !vm.ref<!hal.device>)
^bb9: // pred: ^bb7
vm.fail %c14, "HAL device `device_a` does not support any variant of executable `multi_device_mul_dispatch_0`; available formats: [vmvx-bytecode-fb]"
^bb10(%20: i64, %21: i64, %22: !vm.ref<!hal.device>): // 2 preds: ^bb8, ^bb13
%rnz_10 = vm.cmp.nz.ref %22 : !vm.ref<!hal.device>
%23 = vm.xor.i32 %rnz_10, %c1 : i32
%slt_11 = vm.cmp.lt.i64.s %20, %1 : i64
%24 = vm.and.i32 %23, %slt_11 : i32
vm.cond_br %24, ^bb11, ^bb14
^bb11: // pred: ^bb10
%25 = vm.trunc.i64.i32 %20 : i64 -> i32
%ref_12 = vm.call @hal.devices.get(%25) {nosideeffects} : (i32) -> !vm.ref<!hal.device>
%_utf8_hal_device_id_369FCE3B885986F6_13 = vm.const.ref.rodata @_utf8_hal_device_id_369FCE3B885986F6 : !vm.buffer
%_utf8_metal_FB846B13C56E663B = vm.const.ref.rodata @_utf8_metal_FB846B13C56E663B : !vm.buffer
%26:2 = vm.call @hal.device.query.i64(%ref_12, %_utf8_hal_device_id_369FCE3B885986F6_13, %_utf8_metal_FB846B13C56E663B) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer) -> (i32, i64)
%nz_14 = vm.cmp.nz.i64 %26#1 : i64
%27 = vm.select.i32 %26#0, %nz_14, %zero : i32
vm.cond_br %27, ^bb12, ^bb13(%zero : i32)
^bb12: // pred: ^bb11
%_utf8_metal_msl_fb_BD2413475D7AB9A0 = vm.const.ref.rodata @_utf8_metal_msl_fb_BD2413475D7AB9A0 : !vm.buffer
%28:2 = vm.call @hal.device.query.i64(%ref_12, %_utf8_hal_executable_format_C7128E2AE1BE720D_5, %_utf8_metal_msl_fb_BD2413475D7AB9A0) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer) -> (i32, i64)
%nz_15 = vm.cmp.nz.i64 %28#1 : i64
%29 = vm.select.i32 %28#0, %nz_15, %zero : i32
vm.br ^bb13(%29 : i32)
^bb13(%30: i32): // 2 preds: ^bb11, ^bb12
%eq_16 = vm.cmp.eq.i64 %21, %zero_0 : i64
%31 = vm.select.i64 %30, %c1_1, %zero_0 : i64
%32 = vm.add.i64 %21, %31 : i64
%33 = vm.and.i32 %30, %eq_16 : i32
%ref_17 = vm.select.ref %33, %ref_12, %null_2 : !vm.ref<!hal.device>
%34 = vm.add.i64 %20, %c1_1 : i64
vm.br ^bb10(%34, %32, %ref_17 : i64, i64, !vm.ref<!hal.device>)
^bb14: // pred: ^bb10
vm.cond_br %23, ^bb15, ^bb16
^bb15: // pred: ^bb14
vm.fail %c18, "HAL device `device_b` not found or unavailable: #hal.device.target<\22metal\22, [#hal.executable.target<\22metal-spirv\22, \22metal-msl-fb\22, {iree.gpu.target = #iree_gpu.target<arch = \22apple\22, features = \22spirv:v1.3,cap:Shader\22, wgp = <compute = fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 32768, max_workgroup_counts = [65535, 65535, 65535]>>}>]>"
^bb16: // pred: ^bb14
%_utf8_metal_msl_fb_BD2413475D7AB9A0_18 = vm.const.ref.rodata @_utf8_metal_msl_fb_BD2413475D7AB9A0 : !vm.buffer
%35:2 = vm.call @hal.device.query.i64(%22, %_utf8_hal_executable_format_C7128E2AE1BE720D_5, %_utf8_metal_msl_fb_BD2413475D7AB9A0_18) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer) -> (i32, i64)
%nz_19 = vm.cmp.nz.i64 %35#1 : i64
%36 = vm.select.i32 %35#0, %nz_19, %zero : i32
%37 = vm.select.i64 %36, %zero_0, %c-1 : i64
%eq_20 = vm.cmp.eq.i64 %37, %zero_0 : i64
vm.global.store.ref %22, @device_b : !vm.ref<!hal.device>
vm.cond_br %eq_20, ^bb17, ^bb18
^bb17: // pred: ^bb16
%multi_device_mul_dispatch_1_metal_msl_fb = vm.const.ref.rodata @multi_device_mul_dispatch_1_metal_msl_fb : !vm.buffer
%ref_21 = vm.call @hal.executable.create(%22, %c-1, %_utf8_metal_msl_fb_BD2413475D7AB9A0_18, %multi_device_mul_dispatch_1_metal_msl_fb, %null) {nosideeffects} : (!vm.ref<!hal.device>, i64, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref<!hal.executable>
vm.global.store.ref %ref_21, @__device_b_executable_0_multi_device_mul_dispatch_1 : !vm.ref<!hal.executable>
%ref_22 = vm.call @__multi_device_mul_memoize_apply() : () -> !vm.ref<!hal.command_buffer>
vm.global.store.ref %ref_22, @__multi_device_mul_memoize_result_0_device_a : !vm.ref<!hal.command_buffer>
%ref_23 = vm.call @__multi_device_mul_memoize_apply_0() : () -> !vm.ref<!hal.command_buffer>
vm.global.store.ref %ref_23, @__multi_device_mul_memoize_result_0_device_b : !vm.ref<!hal.command_buffer>
vm.return
^bb18: // pred: ^bb16
vm.fail %c14, "HAL device `device_b` does not support any variant of executable `multi_device_mul_dispatch_1`; available formats: [metal-msl-fb]"
}
}
}
Yep! The MVP only needs semaphore handling changes on the CPU side as it already uses the iree_hal_buffer_t vtable to perform the mapping instead of casting things directly. Import/export are required for non-CPU targets that want to share resources that have no host pointers - for example, Vulkan and CUDA/HIP. My expectation for that case is that either the import/export need to be extremely fast (implemented in drivers as a pointer indirection) or extremely infrequent (done as part of startup/loading tasks, etc). To have import/export work in mainline latency-critical execution it'll likely need a cache on the HAL side of imported buffers and that may require weak references so it's a bit of a ways off.
I’ve been looking into the semaphore timepoint export/import. We currently don’t use/have HIP available, but we’d be happy to contribute vulkan & local_* to the initial MVP. I’d like to start from the vulkan -> local_* angle, if that sounds good to you?
For vulkan, timepoint export appears to be “straight forward”, through vkGetSemaphoreFdKHR & iree_make_wait_primitive. (Could possibly use the timepoint semaphore itself to wait on for the cpu, but looks like a potential optimization to me)
Importing the timepoint into local_task, I found the following steps might work:
iree_hal_semaphore_acquire_timepointof thevalueiree_wait_handle_wrap_primitiveof thewait_primitiveto import.- Store the
wait_handleinto theiree_hal_task_timepoint_t.event
Does that sound correct? Anything else I am missing?
Thanks a lot for taking the time to outline this. We would be happy to help with any issues you created here.
We’ve been doing some prototyping ourselves using the metal/local-task drivers, where we took a shortcut for the “programs will be able to selectively allocate buffers that are compatible with all devices” by simply allocating everything as shared using the metal allocator and manually working around the semaphore issues by hand-modifying the HAL IR.
I have some questions about how to handle buffer import/export—or whether it's even necessary at all for the MVP. My observation so far is that drivers typically require buffers to be allocated using their own allocator, since many of the
iree_hal_*_command_buffer_*functions expect a native buffer handle. However, this doesn't seem to be the case for the local-task device, as I couldn’t find any references indicating that it requires its own native handle (which I assume isiree_hal_heap_buffer_t).Given this, is my understanding correct that for the MVP, buffer import/export might not be required at all? As you mentioned in (#20857/#20855), shared allocations would be performed by the device that supports them (metal, in this example), so the HAL functions on the GPU device would still have access to their native handles, and the local devices wouldn’t care?
If the import/export is still needed would it be something that is handled solely in the runtime or would the transfer ops have to be translated into some form of import/export ops instead of eliding them completely?
If its of any interest for someone, this is the IR snipped i got running for local-task/metal when replacing the local-task allocator with the metal allocator:
metal/local-task example originated from the
tools/test/iree-run-module-multi.mlir testcase, there are not a lot of modifications, mostly to resolve semaphore errors like joining on multiple fences that are not from the same device.compile command i used:
./build/tools/iree-compile --compile-from=vm --iree-execution-model=async-external --iree-hal-target-device=device_a=local[0] --iree-hal-target-device=device_b=metal --iree-hal-local-target-device-backends=vmvx working_multi_device.mlir -o multi.vmfbrun command:
./build/tools/iree-run-module --module=multi.vmfb --device=local-task --device=metal --input=4xf32=10,11,12,13 --function=multi_device_mul
Thanks for sharing this example. I'm trying to execute it in CPU+Metal with v3.4.0 but it's not working so far: I don't get the resulting hal.buffer_view. How did you set up your environment?
Importing the timepoint into
local_task, I found the following steps might work:
iree_hal_semaphore_acquire_timepointof thevalueiree_wait_handle_wrap_primitiveof thewait_primitiveto import.- Store the
wait_handleinto theiree_hal_task_timepoint_t.eventDoes that sound correct? Anything else I am missing?
Vulkan doesn't use the iree_hal_semaphore_acquire_timepoint API so that won't work there. Binary semaphores in Vulkan don't work either as they are auto-reset and don't allow multiple waiters (something we do). I've not looked in to how timeline semaphores are exported (via e.g. vkGetSemaphoreWin32HandleKHR) but that's where you'd want to start. If there's not a way to do that then it'll need a potentially complex implementation in the Vulkan HAL for managing native semaphores as well as exported timepoints in a way that does not induce host round-trips in the common case (and ideally not in the uncommon case). The entire Vulkan HAL needs a rewrite at some point, so adding a lot of new infra there is scary unless it can be quarantined and carried forward. But maybe there's a really simple solution and none of that's needed, just needs someone to do some experiments :)
Metal, CUDA, and HIP are much easier to start with because they do allow (though slow/higher latency) host callbacks to be inserted into the submission queue and those callbacks can then signal wait primitives directly.
@Achiirua if you want to execute it yourself it would require some local changes. I pushed them here: https://github.com/ziereis/iree/tree/metal_cpu_prototype. But keep in mind this is no proper solution to the general problem :D
with this i get this output:
ziereis@Thomass-MacBook-Pro ~/p/iree (metal_cpu_prototype)> ./build/tools/iree-run-module --module=multi.vmfb --device=local-task --device=metal --input=4xf32=10,11,12,13 --function=multi_device_mul --task_topology_group_count=1
Creating Metal HAL driver...
Metal driver created successfully.
Metal device created successfully.
Metal device created successfully.
iree_hal_metal_allocator_allocate_buffer
allocating metal buffer with params: type=DEVICE_LOCAL, usage=TRANSFER|DISPATCH_STORAGE, size=16
EXEC @multi_device_mul
iree_hal_metal_allocator_allocate_buffer
allocating metal buffer with params: type=DEVICE_LOCAL, usage=TRANSFER|DISPATCH_STORAGE, size=16
iree_hal_metal_allocator_allocate_buffer
allocating metal buffer with params: type=DEVICE_LOCAL, usage=TRANSFER|DISPATCH_STORAGE, size=64
iree_hal_metal_allocator_allocate_buffer
allocating metal buffer with params: type=DEVICE_LOCAL, usage=TRANSFER|DISPATCH_STORAGE, size=16
iree_hal_metal_allocator_allocate_buffer
allocating metal buffer with params: type=DEVICE_LOCAL, usage=TRANSFER|DISPATCH_STORAGE, size=64
iree_hal_metal_allocator_allocate_buffer
allocating metal buffer with params: type=HOST_LOCAL|DEVICE_VISIBLE, usage=TRANSFER|MAPPING, size=16
result[0]: hal.buffer_view
4xf32=0 55 144 273
@ziereis thanks for pushing this workaround
https://gist.github.com/benvanik/ecc9b37fb2b670ce1ed2fb0d7c694287 shows what I'm shooting for after #20965 on the compiler side with zero copies once we have the transfer elision pass. Initially it'll let us do homogeneous zero-copy sharding as it sidesteps the runtime issues around synchronization. I'll see if I can get something quickly done for at least making the local executors able to wait and signal non-local semaphores - other targets will need to be able to wait on host events, though, via import or a waiter parking lot thread.
Dropping a note that https://github.com/iree-org/iree/blob/f26a830f71a35d3264333fa3f6a12f6cb5d35e30/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp#L91-L95 will need the topology check logic in order to perform in-place updates across devices.
@ziereis Im trying to build your work around and get to the point that you were. Trying to see if I can do any contribution here :)
Could you maybe include the toolchain that you used and your compile flags?