DESC
DESC copied to clipboard
Reduce Memory Usage by performing operations component-wise for Equilibrium Objectives
In looking at the traceback for an OOM error, I saw that a lot of the memory is being taken up by just concatenation operations where we are putting together 3 arrays to form a vector, like in our basis vectors or vector B. It could be helpful, memory wise, to try to write out things component-wise only instead, at least for the equilibrium objectives, which may end up reducing memory usage, since there won't be these concatenation actions taking up memory
example traceback:
2023-12-23 21:03:51.178436: W external/xla/xla/service/hlo_rematerialization.cc:2202] Can't reduce memory use below 59.49GiB (63880937472 bytes) by rematerialization; only reduced to 64.19GiB (68923960570 bytes), down from 71.23GiB (76482936842 bytes) originally
2023-12-23 21:04:02.619504: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 69.48GiB (rounded to 74602902016)requested by op
2023-12-23 21:04:02.619827: W external/tsl/tsl/framework/bfc_allocator.cc:497] ******************__________________________________________________________________________________
2023-12-23 21:04:02.620390: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 74602901848 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 2.46MiB
constant allocation: 192.5KiB
maybe_live_out allocation: 900.36MiB
preallocated temp allocation: 69.48GiB
preallocated temp fragmentation: 1.77GiB (2.54%)
total allocation: 70.36GiB
total fragmentation: 1.77GiB (2.51%)
Peak buffers:
Buffer 1:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/mul" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=513
XLA Label: broadcast
Shape: f64[3,8375,14091]
==========================
Buffer 2:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 3:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 4:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 5:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 6:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 7:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=3283 deduplicated_name="fusion.300"
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 8:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=306
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 9:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2401 deduplicated_name="fusion.300"
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 10:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2261 deduplicated_name="fusion.300"
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 11:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 12:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 13:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 14:
Size: 2.64GiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
XLA Label: fusion
Shape: f64[3,8375,14091]
==========================
Buffer 15:
Size: 900.36MiB
Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/jit(cross)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(8375, 14091, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/dpanici/DESC/desc/compute/utils.py" source_line=489
XLA Label: fusion
Shape: f64[1,8375,14091]
==========================
probably would not help so closing for now
can you document why it won't help at some point? I saved a lot of memory by making sure operations are fused in #1440 ?
@dpanici
I don't remember why I said it would not help much, probably worth spending the hour to re-write the F balance equations component wise and see if there's a difference in memory use
Might be worth looking into jnp.einsum where we currently use [None,:] to broadcast arrays for multiplication, as this may be more efficient memory-wise