iree icon indicating copy to clipboard operation
iree copied to clipboard

[Codegen][LLVMGPU] Add a DropVectorUnitDims call

Open krzysz00 opened this issue 5 months ago • 9 comments

DropVectorUnitDims was factored out of the CPU pipeline last yer. Now, we have a usecase for adding it to the LLVMGPU pipeline.

Specifically, LLVMGPUVectorLowering will lower the broadcast in

%scale = ... : vector<1xf32>
%scale.bcast = vector.broadcast %scale : vector<1xf32> to vector<32x1xf32>
%res = arith.scaling_truncf %data, %scale_bcast : ...

to a series of inserts, which the later ArithToAMDGPU patterns don't recoginzies as a broadcast. This produces inefficient code.

However, if we first run the unit dimension elimination patterns, this causes the scaling truncation to run on a vector<32xf32> for the scales. Then, broadcast lowering rewrites the vector<1xf32> to vector<2xf32> broadcast into a splat, which ArithToAMDGPU does pattern-match as a uniform scale.

I also expect this to improve vectorization generally, causing some minor performance improvements overall (but thaven't had time to perf it).

ci-extra: test_torch

krzysz00 avatar Jul 28 '25 19:07 krzysz00

FYI, I'm considering to revamp https://github.com/iree-org/iree/pull/17530 for CPU backends. It is a more aggressive version that may flatten something like tensor<...x8x2xi8>, depends on the native vector size.

hanhanW avatar Jul 28 '25 19:07 hanhanW

On further investigation, I don't know we'll need this in the case I'm interested in - will use this as a vehicle for a pipeline test once the more relevant upstream PR lands

krzysz00 avatar Jul 29 '25 15:07 krzysz00

I think this needs perf numbers. This can lead to weird shuffles if done on a non innermost dimensions and we need to make sure we aren't regressing.

Groverkss avatar Jul 29 '25 16:07 Groverkss

I think this needs perf numbers. This can lead to weird shuffles if done on a non innermost dimensions and we need to make sure we aren't regressing.

I suspect part of that might have to do with the placement of LLVMGPUVectorLowering, which Max is taking a poke at?

krzysz00 avatar Jul 29 '25 17:07 krzysz00

I updated this branch and ran the CI bernchmarks.

I'll post the table below, showing mostly flat to 1x03x speedups depending on which benchmark we're looking at. I haven't investigated why the performance change, or if this amount of model-level change is noise, but I figured I'd flag this. (This was partly we making sure I can still run benchmarks).

Here's my summarization script.

#!/usr/bin/env perl
use warnings;
use strict;
use utf8;
use v5.30;

if (scalar @ARGV != 2) {
    die "Usage: $0 [old] [new]";
}

my %old_times;
my %new_times;
open(my $old_file, "<", $ARGV[0]) || die "Couldn't open old results ${ARGV[1]}";
open(my $new_file, "<", $ARGV[1]) || die "Couldn't open new results ${ARGV[2]}";

while (<$old_file>) {
    next unless /model_benchmark_run:model_benchmark_run.py:(\d+) (.+) benchmark time: ([0-9.]+) ms/;
    $old_times{$2} = $3;
}

while (<$new_file>) {
    next unless /model_benchmark_run:model_benchmark_run.py:(\d+) (.+) benchmark time: ([0-9.]+) ms/;
    $new_times{$2} = $3;
    die "Benchmark $2 not in old data" unless exists($old_times{$2});
}
die "Old benhmarks missing in new set" unless keys(%old_times) == keys(%new_times);

print("|Benchmark|Old time|New time|Speedup (old / new)|\n");
print("|------------|---------|-----------|--------------|\n");
foreach my $bench (keys %old_times) {
    my $old_time = $old_times{$bench} * 1.0;
    my $new_time = $new_times{$bench} * 1.0;
    printf("|$bench|%.2f|%.2f|%.2f|\n", $old_time, $new_time, $old_time / $new_time);
}
Benchmark Old time New time Speedup (old / new)
sdxl clip 7.34 7.13 1.03
sdxl unet_fp16 65.13 65.35 1.00
sdxl vae 65.50 65.58 1.00
llama 8b_f16_prefill_data_tiling 50.99 50.96 1.00
llama 8b_f16_decode_data_tiling 30.26 30.06 1.01
llama 8b_f16_decode 8.87 8.61 1.03
llama 8b_f16_prefill 32.60 32.01 1.02
sdxl punet_int8_fp16 43.36 43.17 1.00
sdxl punet_int8_fp8 42.94 42.97 1.00

krzysz00 avatar Sep 24 '25 02:09 krzysz00

Invistigative note: on llama 8b f16 decode, this change meaningfully reduced register usage, for example

     .symbol:         'prefill_bs?$async_dispatch_14_matmul_like_Dx14336x4096_f16xf16xf32.kd'
...
-    .vgpr_count:     194
+    .vgpr_count:     168
 .vgpr_spill_count: 0
 wavefront_size: 64
- .agpr_count:     46
+ .agpr_count:     32 

Not entirely sure why, but at least we have some reason for the perf differences

krzysz00 avatar Sep 24 '25 23:09 krzysz00

This is less related to flattening, more related to the fact that our broadcast lowering is "bad". https://github.com/iree-org/iree/issues/21978 is bad because of the same reason.

The correct way to lower broadcast is to not unroll any broadcasted dims and rely on other ops to unroll to 1-D vector and fold this broadcast. You don't need to lower a broadcast after that at all.

Groverkss avatar Sep 25 '25 08:09 Groverkss

The problem with "just" unrolling is the vector<.... x 1 x f32> we get decently often, so even with that fixed, we still need this pass

krzysz00 avatar Sep 25 '25 14:09 krzysz00

@krzysz00 Can you rebase this? I added ci-extra: test_torch to it. If it improves perfs let's land this, this is effectively doing SLP Vectorization which is okay to do here.

Groverkss avatar Oct 27 '25 12:10 Groverkss