Update the xla commit to experiment potential fix for OOM
Try to fix the large number of zero-broadcast generated from big pad in spmd partitioner.
for sample HLO,
HloModule module_0076.reactant_loop_after_spmd_partitioner_cp_pattern
ENTRY main() -> f64[3056,123] {
%constant.946 = f64[] constant(0)
%subtract.1055 = f64[3056,12272]{1,0} parameter(0), sharding={devices=[2,2]<=[2,2]T(1,0)}
%slice.1057 = f64[3056,1]{1,0} slice(%subtract.1055), slice={[0:3056], [12271:12272]}, sharding={devices=[2,2]<=[2,2]T(1,0)}
ROOT %pad.1059 = f64[3056,123]{1,0} pad(%slice.1057, %constant.946), padding=0_0x0_122, sharding={devices=[2,2]<=[2,2]T(1,0)}
}
before the fix, the result looks like:
after the fix, the result looks like: https://screenshot.googleplex.com/
I fixed a bug in the workflow where we were using the wrong xla commit (apparently this is the first time we're using a custom xla commit, so we didn't spot it before) and also copied the changes from #1243 to replicate the conditions which trigger the OOM.
For readers, this is related to #671.
This is failing to fetch the GB-25 repository, without much information: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16985145489/job/48166877204#step:11:63. Sigh.
also the sed for xla replacement is now wrong (per a lot of bazel reshuffling done over the past week to fix things more stably).
The new way to do so is in https://github.com/EnzymeAD/Reactant.jl/blob/07d4fcacb935f3e915ca40c1ab7c98a210a93efc/deps/ReactantExtra/WORKSPACE#L198:
xla_workspace(NEW_XLA_PATCHES)
becomes
xla_workspace(NEW_XLA_PATCHES, 'abc2342343432432432')
https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/job/48196660878?pr=1307#step:18:720
2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally
same as https://github.com/EnzymeAD/Enzyme-JAX/pull/1243#issuecomment-3146860015 😢 XLA dump uploaded to https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/artifacts/3777223905
Thanks @giordano for the quick verification, just trying to understand effect the above XLA commit to help us better prioritize the optimization direction.
can I assume,
without the commit, link, after rematerialization, memory usage is 67.69GiB
2025-08-02 23:46:25.965629: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3423] Can't reduce memory use below 10.64GiB (11429599125 bytes) by rematerialization; only reduced to 67.69GiB (72679996812 bytes), down from 71.05GiB (76287081352 bytes) originally
after the commit, the link, after remateriliazation, the memory usage is 58.50GiB,
2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally
Question:
- is there any Enzyme level difference between these two runs account for these 10GB memory usage reduction? If not it means the HLO optimization did improve the memory usage, we can prioritize the same methodology for further XLA memory optimization.
cc: @wsmoses
I believe I linked two different lines, i.e. after two different compilation stages (we compile two different kernels), this should be a more direct comparison (always after the first kernel): #1243:
2025-08-02 23:41:14.191988: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3423] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally
2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally
In any case I changed this PR to always run "vanilla XLA" vs your PR, for a quicker comparison.
x/ref https://github.com/openxla/xla/pull/30307
@felixwqp I warmly recommend using git --force-with-lease instead of --force, so you don't keep undoing my fixes 🙂
Ah, thank you for this suggestion! Will use it going forward. apologies for overriding your commits, it's new to github review process, any suggestions are helpful and welcome!
I want to only update the xla commit to 1ac176a9b8b4800bc2753d944eec62a39e6189b8 to verify if the hlo dump looks as intended. No need to trigger OOM anymore.
new commit failed with
Error: The artifact name is not valid: julia-environment-1.11-1ac176a9b8b4800bc2753d944eec62a39e6189b8-mg_sharded-factors-ap/persistent_compile_cache. Contains the following character: Forward slash /
Should I adjust reactant_commit?
reactant_commit:
- 'ap/persistent_compile_cache'
That was supposed to be addressed by #1297 🤔
Hopefully fixed by #1316. I rebased on main. Please don't override again my changes 😅 Edit: confirmed that fixed the artifact issue.
After Billy fixed a mistake in the build system: XLA dump on this PR: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17060619747/artifacts/3795845666 (7.44 MB) XLA dump on main: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17051302903/artifacts/3792563288 (18.4 MB)
Only based on the archive size, your PR is doing something! 😁 (I'm on the phone, can't open the archives)
Yeah, a list of zero-bcast are coalesced as expected in https://github.com/EnzymeAD/Enzyme-JAX/pull/1307#issue-3324447702,
e.g. from the new dump, %broadcast.764 is the coalesced bcast.
%broadcast.764 = f64[760,1527]{1,0} broadcast(%constant.304), dimensions={}, metadata={op_name="pad.5725"}
%collective-permute.1 = f64[760,1]{1,0} collective-permute(%slice.383), channel_id=2, source_target_pairs={{0,2},{1,3}}, metadata={op_name="pad.5725"}
%concatenate.5 = f64[760,1528]{1,0} concatenate(%broadcast.764, %collective-permute.1), dimensions={1}, metadata={op_name="pad.5725"}
%pad.39 = f64[760,3055]{1,0} pad(%concatenate.5, %constant.304), padding=0_0x1527_0, metadata={op_name="pad.5725"}
Besides, I will start to evaluate memory usage change for this XLA diff.
Thanks Billy for applying the large grid.
In the meantime, I will see if I can get some data from xprof.
The simulation step is quite a bit faster than on main currently:
- about 16 minutes on main: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17059851883/job/48370674137
- about 9 minutes here: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17062835799/job/48373063875
I think most of the improvement is in compile time, we don't get the warning about long XLA compilation anymore.
2025-08-19 09:15:10.544797: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11429599125 bytes) by rematerialization; only reduced to 56.60GiB (60775368484 bytes), down from 66.05GiB (70922254524 bytes) originally
so we're now down to 56.6GB from 58.50GB.
So in short, definite compile time improvement confirmed, and a slight memory reduction -- though more memory reduction to go
Somewhat good news, only calling initialize! + update_state! (without time_step!) is sufficient to get an OOM on the device (time_step! would cause to use even more memory): https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17173076863/job/48727161438?pr=1307#step:19:728
2025-08-23 09:38:22.863446: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802313 bytes) by rematerialization; only reduced to 46.84GiB (50297252461 bytes), down from 63.27GiB (67935638125 bytes) originally
[...]
E0000 00:00:1755941951.386733 3750 pjrt_stream_executor_client.cc:3081] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 51712165888 bytes. [tf-allocator-allocation-error='']
2025-08-23 09:39:11.389852: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_2_bfc) ran out of memory trying to allocate 48.16GiB (rounded to 51712165888)requested by op
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation.
The XLA dump is quite a bit smaller.
I further reduced the code by removing some kernels, and this is still using more memory than necessary:
2025-08-23 14:04:03.144019: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 12.40GiB (13313787004 bytes), down from 12.40GiB (13313787004 bytes) originally
The job didn't crash because the 12.4 GiB still fit within the total memory, but you can see at the end of the job that the peak memory usage is larger than the memory used after the model creation
┌ Info: [0] allocations
│ GordonBell25.allocatorstats() =
│ AllocatorStats
│ --------------
│ num_allocs: 41
│ bytes_in_use: 21894366976
│ peak_bytes_in_use: 28943843072
│ largest_alloc_size: 7049475840
│ bytes_limit: 31804391424
│ bytes_reserved: 0
│ peak_bytes_reserved: 0
│ bytes_reservable_limit: nothing
│ largest_free_block_bytes: 0
│ pool_bytes: 31804391424
│ peak_pool_bytes: 31804391424
└
The XLA dump is even smaller
Thank you Mose!
There are primarily two modules, reactant__100 and reactant__first. I will start with reactant__first because reactant__first looks like have more memory allocated, but if you happen to know where the memory bottleneck happens, I will focus on the one with memory bottleneck,
No, I think we're a bit clueless about where memory is going 🫠
Based on dump, I performed a memory profile analysis on HLO input (module_0097.reactant_first_t....before_optimizations.txt), run in an google internal H100 environment.
Top Temporary Ops Summary
Here is a breakdown of memory usage by operation type, focusing on temporary allocations:
| Operation Type | Total Memory (MiB) | Percentage | Number of Operations | Op Names | Framework Op |
|---|---|---|---|---|---|
| param | 20880 | 75.87% | * | ||
| copy | 2880 | 10.46% | 2 | copy.29, copy.30 | N/A |
| loop_add_fusion | 2880 | 10.46% | 2 | loop_add_fusion{0}, loop_add_fusion{2} | add.221 |
| loop_select_fusion | 862.89 | 3.14% | 3 | loop_select_fusion{0}, loop_select_fusion{1}, loop_select_fusion{2} | pad.178 |
| rest ops | 18.51 | 0.07% | |||
| TOTAL | 27521.5 | 100% |
Questions
- Parameter Memory Usage: The HLO input parameters consume 20880 MiB, which accounts for 75.87% of the total memory allocation. Is this magnitude of memory usage for input parameters considered reasonable or expected from an MLIR perspective for this type of model or operation?
- If the parameter memory usage is larger than expected, could this suggest potential inefficiencies or optimization opportunities in the StableHLO to HLO conversion process?
Is this magnitude of memory usage for input parameters considered reasonable or expected from an MLIR perspective for this type of model or operation?
Billy please do correct me, but I believe that's indeed expected. That's the memory we have right after the model generation. In https://github.com/EnzymeAD/Enzyme-JAX/pull/1243#issuecomment-3146844668 I had anticipated about 20 GB only based on scaling the input parameters from previous runs, so I think these 20 GB are what we expect.
I need to stare more at the minimization, but the basic jist here is that the parameters will have some amount of memory usage, but in principle we should be able to use no additional memory. Specifically the original code just had that original allocation, and updated it in place.
In practice when we do the full, loop-based version I expect we'll have a factor of two for the induction variables or something, but here I would expect that these allocations should be eliminable.
With https://github.com/EnzymeAD/Reactant.jl/pull/1619 we get a slightly lower peak memory: main:
│ peak_bytes_in_use: 28943842560
│ peak_bytes_in_use: 28534432768
but the warning we get during compilation mentions a much larger memory buffer: main
W0000 00:00:1756938493.424201 5944 hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 12.40GiB (13313787004 bytes), down from 12.40GiB (13313787004 bytes) originally
vs PR
W0000 00:00:1756941213.091395 59266 hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 16.16GiB (17352329320 bytes), down from 16.23GiB (17427728524 bytes) originally
I'm curious if https://github.com/EnzymeAD/Enzyme-JAX/pull/1363 creates similar improvements [or more], or causes more chaos.
also note that the absence of the dus_to_pad comm op in that PR causes all-gathers to return, which is bad
@felixwqp file Edit: ignore this, it was incomplete.module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt in https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17686880847/artifacts/4001249975 is slightly less than 200 lines of code, is that workable?
@felixwqp: @glwagner further reduced the kernels, I hope now we got something even more useful.
In this this run we launched the program three times separately, using 3 different kernels: one filling the halo regions only in the east-west direction (which is also the direction in which we do the sharding, so we expect device-device communication here), one filling halo regions only in the north-south direction (no sharding), and another one filling all the halo regions (device-device communication also here). We see
- for east-est:
│ num_allocs: 41
│ bytes_in_use: 23404316416
│ peak_bytes_in_use: 24918199296
│ largest_alloc_size: 1513882880
│ bytes_limit: 31804391424
- for north-south:
│ num_allocs: 41
│ bytes_in_use: 23404316416
│ peak_bytes_in_use: 23443442432
│ largest_alloc_size: 1509949440
│ bytes_limit: 31804391424
- for all regions
│ num_allocs: 40
│ bytes_in_use: 21894366976
│ peak_bytes_in_use: 23409627136
│ largest_alloc_size: 1515260160
│ bytes_limit: 31804391424
Here is the XLA dump from the run: simulation-xla-dump-1.11--main-main.zip. Here are the number of lines of the three modules:
% wc -l simulation-xla-dump-1.11--main-main/xla_dump_*/module_00*.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
468 simulation-xla-dump-1.11--main-main/xla_dump_all/module_0049.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
111 simulation-xla-dump-1.11--main-main/xla_dump_east_west/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
197 simulation-xla-dump-1.11--main-main/xla_dump_north_south/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
776 total
I presume the more interesting to look at are the east-west and north-south ones, perhaps especially the former which does the device-device communication (and is also the shorter one), but comparing it with the north-south one may also be useful. Hope this helps!
Greg please do correct me if I said anything wrong above!