zett icon indicating copy to clipboard operation
zett copied to clipboard

OOM on training Mistral hypernet

Open kdcyberdude opened this issue 1 year ago • 2 comments

Hi @bminixhofer, I am getting OOM with following logs when training a mistral multilingual hypernet. I have tried on this two A100(80GB) as well. Not sure what is wrong!!

I have created a branch containing a script to reproduce this. You can run the ./install script on any instance of vast.ai - https://github.com/bminixhofer/zett/compare/main...kdcyberdude:zett:main

2024-06-10 00:21:14.077507: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.00GiB (rounded to 1073741824)requested by op                                                                         
2024-06-10 00:21:14.078739: W external/tsl/tsl/framework/bfc_allocator.cc:497] ****************************************************************************************************                                                                              
2024-06-10 00:21:14.078795: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.                                                             
BufferAssignment OOM Debugging.                                                                                                                                                                                                                                  
BufferAssignment stats:                                                                                                                                                                                                                                          
             parameter allocation:    1.00GiB                                                                                                                                                                                                                    
              constant allocation:         0B                                                                                                                                                                                                                    
        maybe_live_out allocation:    1.00GiB                                                                                                                                                                                                                    
     preallocated temp allocation:         0B                                                                                                                                                                                                                    
                 total allocation:    2.00GiB                                                                                                                                                                                                                    
              total fragmentation:         0B (0.00%)                                                                                                                                                                                                            
Peak buffers:                                                                                                                                                                                                                                                    
        Buffer 1:                                                                                                                                                                                                                                                
                Size: 1.00GiB                                                                                                                                                                                                                                    
                Entry Parameter Subshape: pred[1,32768,32768]                                                                                                                                                                                                    
                ==========================                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                 
        Buffer 2:                                                                                                                                                                                                                                                
                Size: 1.00GiB                                                                                                                                                                                                                                    
                XLA Label: fusion                                                                                                                                                                                                                                
                Shape: pred[1,1,32768,32768]                                                                                                                                                                                                                     
                ==========================                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                 
  0%|                                                                                                                                                                                                                                 | 0/100000 [00:37<?, ?it/s]
Traceback (most recent call last):                                                                                                                                                                                                                               
  File "/workspace/zett/train.py", line 1625, in <module>                                                                                                                                                                                                        
    main()                                                                                                                                                                                                                                                       
  File "/workspace/zett/train.py", line 1526, in main                                                                                                                                                                                                            
    state, train_metric = current_step_fn(state, batch)                                                                                                                                                                                                          
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                          
  File "/workspace/zett/train.py", line 1203, in train_step                                                                                                                                                                                                      
    (loss, (lexical_loss, mean_lexical_overlap)), grad = grad_fn(state.params)                                                                                                                                                                                   
                                                         ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                   
  File "/workspace/zett/train.py", line 1116, in compute_loss                                                                                                                                                                                                    
    ) = compute_embeddings_and_logits(                                                                                                                                                                                                                           
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                           
  File "/workspace/zett/train.py", line 1092, in compute_embeddings_and_logits                                                                                                                                                                                   
    logits = model_fn(                                                                                                                                                                                                                                           
             ^^^^^^^^^                                                                                                                                                                                                                                           
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 502, in __call__                                                                                                                           
    outputs = self.module.apply(                                                                                                                                                                                                                                 
              ^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                 
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 677, in __call__                                                                                                                           
    outputs = self.model(                                                                                                                                                                                                                                        
              ^^^^^^^^^^^                                                                                                                                                                                                                                        
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 605, in __call__                                                                                                                           
    outputs = self.layers(                                                                                                                                                                                                                                       
              ^^^^^^^^^^^^                                                                                                                                                                                                                                       
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 556, in __call__                                                                                                                           
    layer_outputs = block(                                                                                                                                                                                                                                       
                    ^^^^^^                                                                                                                                                                                                                                       
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 374, in __call__                                                                                                                           
    outputs = self.self_attn(                                                                                                                                                                                                                                    
              ^^^^^^^^^^^^^^^                                                                                                                                                                                                                                    
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 241, in setup                                                                                                                              
    casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")                                                                                                                                                    
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                    
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 810, in make_causal_mask                                                                                                                                                
    return make_attention_mask(                                                                                                                                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                  
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 786, in make_attention_mask                                           
    mask = jnp.expand_dims(mask, axis=-3)                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                      
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 912, in expand_dims                                               
    return lax.expand_dims(a, axis)                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^                                            
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.                                                                       
BufferAssignment OOM Debugging.                                                
BufferAssignment stats:                                                        
             parameter allocation:    1.00GiB                                  
              constant allocation:         0B                                  
        maybe_live_out allocation:    1.00GiB                                  
     preallocated temp allocation:         0B                                  
                 total allocation:    2.00GiB                                  
              total fragmentation:         0B (0.00%)                          
Peak buffers:                                                                  
        Buffer 1:                                                                                
                Size: 1.00GiB                                                                    
                Entry Parameter Subshape: pred[1,32768,32768]                                    
                ==========================                                                       

        Buffer 2:                                                                                
                Size: 1.00GiB                                                                    
                XLA Label: fusion                                                                
                Shape: pred[1,1,32768,32768]                                                                           
                ==========================   

kdcyberdude avatar Jun 10 '24 00:06 kdcyberdude

Hi, it looks like it is trying to create a very large causal mask due to the high max_position_embeddings. You can try manually lowering the max_position_embeddings to the block_size, which should make it a lot easier on memory (and should be safe to do).

bminixhofer avatar Jun 11 '24 20:06 bminixhofer

Hi @bminixhofer, Do I need to update max_position_embedding while initializing roberta-base model to 128 in zett/model/init.py

I tried without using pretrained hypernet model as well. It's still giving OOM.

And what is the VRAM requirement for training this on GPU?

Logs

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1718174206.615502 474103 hlo_rematerialization.cc:2946] Can't reduce memory use below -18.34GiB (-19688159409 bytes) by rematerialization; only reduced to 21.55GiB (23140745816 bytes), down from 21.55GiB (23140745816 bytes) originally 2024-06-12 12:06:47.545238: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 16.50GiB constant allocation: 22B maybe_live_out allocation: 21.55GiB preallocated temp allocation: 13.4KiB total allocation: 38.05GiB total fragmentation: 13.4KiB (0.00%) Peak buffers: Buffer 1: Size: 1000.00MiB Entry Parameter Subshape: f32[32000,8192] ==========================

    Buffer 2:
            Size: 1000.00MiB
            XLA Label: fusion
            Shape: f32[32000,8192]
            ==========================

    Buffer 3:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 4:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 5:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 6:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 7:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 8:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 9:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 10:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 11:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 12:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 13:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 14:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 15:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

Traceback (most recent call last): File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 198, in _run_module_as_main return _run_code(code, main_globals, None, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 88, in _run_code exec(code, run_globals) File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in cli.main() File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="main") File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/mnt/pi/proj/jun/zett/train.py", line 1629, in main() File "/mnt/pi/proj/jun/zett/train.py", line 848, in main state = jax.jit( ^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 16.50GiB constant allocation: 22B maybe_live_out allocation: 21.55GiB preallocated temp allocation: 13.4KiB total allocation: 38.05GiB total fragmentation: 13.4KiB (0.00%) Peak buffers: Buffer 1: Size: 1000.00MiB Entry Parameter Subshape: f32[32000,8192] ==========================

    Buffer 2:
            Size: 1000.00MiB
            XLA Label: fusion
            Shape: f32[32000,8192]
            ==========================

    Buffer 3:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 4:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 5:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 6:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 7:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 8:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 9:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 10:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 11:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 12:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 13:
            Size: 128.00MiB
            Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
            XLA Label: fusion
            Shape: f32[4096,8192]
            ==========================

    Buffer 14:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

    Buffer 15:
            Size: 128.00MiB
            XLA Label: fusion
            Shape: f32[8192,4096]
            ==========================

PS: I am new to JAX.

kdcyberdude avatar Jun 12 '24 06:06 kdcyberdude