Memory footprint of initializer for data-tiling (2x -> 1x)
When we enable data-tiling and the set encodings on parameters get hoisted into initializers, this results in double the memory footprint, iiuc from loading the original parameters and the output buffers for the encoded parameters. This becomes a problem for llama405 as the memory footprint +-200GB weights can't be doubled on MI355 (only 288GB available) and we run out of memory. How do we go about reducing this memory footprint of the initializer?
Repro:
util.global private @weight = #flow.parameter.named<"model"::"weight"> : tensor<128256x4096xf16>
util.global private @scale = #flow.parameter.named<"model"::"scale"> : tensor<128256xf32>
util.func public @test(%arg0: tensor<?x4096xf16>) -> tensor<?x128256xf16> {
%arg1 = util.global.load @weight : tensor<128256x4096xf16>
%arg2 = util.global.load @scale : tensor<128256xf32>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = tensor.dim %arg0, %c0 : tensor<?x4096xf16>
%2 = util.assume.int %dim<umin = 256, umax = 524160, udiv = 256> : index
%17 = tensor.empty(%2) : tensor<?x128256xf32>
%18 = linalg.fill ins(%cst : f32) outs(%17 : tensor<?x128256xf32>) -> tensor<?x128256xf32>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x4096xf16>, tensor<128256x4096xf16>) outs(%18 : tensor<?x128256xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%22 = arith.extf %in : f16 to f32
%23 = arith.extf %in_0 : f16 to f32
%24 = arith.mulf %22, %23 : f32
%25 = arith.addf %out, %24 : f32
linalg.yield %25 : f32
} -> tensor<?x128256xf32>
%20 = tensor.empty(%2) : tensor<?x128256xf16>
%21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%19, %arg2 : tensor<?x128256xf32>, tensor<128256xf32>) outs(%20 : tensor<?x128256xf16>) {
^bb0(%in: f32, %scale: f32, %out: f16):
%10 = arith.mulf %in, %scale : f32
%11 = arith.truncf %10 : f32 to f16
linalg.yield %11 : f16
} -> tensor<?x128256xf16>
util.return %21 : tensor<?x128256xf16>
}
Compilation command:
iree-compile repro.mlir \
--iree-hal-target-device=hip \
--iree-hip-target=gfx942 \
-o repro.vmfb \
--iree-dispatch-creation-data-tiling
Run command:
iree-run-module \
--device="hip://0" \
--device_allocator=caching \
--hip_use_streams=true \
--module="repro.vmfb" \
--function=test \
--parameters=model=params.irpa \
--input=8192x4096xf16=1
Note, you can generate a params.irpa file with:
import iree.runtime as rt
import numpy as np
weight = np.ones((128256, 4096), np.float16)
scale = np.ones((128256,), np.float32) * 2.
parameter_index = rt.ParameterIndex()
parameter_index.add_buffer("weight", weight)
parameter_index.add_buffer("scale", scale)
parameter_index.create_archive_file("params.irpa")
To demonstrate where I think the problem is, see the below initializer IR snippet at the end of compilation, which contains an @io_parameters.load, which I think allocates +-1GB for the original parameter, and @hal.device.queue.alloca which allocates another 1 GB for storing the encoded parameter.
%list = vm.call @io_parameters.load(%__device_0, %c-1_1, %null_0, %ref_12, %_utf8_model_18B28E5820ED409E, %c-1_1, %c48, %c527363, %_const, %_const_3, %_const_4) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, !vm.buffer, i64, i32, i32, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.list<!vm.ref<!hal.buffer>>
%ref_13 = vm.list.get.ref %list, %zero_2 : (!vm.list<!vm.ref<!hal.buffer>>, i32) -> !vm.ref<!hal.buffer>
%ref_14 = vm.list.get.ref %list, %c1 : (!vm.list<!vm.ref<!hal.buffer>>, i32) -> !vm.ref<!hal.buffer>
%ref_15 = vm.call @hal.fence.create(%__device_0, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
%ref_16 = vm.call @hal.device.queue.alloca(%__device_0, %c-1_1, %null_0, %ref_15, %zero, %c48, %c527363, %c1051186176, %zero) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i64, i32, i32, i64, i64) -> !vm.ref<!hal.buffer>
%ref_17 = vm.call.variadic @hal.fence.join(%zero, [%ref_12, %ref_15]) {nosideeffects} : (i64, !vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence>
%ref_18 = vm.call @hal.command_buffer.create(%__device_0, %c1, %c3, %c-1_1, %zero_2) : (!vm.ref<!hal.device>, i32, i32, i64, i32) -> !vm.ref<!hal.command_buffer>
vm.call.variadic @hal.command_buffer.dispatch(%ref_18, %ref_11, %zero_2, %c8208384, %c1, %c1, %zero, [], [(%zero_2, %zero_2, %ref_13, %zero, %c1050673152), (%zero_2, %zero_2, %ref_16, %zero, %c1051186176)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_18, %ref_11, %c1, %c2004, %c1, %c1, %zero, [], [(%zero_2, %zero_2, %ref_14, %zero, %c513024), (%zero_2, %zero_2, %ref_16, %zero, %c1051186176)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call @hal.command_buffer.execution_barrier(%ref_18, %c28, %c13, %zero) : (!vm.ref<!hal.command_buffer>, i32, i32, i64) -> ()
vm.call @hal.command_buffer.finalize(%ref_18) : (!vm.ref<!hal.command_buffer>) -> ()
%ref_19 = vm.call @hal.fence.create(%__device_0, %zero) : (!vm.ref<!hal.device>, i64) -> !vm.ref<!hal.fence>
vm.call @hal.device.queue.execute(%__device_0, %c-1_1, %ref_17, %ref_19, %ref_18, %zero) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, !vm.ref<!hal.command_buffer>, i64) -> ()
vm.global.store.ref %ref_11, @__device_0_executable_0_matmul_with_params_linked : !vm.ref<!hal.executable>
%20 = vm.call.variadic @hal.fence.await(%c-1, %zero, [%ref_19]) : (i32, i64, !vm.ref<!hal.fence> ...) -> i32
vm.cond_br %20, ^bb11, ^bb10
^bb9: // pred: ^bb7
vm.fail %c14, "HAL device `__device_0` does not support any variant of executable `matmul_with_params_linked`; available formats: [rocm-hsaco-fb]"
^bb10: // pred: ^bb8
vm.global.store.ref %ref_16, @__hoisted_tensor_128256x4096xf16__encoded : !vm.ref<!hal.buffer>
vm.return
Also see the memory footprint in the below tracy screenshots:
Without data-tiling we see a memory footprint of 1GB on initialization:
With data-tiling we see a memory footprint of 2GB on initialization:
@hanhanW @MaheshRavishankar @benvanik
Thanks @jtuyls for the repro! I can reproduce it without running the test. I'll share the IR and steps.
I recently learned that the unexpected memory allocation issue sometimes can be identified in Stream output. We can use iree-compile with the repro command: iree-compile --iree-hal-target-device=hip --iree-hip-target=gfx942 --iree-dispatch-creation-data-tiling ~/repro.mlir --compile-to=stream -o ~/stream.dt.out.mlir
Here is the dump with data-tiling flag and without data-tiling flag.
We can grep stream.resource.alloca to see the total allocation; we find that IREE allocates %c1051186176 bytes in the initializer.
❯ rg stream.resource.alloca stream.*.out.mlir
stream.nodt.out.mlir
87: %result, %result_timepoint = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_0>) : !stream.resource<external>{%4} => !stream.timepoint
stream.dt.out.mlir
49: %result, %result_timepoint_0 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_0>) : !stream.resource<constant>{%c1051186176} => !stream.timepoint
174: %result, %result_timepoint = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_0>) : !stream.resource<external>{%6} => !stream.timepoint
175: %result_0, %result_timepoint_1 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_0>) : !stream.resource<transient>{%5} => !stream.timepoint
Below is the initializer snippet, and it allocates ~1GB for loading the parameters and ~1GB for the new global. It falls to the 2x memory footprint scenario in https://github.com/iree-org/iree/issues/21659. What I'm currently working on in that issue is reducing the memory footprint from 3x to 2x; this issue is the next issue that we have to fix.
I don't know what the solution for now; we likely need some inputs from @benvanik. A potential solution may be breaking the initializers into several chunks, based on the available memory on the device. In this context, the minimal memory requirement is 2x of the largest parameter, which should be okay in llama.
util.initializer {
%c0 = arith.constant 0 : index
%c1051186176 = arith.constant 1051186176 : index
%c513024 = arith.constant 513024 : index
%c1050673152 = arith.constant 1050673152 : index
%c0_i64 = arith.constant 0 : i64
%results:2, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@__device_0>) {
"model"::"weight"[%c0_i64] : !stream.resource<constant>{%c1050673152},
"model"::"scale"[%c0_i64] : !stream.resource<constant>{%c513024}
} => !stream.timepoint
%result, %result_timepoint_0 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_0>) : !stream.resource<constant>{%c1051186176} => !stream.timepoint
%0 = stream.timepoint.join max(%result_timepoint, %result_timepoint_0) => !stream.timepoint
%1 = stream.cmd.execute once on(#hal.device.affinity<@__device_0>) await(%0) => with(%results#0 as %arg0: !stream.resource<constant>{%c1050673152}, %results#1 as %arg1: !stream.resource<constant>{%c513024}, %result as %arg2: !stream.resource<constant>{%c1051186176}) {
stream.cmd.concurrent {
stream.cmd.dispatch @_encoding_0::@_encoding_0_encode_128256x4096xf16_to_128256x4096xf16 {
ro %arg0[%c0 for %c1050673152] : !stream.resource<constant>{%c1050673152},
wo %arg2[%c0 for %c1051186176] : !stream.resource<constant>{%c1051186176}
}
stream.cmd.dispatch @_encoding_1::@_encoding_1_encode_128256xf32_to_128256xf32 {
ro %arg1[%c0 for %c513024] : !stream.resource<constant>{%c513024},
wo %arg2[%c0 for %c1051186176] : !stream.resource<constant>{%c1051186176}
}
}
} => !stream.timepoint
%2 = stream.timepoint.await sync %1 => %result : !stream.resource<constant>{%c1051186176}
util.global.store %2, @__hoisted_tensor_128256x4096xf16__encoded : !stream.resource<constant>
util.return
}
The input of stream, i.e., the output with --compile-to=flow: https://gist.github.com/hanhanW/263bf4b3d8d197e8ab0fd22d5fe09878
The IR dump of the stream transformations: https://gist.github.com/hanhanW/7352a942474080d3f5521aaa1d4b2cea
Coming from https://github.com/iree-org/iree/pull/22118#discussion_r2457572997, and I'm mirroring the context from @benvanik, so I won't miss it in the future.
I was asking question if we should do it in CombineInitilaizers pass, and Ben said:
That's something different - that's pipelining (as a degenerate case of short depth and each stage being a small number of operations). We should be pipelining everywhere, and whether we're doing it in an initializer or a function it should be the same. So there won't be any pass that splits initializers or anything. We'll be breaking up large execution regions into smaller chunks, overlapping execution of certain stages, and the chunk granularity will be based on a policy.
For the policy, initially it will use the Stream resource constraints that let us specify e.g. what the maximum size of an allocation is and a few other things like favor-speed or favor-min-peak-memory. I've got some changes in-flight for adding pipelining control (max in-flight operations, max outstanding memory, etc). Devices (indirectly) can then specify their limits. The key here is going to be that the pipeline depth is going to be controlled via SSA values, not flags/constants/etc: we'll have a way to query for an affinity what the ideal pipeline depth is, have all the IR built for being flexible, and let the rest of the compiler do its magic (resolving/folding/unrolling/etc).
The thing preventing all that (and one of my 81 rabbit holes I'm currently down :) is that we've never supported control flow well at this level of the compiler. We've got a lot of work to do to make structured control flow viable and that's a lot of what I'm working on - this PR is part of that workstream as is the big chain of users/benvanik/16168-* branches (there's like 7k more lines of code coming). Once we can support things like scf.for with timepoints, resources, allocations, partitioning, etc we'll be able to start doing some of the pipelining stuff.
All of that will then be useful for our encoding specialization modules (which is where this all started): we don't want to load 500GB of parameters, allocate 500GB of transients, and then calculate them all in one giant command buffer. It's the same behavior we want in normal programs, though, both during initialization and during normal execution (today we greedily throw everything together).
You opened an epically cool stream of work from asking for the specialization stuff ;P
@hanhanW I am wondering what the plan here is to move forward? Do we have a breakdown of all pieces that are needed somewhere? (Sorry if missed it somewhere else, just got back from holiday this week)
@hanhanW I am wondering what the plan here is to move forward? Do we have a breakdown of all pieces that are needed somewhere? (Sorry if missed it somewhere else, just got back from holiday this week)
I'm not sure. I did not get updates from Ben, and I'd like to check with Ben. I probably need to study https://github.com/iree-org/iree/pull/22814 which looks relevant.
I've been mainly working on the other issue (3x->2x), which happens before Ben's pass. The latest status of my work is revamping my prototype with control-flow supports in initializer, which is new to IREE recently: https://github.com/iree-org/iree/issues/22485 I'm picking up my work after the vacation.
Good news, I made some progress with https://github.com/iree-org/iree/pull/22814. I made few small examples with encodings, and we are able to do compiler->tool->runtime trick for encodings. I'll look at llama models tomorrow.
I hacked GenerateSplatParameterArchive.cpp to generate value 1s, but not 0s. So we can check the output. (patch)
Input program: encoded_two_globals_f16f16f32.mlir
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2]>
#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2]>
#encoding_res = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2]>
util.global private @weight_0 = #flow.parameter.named<"model"::"weight_0"> : tensor<456x789xf16>
util.global private @weight_1 = #flow.parameter.named<"model"::"weight_1"> : tensor<456x789xf16>
// This global holds the transformed value.
util.global private @encoded_weight_0 : tensor<456x789xf16, #encoding_rhs>
util.global private @encoded_weight_1 : tensor<456x789xf16, #encoding_rhs>
util.initializer {
// Load the raw parameter (all ones from splat).
%raw0 = util.global.load @weight_0 : tensor<456x789xf16>
%encoded0 = flow.tensor.encode %raw0 : tensor<456x789xf16> -> tensor<456x789xf16, #encoding_rhs>
util.global.store %encoded0, @encoded_weight_0 : tensor<456x789xf16, #encoding_rhs>
%raw1 = util.global.load @weight_1 : tensor<456x789xf16>
%encoded1 = flow.tensor.encode %raw1 : tensor<456x789xf16> -> tensor<456x789xf16, #encoding_rhs>
util.global.store %encoded1, @encoded_weight_1 : tensor<456x789xf16, #encoding_rhs>
util.return
}
util.func private @gemm(%lhs: tensor<123x456xf16, #encoding_lhs>, %rhs: tensor<456x789xf16, #encoding_rhs>) -> tensor<123x789xf32> {
%cst = arith.constant 0.0 : f32
%init = tensor.empty() : tensor<123x789xf32, #encoding_res>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<123x789xf32, #encoding_res>) -> tensor<123x789xf32, #encoding_res>
%op = linalg.matmul
ins(%lhs, %rhs : tensor<123x456xf16, #encoding_lhs>, tensor<456x789xf16, #encoding_rhs>)
outs(%fill : tensor<123x789xf32, #encoding_res>) -> tensor<123x789xf32, #encoding_res>
%result = iree_encoding.unset_encoding %op : tensor<123x789xf32, #encoding_res> -> tensor<123x789xf32>
util.return %result : tensor<123x789xf32>
}
func.func @main(%lhs_src_0: tensor<123x456xf16>, %lhs_src_1: tensor<123x456xf16>) -> tensor<123x789xf32> {
%lhs0 = iree_encoding.set_encoding %lhs_src_0 : tensor<123x456xf16> -> tensor<123x456xf16, #encoding_lhs>
%rhs0 = util.global.load @encoded_weight_0 : tensor<456x789xf16, #encoding_rhs>
%result0 = util.call @gemm(%lhs0, %rhs0) : (tensor<123x456xf16, #encoding_lhs>, tensor<456x789xf16, #encoding_rhs>) -> tensor<123x789xf32>
%lhs1 = iree_encoding.set_encoding %lhs_src_1 : tensor<123x456xf16> -> tensor<123x456xf16, #encoding_lhs>
%rhs1 = util.global.load @encoded_weight_1 : tensor<456x789xf16, #encoding_rhs>
%result1 = util.call @gemm(%lhs1, %rhs1) : (tensor<123x456xf16, #encoding_lhs>, tensor<456x789xf16, #encoding_rhs>) -> tensor<123x789xf32>
// Disable the (unset_encoding -> elementwise) fusion, because backends can not support it.
%operand_0 = util.optimization_barrier %result0 : tensor<123x789xf32>
%operand_1 = util.optimization_barrier %result1 : tensor<123x789xf32>
%init = tensor.empty() : tensor<123x789xf32>
%result = linalg.add ins(%operand_0, %operand_1 : tensor<123x789xf32>, tensor<123x789xf32>) outs(%init : tensor<123x789xf32>) -> tensor<123x789xf32>
return %result : tensor<123x789xf32>
}
Target of Compilation: gfx1100 (you can switch to gfx950) Target of EncodeParameter: llvmcpu
#!usr/bin/bash
set -ex
build/tools/iree-compile \
--iree-hal-target-device=hip \
--iree-hip-target=gfx1100 \
~/encoded_two_globals_f16f16f32.mlir \
--iree-parameter-encoder-output-file=repro_encoder.mlir \
--iree-parameter-splat=repro_input.irpa \
-o /tmp/repro_main.vmfb
# CPU does not support relayout well, so there is a big stack buffer.
build/tools/iree-compile repro_encoder.mlir \
--iree-hal-target-device=local \
--iree-hal-local-target-device-backends=llvm-cpu --iree-llvmcpu-target-cpu=znver4 \
--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false \
-o repro_encoder.vmfb
build/tools/iree-encode-parameters \
--module=repro_encoder.vmfb \
--parameters=model=repro_input.irpa \
--output=encoded=repro_output.irpa \
--quiet
build/tools/iree-run-module \
--device=hip \
--module=/tmp/repro_main.vmfb \
--function=main \
--input=123x456xf16=1.0 \
--input=123x456xf16=2.0 \
--parameters=model=repro_input.irpa \
--parameters=encoded=repro_output.irpa
Output:
EXEC @main
result[0]: hal.buffer_view
123x789xf32=[1368 1368 1368 1368 1368 1368 1368 1368 1368 1368 1368 1368...
Update: there are few issues in CPU codegen that need to support GPU relayout behavior better, but we can bypass it with flags. However, I ran into https://github.com/iree-org/iree/issues/22941 issue, which crashes in iree-encode-parameter.