DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Reduce Memory Usage by performing operations component-wise for Equilibrium Objectives

Open dpanici opened this issue 1 year ago • 6 comments

In looking at the traceback for an OOM error, I saw that a lot of the memory is being taken up by just concatenation operations where we are putting together 3 arrays to form a vector, like in our basis vectors or vector B. It could be helpful, memory wise, to try to write out things component-wise only instead, at least for the equilibrium objectives, which may end up reducing memory usage, since there won't be these concatenation actions taking up memory

example traceback:

2023-12-23 21:03:51.178436: W external/xla/xla/service/hlo_rematerialization.cc:2202] Can't reduce memory use below 59.49GiB (63880937472 bytes) by rematerialization; only reduced to 64.19GiB (68923960570 bytes), down from 71.23GiB (76482936842 bytes) originally
2023-12-23 21:04:02.619504: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 69.48GiB (rounded to 74602902016)requested by op 
2023-12-23 21:04:02.619827: W external/tsl/tsl/framework/bfc_allocator.cc:497] ******************__________________________________________________________________________________
2023-12-23 21:04:02.620390: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 74602901848 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    2.46MiB
              constant allocation:   192.5KiB
        maybe_live_out allocation:  900.36MiB
     preallocated temp allocation:   69.48GiB
  preallocated temp fragmentation:    1.77GiB (2.54%)
                 total allocation:   70.36GiB
              total fragmentation:    1.77GiB (2.51%)
Peak buffers:
	Buffer 1:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/mul" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=513
		XLA Label: broadcast
		Shape: f64[3,8375,14091]
		==========================

	Buffer 2:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 3:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 4:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 5:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 6:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=408
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 7:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=3283 deduplicated_name="fusion.300"
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 8:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/add" source_file="/home/dpanici/DESC/desc/compute/_field.py" source_line=306
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 9:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2401 deduplicated_name="fusion.300"
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 10:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2261 deduplicated_name="fusion.300"
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 11:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 12:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 13:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 14:
		Size: 2.64GiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/concatenate[dimension=1]" source_file="/home/dpanici/DESC/desc/compute/_basis_vectors.py" source_line=2336
		XLA Label: fusion
		Shape: f64[3,8375,14091]
		==========================

	Buffer 15:
		Size: 900.36MiB
		Operator: op_name="jit(jac_scaled)/jit(main)/vmap(jvp(jit(compute_scaled)))/jit(cross)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(8375, 14091, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/dpanici/DESC/desc/compute/utils.py" source_line=489
		XLA Label: fusion
		Shape: f64[1,8375,14091]
		==========================

dpanici avatar Dec 24 '23 02:12 dpanici

Also mentioned in this paper section V. C as way to speedup compilation times

image

dpanici avatar Dec 24 '23 02:12 dpanici

probably would not help so closing for now

dpanici avatar Jan 08 '24 21:01 dpanici

can you document why it won't help at some point? I saved a lot of memory by making sure operations are fused in #1440 ?

unalmis avatar Feb 09 '25 04:02 unalmis

@dpanici

YigitElma avatar May 20 '25 02:05 YigitElma

I don't remember why I said it would not help much, probably worth spending the hour to re-write the F balance equations component wise and see if there's a difference in memory use

dpanici avatar May 22 '25 15:05 dpanici

Might be worth looking into jnp.einsum where we currently use [None,:] to broadcast arrays for multiplication, as this may be more efficient memory-wise

dpanici avatar Nov 06 '25 20:11 dpanici