Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

Update the xla commit to experiment potential fix for OOM

Open felixwqp opened this issue 5 months ago • 28 comments

Try to fix the large number of zero-broadcast generated from big pad in spmd partitioner.

for sample HLO,

  HloModule module_0076.reactant_loop_after_spmd_partitioner_cp_pattern

    ENTRY main() -> f64[3056,123] {
      %constant.946 = f64[] constant(0)
      %subtract.1055 = f64[3056,12272]{1,0} parameter(0), sharding={devices=[2,2]<=[2,2]T(1,0)}
      %slice.1057 = f64[3056,1]{1,0} slice(%subtract.1055), slice={[0:3056], [12271:12272]}, sharding={devices=[2,2]<=[2,2]T(1,0)}
      ROOT %pad.1059 = f64[3056,123]{1,0} pad(%slice.1057, %constant.946), padding=0_0x0_122, sharding={devices=[2,2]<=[2,2]T(1,0)}
    }

before the fix, the result looks like: 4LjPYjEP2XJGwcn

after the fix, the result looks like: https://screenshot.googleplex.com/ 3BZP5MRC5kMpK2L

felixwqp avatar Aug 15 '25 05:08 felixwqp

I fixed a bug in the workflow where we were using the wrong xla commit (apparently this is the first time we're using a custom xla commit, so we didn't spot it before) and also copied the changes from #1243 to replicate the conditions which trigger the OOM.

For readers, this is related to #671.

giordano avatar Aug 15 '25 07:08 giordano

This is failing to fetch the GB-25 repository, without much information: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16985145489/job/48166877204#step:11:63. Sigh.

giordano avatar Aug 15 '25 14:08 giordano

also the sed for xla replacement is now wrong (per a lot of bazel reshuffling done over the past week to fix things more stably).

The new way to do so is in https://github.com/EnzymeAD/Reactant.jl/blob/07d4fcacb935f3e915ca40c1ab7c98a210a93efc/deps/ReactantExtra/WORKSPACE#L198:

xla_workspace(NEW_XLA_PATCHES)

becomes

xla_workspace(NEW_XLA_PATCHES, 'abc2342343432432432')

wsmoses avatar Aug 15 '25 15:08 wsmoses

https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/job/48196660878?pr=1307#step:18:720

2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

same as https://github.com/EnzymeAD/Enzyme-JAX/pull/1243#issuecomment-3146860015 😢 XLA dump uploaded to https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/artifacts/3777223905

giordano avatar Aug 15 '25 21:08 giordano

Thanks @giordano for the quick verification, just trying to understand effect the above XLA commit to help us better prioritize the optimization direction.

can I assume,

without the commit, link, after rematerialization, memory usage is 67.69GiB

2025-08-02 23:46:25.965629: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3423] Can't reduce memory use below 10.64GiB (11429599125 bytes) by rematerialization; only reduced to 67.69GiB (72679996812 bytes), down from 71.05GiB (76287081352 bytes) originally

after the commit, the link, after remateriliazation, the memory usage is 58.50GiB,

2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

Question:

  1. is there any Enzyme level difference between these two runs account for these 10GB memory usage reduction? If not it means the HLO optimization did improve the memory usage, we can prioritize the same methodology for further XLA memory optimization.

cc: @wsmoses

felixwqp avatar Aug 15 '25 22:08 felixwqp

I believe I linked two different lines, i.e. after two different compilation stages (we compile two different kernels), this should be a more direct comparison (always after the first kernel): #1243:

2025-08-02 23:41:14.191988: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3423] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

this PR

2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

In any case I changed this PR to always run "vanilla XLA" vs your PR, for a quicker comparison.

giordano avatar Aug 15 '25 22:08 giordano

x/ref https://github.com/openxla/xla/pull/30307

wsmoses avatar Aug 16 '25 06:08 wsmoses

@felixwqp I warmly recommend using git --force-with-lease instead of --force, so you don't keep undoing my fixes 🙂

giordano avatar Aug 18 '25 16:08 giordano

Ah, thank you for this suggestion! Will use it going forward. apologies for overriding your commits, it's new to github review process, any suggestions are helpful and welcome!

felixwqp avatar Aug 18 '25 17:08 felixwqp

I want to only update the xla commit to 1ac176a9b8b4800bc2753d944eec62a39e6189b8 to verify if the hlo dump looks as intended. No need to trigger OOM anymore.

new commit failed with

Error: The artifact name is not valid: julia-environment-1.11-1ac176a9b8b4800bc2753d944eec62a39e6189b8-mg_sharded-factors-ap/persistent_compile_cache. Contains the following character:  Forward slash /

Should I adjust reactant_commit?

   reactant_commit:
          - 'ap/persistent_compile_cache'

felixwqp avatar Aug 18 '25 19:08 felixwqp

That was supposed to be addressed by #1297 🤔

giordano avatar Aug 18 '25 19:08 giordano

Hopefully fixed by #1316. I rebased on main. Please don't override again my changes 😅 Edit: confirmed that fixed the artifact issue.

giordano avatar Aug 18 '25 20:08 giordano

After Billy fixed a mistake in the build system: XLA dump on this PR: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17060619747/artifacts/3795845666 (7.44 MB) XLA dump on main: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17051302903/artifacts/3792563288 (18.4 MB)

Only based on the archive size, your PR is doing something! 😁 (I'm on the phone, can't open the archives)

giordano avatar Aug 19 '25 07:08 giordano

Yeah, a list of zero-bcast are coalesced as expected in https://github.com/EnzymeAD/Enzyme-JAX/pull/1307#issue-3324447702,

e.g. from the new dump, %broadcast.764 is the coalesced bcast.

  %broadcast.764 = f64[760,1527]{1,0} broadcast(%constant.304), dimensions={}, metadata={op_name="pad.5725"}
  %collective-permute.1 = f64[760,1]{1,0} collective-permute(%slice.383), channel_id=2, source_target_pairs={{0,2},{1,3}}, metadata={op_name="pad.5725"}
  %concatenate.5 = f64[760,1528]{1,0} concatenate(%broadcast.764, %collective-permute.1), dimensions={1}, metadata={op_name="pad.5725"}
  %pad.39 = f64[760,3055]{1,0} pad(%concatenate.5, %constant.304), padding=0_0x1527_0, metadata={op_name="pad.5725"}

Besides, I will start to evaluate memory usage change for this XLA diff.

Thanks Billy for applying the large grid.

In the meantime, I will see if I can get some data from xprof.

felixwqp avatar Aug 19 '25 07:08 felixwqp

The simulation step is quite a bit faster than on main currently:

  • about 16 minutes on main: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17059851883/job/48370674137
  • about 9 minutes here: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17062835799/job/48373063875

I think most of the improvement is in compile time, we don't get the warning about long XLA compilation anymore.

giordano avatar Aug 19 '25 08:08 giordano

2025-08-19 09:15:10.544797: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11429599125 bytes) by rematerialization; only reduced to 56.60GiB (60775368484 bytes), down from 66.05GiB (70922254524 bytes) originally

so we're now down to 56.6GB from 58.50GB.

So in short, definite compile time improvement confirmed, and a slight memory reduction -- though more memory reduction to go

wsmoses avatar Aug 19 '25 09:08 wsmoses

Somewhat good news, only calling initialize! + update_state! (without time_step!) is sufficient to get an OOM on the device (time_step! would cause to use even more memory): https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17173076863/job/48727161438?pr=1307#step:19:728

2025-08-23 09:38:22.863446: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802313 bytes) by rematerialization; only reduced to 46.84GiB (50297252461 bytes), down from 63.27GiB (67935638125 bytes) originally
[...]
E0000 00:00:1755941951.386733    3750 pjrt_stream_executor_client.cc:3081] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 51712165888 bytes. [tf-allocator-allocation-error='']
2025-08-23 09:39:11.389852: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_2_bfc) ran out of memory trying to allocate 48.16GiB (rounded to 51712165888)requested by op 
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 

The XLA dump is quite a bit smaller.

giordano avatar Aug 23 '25 09:08 giordano

I further reduced the code by removing some kernels, and this is still using more memory than necessary:

2025-08-23 14:04:03.144019: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 12.40GiB (13313787004 bytes), down from 12.40GiB (13313787004 bytes) originally

The job didn't crash because the 12.4 GiB still fit within the total memory, but you can see at the end of the job that the peak memory usage is larger than the memory used after the model creation

┌ Info: [0] allocations
│   GordonBell25.allocatorstats() =
│    AllocatorStats
│    --------------
│    num_allocs: 41
│    bytes_in_use: 21894366976
│    peak_bytes_in_use: 28943843072
│    largest_alloc_size: 7049475840
│    bytes_limit: 31804391424
│    bytes_reserved: 0
│    peak_bytes_reserved: 0
│    bytes_reservable_limit: nothing
│    largest_free_block_bytes: 0
│    pool_bytes: 31804391424
│    peak_pool_bytes: 31804391424
└    

The XLA dump is even smaller

giordano avatar Aug 23 '25 14:08 giordano

Thank you Mose!

There are primarily two modules, reactant__100 and reactant__first. I will start with reactant__first because reactant__first looks like have more memory allocated, but if you happen to know where the memory bottleneck happens, I will focus on the one with memory bottleneck,

felixwqp avatar Aug 26 '25 17:08 felixwqp

No, I think we're a bit clueless about where memory is going 🫠

giordano avatar Aug 26 '25 17:08 giordano

Based on dump, I performed a memory profile analysis on HLO input (module_0097.reactant_first_t....before_optimizations.txt), run in an google internal H100 environment.

Top Temporary Ops Summary

Here is a breakdown of memory usage by operation type, focusing on temporary allocations:

Operation Type Total Memory (MiB) Percentage Number of Operations Op Names Framework Op
param 20880 75.87% *
copy 2880 10.46% 2 copy.29, copy.30 N/A
loop_add_fusion 2880 10.46% 2 loop_add_fusion{0}, loop_add_fusion{2} add.221
loop_select_fusion 862.89 3.14% 3 loop_select_fusion{0}, loop_select_fusion{1}, loop_select_fusion{2} pad.178
rest ops 18.51 0.07%
TOTAL 27521.5 100%
8prNPahc8JZAzpP

Questions

  1. Parameter Memory Usage: The HLO input parameters consume 20880 MiB, which accounts for 75.87% of the total memory allocation. Is this magnitude of memory usage for input parameters considered reasonable or expected from an MLIR perspective for this type of model or operation?
  2. If the parameter memory usage is larger than expected, could this suggest potential inefficiencies or optimization opportunities in the StableHLO to HLO conversion process?

felixwqp avatar Aug 26 '25 19:08 felixwqp

Is this magnitude of memory usage for input parameters considered reasonable or expected from an MLIR perspective for this type of model or operation?

Billy please do correct me, but I believe that's indeed expected. That's the memory we have right after the model generation. In https://github.com/EnzymeAD/Enzyme-JAX/pull/1243#issuecomment-3146844668 I had anticipated about 20 GB only based on scaling the input parameters from previous runs, so I think these 20 GB are what we expect.

giordano avatar Aug 26 '25 20:08 giordano

I need to stare more at the minimization, but the basic jist here is that the parameters will have some amount of memory usage, but in principle we should be able to use no additional memory. Specifically the original code just had that original allocation, and updated it in place.

In practice when we do the full, loop-based version I expect we'll have a factor of two for the induction variables or something, but here I would expect that these allocations should be eliminable.

wsmoses avatar Aug 26 '25 20:08 wsmoses

With https://github.com/EnzymeAD/Reactant.jl/pull/1619 we get a slightly lower peak memory: main:

│    peak_bytes_in_use: 28943842560

vs PR:

│    peak_bytes_in_use: 28534432768

but the warning we get during compilation mentions a much larger memory buffer: main

W0000 00:00:1756938493.424201    5944 hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 12.40GiB (13313787004 bytes), down from 12.40GiB (13313787004 bytes) originally

vs PR

W0000 00:00:1756941213.091395   59266 hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 16.16GiB (17352329320 bytes), down from 16.23GiB (17427728524 bytes) originally

XLA dumps: main vs PR.

giordano avatar Sep 03 '25 23:09 giordano

I'm curious if https://github.com/EnzymeAD/Enzyme-JAX/pull/1363 creates similar improvements [or more], or causes more chaos.

wsmoses avatar Sep 03 '25 23:09 wsmoses

also note that the absence of the dus_to_pad comm op in that PR causes all-gathers to return, which is bad

wsmoses avatar Sep 03 '25 23:09 wsmoses

@felixwqp file module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt in https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17686880847/artifacts/4001249975 is slightly less than 200 lines of code, is that workable? Edit: ignore this, it was incomplete.

giordano avatar Sep 13 '25 02:09 giordano

@felixwqp: @glwagner further reduced the kernels, I hope now we got something even more useful.

In this this run we launched the program three times separately, using 3 different kernels: one filling the halo regions only in the east-west direction (which is also the direction in which we do the sharding, so we expect device-device communication here), one filling halo regions only in the north-south direction (no sharding), and another one filling all the halo regions (device-device communication also here). We see

│    num_allocs: 41
│    bytes_in_use: 23404316416
│    peak_bytes_in_use: 24918199296
│    largest_alloc_size: 1513882880
│    bytes_limit: 31804391424
│    num_allocs: 41
│    bytes_in_use: 23404316416
│    peak_bytes_in_use: 23443442432
│    largest_alloc_size: 1509949440
│    bytes_limit: 31804391424
│    num_allocs: 40
│    bytes_in_use: 21894366976
│    peak_bytes_in_use: 23409627136
│    largest_alloc_size: 1515260160
│    bytes_limit: 31804391424

Here is the XLA dump from the run: simulation-xla-dump-1.11--main-main.zip. Here are the number of lines of the three modules:

% wc -l simulation-xla-dump-1.11--main-main/xla_dump_*/module_00*.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     468 simulation-xla-dump-1.11--main-main/xla_dump_all/module_0049.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     111 simulation-xla-dump-1.11--main-main/xla_dump_east_west/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     197 simulation-xla-dump-1.11--main-main/xla_dump_north_south/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     776 total

I presume the more interesting to look at are the east-west and north-south ones, perhaps especially the former which does the device-device communication (and is also the shorter one), but comparing it with the north-south one may also be useful. Hope this helps!

Greg please do correct me if I said anything wrong above!

giordano avatar Sep 13 '25 18:09 giordano