zett
zett copied to clipboard
OOM on training Mistral hypernet
Hi @bminixhofer, I am getting OOM with following logs when training a mistral multilingual hypernet. I have tried on this two A100(80GB) as well. Not sure what is wrong!!
I have created a branch containing a script to reproduce this. You can run the ./install script on any instance of vast.ai - https://github.com/bminixhofer/zett/compare/main...kdcyberdude:zett:main
2024-06-10 00:21:14.077507: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.00GiB (rounded to 1073741824)requested by op
2024-06-10 00:21:14.078739: W external/tsl/tsl/framework/bfc_allocator.cc:497] ****************************************************************************************************
2024-06-10 00:21:14.078795: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.00GiB
constant allocation: 0B
maybe_live_out allocation: 1.00GiB
preallocated temp allocation: 0B
total allocation: 2.00GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 1.00GiB
Entry Parameter Subshape: pred[1,32768,32768]
==========================
Buffer 2:
Size: 1.00GiB
XLA Label: fusion
Shape: pred[1,1,32768,32768]
==========================
0%| | 0/100000 [00:37<?, ?it/s]
Traceback (most recent call last):
File "/workspace/zett/train.py", line 1625, in <module>
main()
File "/workspace/zett/train.py", line 1526, in main
state, train_metric = current_step_fn(state, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/zett/train.py", line 1203, in train_step
(loss, (lexical_loss, mean_lexical_overlap)), grad = grad_fn(state.params)
^^^^^^^^^^^^^^^^^^^^^
File "/workspace/zett/train.py", line 1116, in compute_loss
) = compute_embeddings_and_logits(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/zett/train.py", line 1092, in compute_embeddings_and_logits
logits = model_fn(
^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 502, in __call__
outputs = self.module.apply(
^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 677, in __call__
outputs = self.model(
^^^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 605, in __call__
outputs = self.layers(
^^^^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 556, in __call__
layer_outputs = block(
^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 374, in __call__
outputs = self.self_attn(
^^^^^^^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 241, in setup
casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 810, in make_causal_mask
return make_attention_mask(
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 786, in make_attention_mask
mask = jnp.expand_dims(mask, axis=-3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/zett/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 912, in expand_dims
return lax.expand_dims(a, axis)
^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.00GiB
constant allocation: 0B
maybe_live_out allocation: 1.00GiB
preallocated temp allocation: 0B
total allocation: 2.00GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 1.00GiB
Entry Parameter Subshape: pred[1,32768,32768]
==========================
Buffer 2:
Size: 1.00GiB
XLA Label: fusion
Shape: pred[1,1,32768,32768]
==========================
Hi, it looks like it is trying to create a very large causal mask due to the high max_position_embeddings. You can try manually lowering the max_position_embeddings to the block_size, which should make it a lot easier on memory (and should be safe to do).
Hi @bminixhofer, Do I need to update max_position_embedding while initializing roberta-base model to 128 in zett/model/init.py
I tried without using pretrained hypernet model as well. It's still giving OOM.
And what is the VRAM requirement for training this on GPU?
Logs
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1718174206.615502 474103 hlo_rematerialization.cc:2946] Can't reduce memory use below -18.34GiB (-19688159409 bytes) by rematerialization; only reduced to 21.55GiB (23140745816 bytes), down from 21.55GiB (23140745816 bytes) originally 2024-06-12 12:06:47.545238: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 16.50GiB constant allocation: 22B maybe_live_out allocation: 21.55GiB preallocated temp allocation: 13.4KiB total allocation: 38.05GiB total fragmentation: 13.4KiB (0.00%) Peak buffers: Buffer 1: Size: 1000.00MiB Entry Parameter Subshape: f32[32000,8192] ==========================
Buffer 2:
Size: 1000.00MiB
XLA Label: fusion
Shape: f32[32000,8192]
==========================
Buffer 3:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 4:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 5:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 6:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 7:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 8:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 9:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 10:
Size: 128.00MiB
Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 11:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 12:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 13:
Size: 128.00MiB
Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 14:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 15:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Traceback (most recent call last):
File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 198, in _run_module_as_main
return _run_code(code, main_globals, None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 88, in _run_code
exec(code, run_globals)
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in
Buffer 2:
Size: 1000.00MiB
XLA Label: fusion
Shape: f32[32000,8192]
==========================
Buffer 3:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 4:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 5:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 6:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 7:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 8:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 9:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 10:
Size: 128.00MiB
Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 11:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 12:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 13:
Size: 128.00MiB
Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 14:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 15:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
PS: I am new to JAX.