[vulkan] Support VK_KHR_buffer_device_address and PhysicalStorageBuffer
Request description
Branching off issue https://github.com/openxla/iree/issues/13196 to see what it would take to implement VK_KHR_buffer_device_address and PhysicalStorageBuffer based access to Vulkan devices. Increasing we are dealing with very large tensors (>4GB) and maxStorageBufferRange is limited to 4GB. While we explore options with https://github.com/openxla/iree/issues/13196 this feature request is to see what it would take to move us to using VK_KHR_buffer_device_address.
We are seeing increasing model sizes Stable Diffusion (768x768) , LLaMA upto 65B etc that we are unable to run on our Vulkan backend today without doing a multi-process hack. We are also starting to see 16GB+ VRAM allocations on mobile SoC devices so this a requirement across the board for vulkan devices.
Some other references: https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/enabling_buffer_device_address.html ? There were some references to it from https://github.com/KhronosGroup/Vulkan-Docs/issues/1016
@antiagainst @benvanik @stellaraccident
What component(s) does this issue relate to?
No response
Additional context
No response
Adding @antiagainst to take a look when you can.
This would be a major change w.r.t. how we handle buffers across runtime and kernel. What the device buffer address extension does is enabling querying int64 physical GPU addresses for storage buffers so that we can populate them in uniform buffers or push constants and then let kernels directly load them as buffer pointers and do load/store thereafter. So it's pretty substantial, at least at the conceptual level.
Now, regarding what needs to be changed to support this extension, a few big parts:
- Runtime-kernel boundary: we'd need a way/convention to make sure we have consistency between runtime (in IREE) and kernel CodeGen (in MLIR). E.g., using a uniform buffer or push constants for all base buffer address. Using push constants avoids indirection but push constants already used for shape dimensions and such. If uniform buffer, which descriptor set. How buffer addresses for different descriptor (set, binding) are organized in the uniform buffer, etc.
- Runtime: we'd need to change how push descriptors are handled in the Vulkan HAL driver, to query device buffer address and populate them in the above manner. This actually makes it closer to LLVMGPU side in a sense.
- Kernel: we'd need to plumb through support for physical storage buffer and related capabilities in MLIR/SPIR-V CodeGen. And then follow the convention to not generate descriptors but load device pointers directly and use them.
I'd need to talk with @benvanik to get a more detailed picture w.r.t. what's need to be changed, esp. on the runtime side.
Yeah, lots tangled up here - we'll need to break it down. There's some easier solutions and some harder ones for sure :) I suspect we'll build ArgumentBuffer-like thing in the Vulkan HAL driver and map descriptor sets into that instead of native descriptor sets.
(this was how I was imagining enabling secondary command buffer buffer substitution in iree_hal_command_buffer_execute - so doing it may get us that too!)
Setting as P2 for now but leaving open for continued conversations and task lists etc. Please edit as needed.
Took a look at what would be needed. @antiagainst articulated the major parts and then there's a few details:
- need to do #7824 so that we can have alternative shaders (ones with/without device address support) - today without linking it'd create an explosion as each shader is its own executable whereas we should instead have one executable for all shaders with a given target configuration
- executable target configuration for device addresses propagated to vulkan/spir-v env so we can specify targets that do/don't support them
- HAL -> SPIR-V lowering needs to emit all resources as pointers loaded from a base descriptor when the configuration demands it
- SPIR-V ExecutableDef flatbuffer needs a flag for whether device addresses are used (this is how the runtime knows it needs to switch modes)
- FillBufferUnaligned builtin will need to support device addresses (today it's just descriptors)
- direct command buffer needs two paths for push_descriptor_set: one that does what it does today with DescriptorSetArena and another that uses binding tables
The idea would be to have each command buffer have a growable set of staging buffers with uniform-buffer-upload semantics (something we'd have to benchmark, but usually device-local|host-visible|host-coherent) and as the command buffer is recorded and push_descriptor_sets is called we'd slice off some of the current staging buffer, scribble in our buffer info, and then bind the staging buffer with a dynamic offset as a normal descriptor set operation using DescriptorSetArena as today. The shaders would have a buffer declared for the descriptors and access all descriptors indirectly through it. Something like:
iree_hal_vulkan_direct_command_buffer_push_descriptor_set(...) {
if (iree_hal_vulkan_native_pipeline_layout_has_indirect_access(pipeline_layout, set)) {
... stash on command buffer mirror of descriptor state ...
... mark parameters as dirty ...
} else {
// existing descriptor set binding path
}
}
iree_hal_vulkan_direct_command_buffer_dispatch(...) {
if (parameters are dirty) {
if (iree_hal_vulkan_native_pipeline_layout_has_indirect_access(pipeline_layout, set)) {
// upload a new parameters chunk by flushing the command buffer descriptor state for the pipeline layout
// this may allocate a new staging buffer if the prior one is exhausted
iree_hal_vulkan_direct_command_buffer_append_dispatch_parameters(pipeline_layout, &staging_descriptor_set, &staging_offset);
// bind the root descriptor with the dynamic offset of the dispatch - should be cheap
vkCmdBindDescriptorSets(staging_descriptor_set, staging_offset);
... reset parameters dirty flag ...
}
}
vkCmdDispatch...
}
And the shader:
layout(buffer_reference, std430, buffer_reference_align = 16) buffer binding_f32_t {
float data[];
};
layout(set = 3, binding = 0) buffer root_set_0_t {
binding_f32_t binding_0;
binding_f32_t binding_1;
binding_f32_t binding_2;
} root_set_0;
void main() {
root_set_0.binding_0.data[0]; // access...
}
Oh the other thing this intersects with is secondary indirect command buffers (iree_hal_command_buffer_execute_commands) - those are recorded with placeholders and then when executed the placeholders get updated to the bindings passed in via iree_hal_buffer_binding_table_t. There's a bit more bookkeeping required such that we map what a command-buffer-global binding slot is to the locations in the staging buffer that need to have that value populated. In the above we can write the device addresses in when flushing the parameters but here we'd instead scribble aside the offset into the parameter buffer where the address should be written as we don't actually have it at the time of recording. When a secondary command buffer is scheduled with iree_hal_command_buffer_execute_commands the hosting primary command buffer would use one or more vkCmdUpdateBuffer to populate the parameters in stream order and so long as we didn't allow multiple overlapping executions of the same secondary buffer (we'll need to track) we should be safe. I call this out because the compiler side will look identical and it's just additional tracking at runtime on top of the above work to also get reusable command buffers!
One thing that may need some fiddling is how to communicate the pipeline layout mode or if we want to make it per descriptor set (I think we want to make it per descriptor set) - we can add a bit to iree_hal_descriptor_set_layout_flags_t for whether it's a native descriptor (default) or an indirect one, have the compiler emit the flag when creating such sets, and then have that be queried during dispatch. The flags are set based on the executable target which we'd know needs the flag set. when binding we'd then ignore any indirect descriptor sets as those are covered by the parameter upload path.
#14777 has disabled VMA by default and hopefully it sticks (there may be some issues). We'll let that soak a bit and get into shark before fully removing VMA.
After that #14778 makes the runtime Vulkan HAL detect support for buffer device addresses and enables the feature on the allocations we make.
The next step is to implement the indirect parameter buffer in the Vulkan HAL in preparation for the compiler using it. I've got some sketches that dovetail with indirect command buffers and will see if I can piece them apart for some incremental work.
Following up from the discussions yesterday, here's the spec I'm going to be shooting for on the compiler/runtime side outside of codegen:
- new
HAL_DescriptorSetLayoutFlags_Indirect/IREE_HAL_DESCRIPTOR_SET_LAYOUT_FLAG_INDIRECTflag - new temporary global compiler flag so that MaterializeInterfaces sets the flag on
DescriptorSetLayoutAttr - runtime command buffer logic that uses set 3 binding N for each original set N
maybe - not sure I like the flag approach, but the below in-memory format is not likely to change
So what would have been:
#version 460
layout(set = 0, binding = 0, std430) buffer set_0_binding_0 { float data[]; };
layout(set = 0, binding = 1, std430) buffer set_0_binding_1 { float data[]; };
layout(set = 0, binding = 2, std430) buffer set_0_binding_2 { float data[]; };
// note no binding 0 used
layout(set = 1, binding = 1, std430) buffer set_1_binding_1 { float data[]; };
void main() {
set_0_binding_0.data[0];
set_0_binding_1.data[0];
set_0_binding_2.data[0];
set_1_binding_1.data[1] = 1.0f;
// ...
}
->
#version 460
#extension GL_EXT_buffer_reference : require
layout(buffer_reference, std430, buffer_reference_align = 16) buffer binding_f32_t {
float data[];
};
layout(set = 3, binding = 0) buffer set_0_t {
binding_f32_t binding_0;
binding_f32_t binding_1;
binding_f32_t binding_2;
} set_0;
layout(set = 3, binding = 1) buffer set_1_t {
binding_f32_t unused_binding_0; // note here for alignment
binding_f32_t binding_1;
} set_1;
void main() {
set_0.binding_0.data[0]; // access original set(0) binding(0)
set_0.binding_1.data[0]; // access original set(0) binding(1)
set_0.binding_2.data[0]; // access original set(0) binding(2)
set_1.binding_1.data[1] = 1.0f; // access original set(1) binding(0)
// ...
}
SPIR-V:
; SPIR-V
; Version: 1.6
; Generator: Khronos Glslang Reference Front End; 11
; Bound: 34
; Schema: 0
OpCapability Shader
OpCapability PhysicalStorageBufferAddresses
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel PhysicalStorageBuffer64 GLSL450
OpEntryPoint GLCompute %main "main" %set_0 %set_1
OpExecutionModeId %main LocalSizeId %uint_1 %uint_1 %uint_1
OpSource GLSL 460
OpSourceExtension "GL_EXT_buffer_reference"
OpName %main "main"
OpName %set_0_t "set_0_t"
OpMemberName %set_0_t 0 "binding_0"
OpMemberName %set_0_t 1 "binding_1"
OpMemberName %set_0_t 2 "binding_2"
OpName %binding_f32_t "binding_f32_t"
OpMemberName %binding_f32_t 0 "data"
OpName %set_0 "set_0"
OpName %set_1_t "set_1_t"
OpMemberName %set_1_t 0 "unused_binding_0"
OpMemberName %set_1_t 1 "binding_1"
OpName %set_1 "set_1"
OpMemberDecorate %set_0_t 0 Offset 0
OpMemberDecorate %set_0_t 1 Offset 8
OpMemberDecorate %set_0_t 2 Offset 16
OpDecorate %set_0_t Block
OpDecorate %_runtimearr_float ArrayStride 4
OpMemberDecorate %binding_f32_t 0 Offset 0
OpDecorate %binding_f32_t Block
OpDecorate %set_0 DescriptorSet 3
OpDecorate %set_0 Binding 0
OpMemberDecorate %set_1_t 0 Offset 0
OpMemberDecorate %set_1_t 1 Offset 8
OpDecorate %set_1_t Block
OpDecorate %set_1 DescriptorSet 3
OpDecorate %set_1 Binding 1
%void = OpTypeVoid
%3 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
OpTypeForwardPointer %_ptr_PhysicalStorageBuffer_binding_f32_t PhysicalStorageBuffer
%set_0_t = OpTypeStruct %_ptr_PhysicalStorageBuffer_binding_f32_t %_ptr_PhysicalStorageBuffer_binding_f32_t %_ptr_PhysicalStorageBuffer_binding_f32_t
%float = OpTypeFloat 32
%_runtimearr_float = OpTypeRuntimeArray %float
%binding_f32_t = OpTypeStruct %_runtimearr_float
%_ptr_PhysicalStorageBuffer_binding_f32_t = OpTypePointer PhysicalStorageBuffer %binding_f32_t
%_ptr_StorageBuffer_set_0_t = OpTypePointer StorageBuffer %set_0_t
%set_0 = OpVariable %_ptr_StorageBuffer_set_0_t StorageBuffer
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t = OpTypePointer StorageBuffer %_ptr_PhysicalStorageBuffer_binding_f32_t
%int_1 = OpConstant %int 1
%int_2 = OpConstant %int 2
%set_1_t = OpTypeStruct %_ptr_PhysicalStorageBuffer_binding_f32_t %_ptr_PhysicalStorageBuffer_binding_f32_t
%_ptr_StorageBuffer_set_1_t = OpTypePointer StorageBuffer %set_1_t
%set_1 = OpVariable %_ptr_StorageBuffer_set_1_t StorageBuffer
%float_1 = OpConstant %float 1
%_ptr_PhysicalStorageBuffer_float = OpTypePointer PhysicalStorageBuffer %float
%main = OpFunction %void None %3
%5 = OpLabel
%18 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_0 %int_0
%19 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %18
%21 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_0 %int_1
%22 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %21
%24 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_0 %int_2
%25 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %24
%29 = OpAccessChain %_ptr_StorageBuffer__ptr_PhysicalStorageBuffer_binding_f32_t %set_1 %int_1
%30 = OpLoad %_ptr_PhysicalStorageBuffer_binding_f32_t %29
%33 = OpAccessChain %_ptr_PhysicalStorageBuffer_float %30 %int_0 %int_1
OpStore %33 %float_1 Aligned 4
OpReturn
OpFunctionEnd
Update that I decided flags are fine for now as this can be experimental. In #14977 I've added the --iree-vulkan-experimental-indirect-bindings=true compiler flag that changes the executable format to that required by the runtime (vulkan-spirv-fb-ptr) and sets the Indirect flag on the descriptor set layouts on the exported executable variant functions.
Next steps on the runtime side are to route iree_hal_command_buffer_push_descriptor_set calls down a special parameter buffer path when the IREE_HAL_DESCRIPTOR_SET_LAYOUT_FLAG_INDIRECT flag is set (along with some other goo), while on the compiler side the codegen lowerings will need to inspect the layout flags and when IREE::HAL::DescriptorSetLayoutFlags::Indirect lower to the above SPIR-V binding style.
(there's a lot I don't like about this approach but it's not worth me stalling any longer - we've got enough other cleanup around SPIR-V executables and extensions pending and this is such a big switch that it may help ground out discussions on next steps by being so hideous :)
Thanks @benvanik for the details! @kuhar will help to flesh out the SPIR-V part:
- As we chatted, the majority of the changes can happen in the
ConvertToSPIRVPassright now--there is a pre-step there analyzing subspan ops and creating SPIR-V global variables for them. We can add the logic there to perform the extra indirect indexing via the argument/parameter buffer. - The addressing model needs to be
PhysicalStorageBuffer64--thespv.moduleop is directly created inConvertToSPIRVPasstoo so should be direct to modify. - And we need to plumb through support for the extension and associated capabilties in the upstream MLIR and IREE SPIR-V CodeGen side.
I opened a PR with the compiler support: https://github.com/openxla/iree/pull/16301, and landed another one with hal device queries for the related Vulkan extension: https://github.com/openxla/iree/pull/16282. There's also a landed MLIR PR for memref to spir-v conversion: https://github.com/llvm/llvm-project/pull/80243.
With these three PRs in the tree and 64-bit indexing enabled, the following e2e compiles but fails at runtime:
$ ninja iree-compile ~/iree/iree/tests/e2e/stablehlo_ops/add.mlir --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=rdna3-7900-linux --iree-vulkan-experimental-indirect-bindings=true -o add.vmfb --mlir-disable-threading --mlir-print-ir-after-all 2>add_all.log
$ tools/iree-check-module --module=add.vmfb --device=vulkan://0