[Codegen][LLVMGPU] Add a DropVectorUnitDims call
DropVectorUnitDims was factored out of the CPU pipeline last yer. Now, we have a usecase for adding it to the LLVMGPU pipeline.
Specifically, LLVMGPUVectorLowering will lower the broadcast in
%scale = ... : vector<1xf32>
%scale.bcast = vector.broadcast %scale : vector<1xf32> to vector<32x1xf32>
%res = arith.scaling_truncf %data, %scale_bcast : ...
to a series of inserts, which the later ArithToAMDGPU patterns don't recoginzies as a broadcast. This produces inefficient code.
However, if we first run the unit dimension elimination patterns, this causes the scaling truncation to run on a vector<32xf32> for the scales. Then, broadcast lowering rewrites the vector<1xf32> to vector<2xf32> broadcast into a splat, which ArithToAMDGPU does pattern-match as a uniform scale.
I also expect this to improve vectorization generally, causing some minor performance improvements overall (but thaven't had time to perf it).
ci-extra: test_torch
FYI, I'm considering to revamp https://github.com/iree-org/iree/pull/17530 for CPU backends. It is a more aggressive version that may flatten something like tensor<...x8x2xi8>, depends on the native vector size.
On further investigation, I don't know we'll need this in the case I'm interested in - will use this as a vehicle for a pipeline test once the more relevant upstream PR lands
I think this needs perf numbers. This can lead to weird shuffles if done on a non innermost dimensions and we need to make sure we aren't regressing.
I think this needs perf numbers. This can lead to weird shuffles if done on a non innermost dimensions and we need to make sure we aren't regressing.
I suspect part of that might have to do with the placement of LLVMGPUVectorLowering, which Max is taking a poke at?
I updated this branch and ran the CI bernchmarks.
I'll post the table below, showing mostly flat to 1x03x speedups depending on which benchmark we're looking at. I haven't investigated why the performance change, or if this amount of model-level change is noise, but I figured I'd flag this. (This was partly we making sure I can still run benchmarks).
Here's my summarization script.
#!/usr/bin/env perl
use warnings;
use strict;
use utf8;
use v5.30;
if (scalar @ARGV != 2) {
die "Usage: $0 [old] [new]";
}
my %old_times;
my %new_times;
open(my $old_file, "<", $ARGV[0]) || die "Couldn't open old results ${ARGV[1]}";
open(my $new_file, "<", $ARGV[1]) || die "Couldn't open new results ${ARGV[2]}";
while (<$old_file>) {
next unless /model_benchmark_run:model_benchmark_run.py:(\d+) (.+) benchmark time: ([0-9.]+) ms/;
$old_times{$2} = $3;
}
while (<$new_file>) {
next unless /model_benchmark_run:model_benchmark_run.py:(\d+) (.+) benchmark time: ([0-9.]+) ms/;
$new_times{$2} = $3;
die "Benchmark $2 not in old data" unless exists($old_times{$2});
}
die "Old benhmarks missing in new set" unless keys(%old_times) == keys(%new_times);
print("|Benchmark|Old time|New time|Speedup (old / new)|\n");
print("|------------|---------|-----------|--------------|\n");
foreach my $bench (keys %old_times) {
my $old_time = $old_times{$bench} * 1.0;
my $new_time = $new_times{$bench} * 1.0;
printf("|$bench|%.2f|%.2f|%.2f|\n", $old_time, $new_time, $old_time / $new_time);
}
| Benchmark | Old time | New time | Speedup (old / new) |
|---|---|---|---|
| sdxl clip | 7.34 | 7.13 | 1.03 |
| sdxl unet_fp16 | 65.13 | 65.35 | 1.00 |
| sdxl vae | 65.50 | 65.58 | 1.00 |
| llama 8b_f16_prefill_data_tiling | 50.99 | 50.96 | 1.00 |
| llama 8b_f16_decode_data_tiling | 30.26 | 30.06 | 1.01 |
| llama 8b_f16_decode | 8.87 | 8.61 | 1.03 |
| llama 8b_f16_prefill | 32.60 | 32.01 | 1.02 |
| sdxl punet_int8_fp16 | 43.36 | 43.17 | 1.00 |
| sdxl punet_int8_fp8 | 42.94 | 42.97 | 1.00 |
Invistigative note: on llama 8b f16 decode, this change meaningfully reduced register usage, for example
.symbol: 'prefill_bs?$async_dispatch_14_matmul_like_Dx14336x4096_f16xf16xf32.kd'
...
- .vgpr_count: 194
+ .vgpr_count: 168
.vgpr_spill_count: 0
wavefront_size: 64
- .agpr_count: 46
+ .agpr_count: 32
Not entirely sure why, but at least we have some reason for the perf differences
This is less related to flattening, more related to the fact that our broadcast lowering is "bad". https://github.com/iree-org/iree/issues/21978 is bad because of the same reason.
The correct way to lower broadcast is to not unroll any broadcasted dims and rely on other ops to unroll to 1-D vector and fold this broadcast. You don't need to lower a broadcast after that at all.
The problem with "just" unrolling is the vector<.... x 1 x f32> we get decently often, so even with that fixed, we still need this pass
@krzysz00 Can you rebase this? I added ci-extra: test_torch to it. If it improves perfs let's land this, this is effectively doing SLP Vectorization which is okay to do here.