litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

CUDA Out of Memory for the Falcon 7B model on A100 80GB GPU

Open k21993 opened this issue 2 years ago • 25 comments

I am trying to reproduce the Falcon-7B Lora fine-tuning on the Alpaca dataset. I followed the steps to convert the checkpoints to lightning format, downloaded and tokenized the Alpaca dataset as instructed. When I run:

python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/

I get the following traceback:

{'eval_interval': 100, 'save_interval': 100, 'eval_iters': 100, 'log_interval': 1, 'devices': 1, 'learning_rate': 0.0003, 'batch_size': 4, 'micro_batch_size': 4, 'gradient_accumulation_iters': 1, 'max_iters': 50000, 'weight_decay': 0.01, 'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05, 'warmup_iters': 100}
Using bfloat16 Automatic Mixed Precision (AMP)
Global seed set to 1337
Loading model 'checkpoints/tiiuae/falcon-7b/lit_model.pth' with {'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 512, 'padded_vocab_size': 65024, 'n_layer': 32, 'n_head': 71, 'n_embd': 4544, 'rotary_percentage': 1.0, 'parallel_residual': True, 'bias': False, 'n_query_groups': 1, 'shared_attention_norm': True}
Number of trainable parameters: 3506176
Validating ...
Recommend a movie for me to watch during the weekend and explain the reason.
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie for me to watch during the weekend and explain the reason.

### Response:
[The Martian](https://www.imdb.com/title/tt1878107) is a really good movie to watch during the weekend. It is set on Mars and is based on the book by Andy Weir. Weir is a retired engineer who won an international writing contest for promising science fiction writers. The movie is funny and at the same time it is thoughtful and inspiring. I will recommend this movie to you because of the following reasons.

1. The movie
Estimated TFLOPs: 384.19
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ finetune/lora.py: │
│ 288 in <module>                                                                                  │
│                                                                                                  │
│   285 │   │   message="Remove `.no_backward_sync()` from your code",                             │
│   286 │   )                                                                                      │
│   287 │                                                                                          │
│ ❱ 288 │   CLI(setup)                                                                             │
│   289                                                                                            │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:85 in CLI                             │
│                                                                                                  │
│    82 │   │   │   return parser                                                                  │
│    83 │   │   cfg = parser.parse_args(args)                                                      │
│    84 │   │   cfg_init = parser.instantiate_classes(cfg)                                         │
│ ❱  85 │   │   return _run_component(component, cfg_init)                                         │
│    86 │                                                                                          │
│    87 │   subcommands = parser.add_subcommands(required=True)                                    │
│    88 │   comp_dict = {c.__name__: c for c in components}                                        │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:147 in _run_component                 │
│                                                                                                  │
│   144 def _run_component(component, cfg):                                                        │
│   145 │   cfg.pop("config", None)                                                                │
│   146 │   if not inspect.isclass(component):                                                     │
│ ❱ 147 │   │   return component(**cfg)                                                            │
│   148 │   subcommand = cfg.pop("subcommand")                                                     │
│   149 │   if not subcommand:                                                                     │
│   150 │   │   return component(**cfg)                                                            │
│                                                                                                  │
│ finetune/lora.py: │
│ 75 in setup                                                                                      │
│                                                                                                  │
│    72 │   print(hparams)                                                                         │
│    73 │                                                                                          │
│    74 │   fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision)      │
│ ❱  75 │   fabric.launch(main, data_dir, checkpoint_dir, out_dir, precision)                      │
│    76                                                                                            │
│    77                                                                                            │
│    78 def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, precisio   │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:759 in launch                  │
│                                                                                                  │
│   756 │   │   │   │   f"To use the `{type(self.strategy).__name__}` strategy, `.launch()` need   │
│   757 │   │   │   │   " that contains the code to launch in processes."                          │
│   758 │   │   │   )                                                                              │
│ ❱ 759 │   │   return self._wrap_and_launch(function, self, *args, **kwargs)                      │
│   760 │                                                                                          │
│   761 │   def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:                     │
│   762 │   │   """Trigger the callback methods with the given name and arguments.                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:841 in _wrap_and_launch        │
│                                                                                                  │
│   838 │   │   to_run = partial(self._wrap_with_setup, to_run)                                    │
│   839 │   │   if (launcher := self._strategy.launcher) is not None:                              │
│   840 │   │   │   return launcher.launch(to_run, *args, **kwargs)                                │
│ ❱ 841 │   │   return to_run(*args, **kwargs)                                                     │
│   842 │                                                                                          │
│   843 │   def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:        │
│   844 │   │   self._strategy.setup_environment()                                                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:846 in _wrap_with_setup        │
│                                                                                                  │
│   843 │   def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:        │
│   844 │   │   self._strategy.setup_environment()                                                 │
│   845 │   │   with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(Bat   │
│ ❱ 846 │   │   │   return to_run(*args, **kwargs)                                                 │
│   847 │                                                                                          │
│   848 │   def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn   │
│   849 │   │   initial_device = next(model.parameters(), torch.tensor(0)).device                  │
│                                                                                                  │
│ finetune/lora.py: │
│ 112 in main                                                                                      │
│                                                                                                  │
│   109 │   │   max_seq_length = json.load(data_config_path).get("max_seq_length", model.config.   │
│   110 │                                                                                          │
│   111 │   train_time = time.time()                                                               │
│ ❱ 112 │   train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir, max_s   │
│   113 │   fabric.print(f"Training time: {(time.time()-train_time):.2f}s")                        │
│   114 │                                                                                          │
│   115 │   # Save the final LoRA checkpoint at the end of training                                │
│                                                                                                  │
│ finetune/lora.py: │
│ 138 in train                                                                                     │
│                                                                                                  │
│   135 │   estimated_flops = estimate_flops(model) * micro_batch_size                             │
│   136 │   fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")    │
│   137 │   if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported                  │
│ ❱ 138 │   │   measured_flops = measure_flops(                                                    │
│   139 │   │   │   model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), devi   │
│   140 │   │   )                                                                                  │
│   141 │   │   fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}"   │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/speed_ │
│ monitor.py:269 in measure_flops                                                                  │
│                                                                                                  │
│   266 │   flop_counter = FlopCounterMode(model, display=False)                                   │
│   267 │   ctx = nullcontext() if model.training else torch.no_grad()                             │
│   268 │   with ctx, flop_counter:                                                                │
│ ❱ 269 │   │   y = model(x)                                                                       │
│   270 │   │   if model.training:                                                                 │
│   271 │   │   │   y.sum().backward()                                                             │
│   272 │   return flop_counter.get_total_flops()                                                  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/wrappers.py:116 in forward               │
│                                                                                                  │
│   113 │   │   args, kwargs = self._precision.convert_input((args, kwargs))                       │
│   114 │   │                                                                                      │
│   115 │   │   with self._precision.forward_context():                                            │
│ ❱ 116 │   │   │   output = self._forward_module(*args, **kwargs)                                 │
│   117 │   │                                                                                      │
│   118 │   │   output = self._precision.convert_output(output)                                    │
│   119 │   │   return output                                                                      │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:92 in forward                                                                                 │
│                                                                                                  │
│    89 │   │                                                                                      │
│    90 │   │   if input_pos is None:  # proxy for use_cache=False                                 │
│    91 │   │   │   for block in self.transformer.h:                                               │
│ ❱  92 │   │   │   │   x, *_ = block(x, (cos, sin), mask, max_seq_length)                         │
│    93 │   │   else:                                                                              │
│    94 │   │   │   self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, c   │
│    95 │   │   │   for i, block in enumerate(self.transformer.h):                                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:158 in forward                                                                                │
│                                                                                                  │
│   155 │   │   kv_cache: Optional[KVCache] = None,                                                │
│   156 │   ) -> Tuple[torch.Tensor, Optional[KVCache]]:                                           │
│   157 │   │   n_1 = self.norm_1(x)                                                               │
│ ❱ 158 │   │   h, new_kv_cache = self.attn(n_1, rope, mask, max_seq_length, input_pos, kv_cache   │
│   159 │   │   if self.config.parallel_residual:                                                  │
│   160 │   │   │   n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)             │
│   161 │   │   │   x = x + h + self.mlp(n_2)                                                      │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:233 in forward                                                                                │
│                                                                                                  │
│   230 │   │   │   kv_cache = k, v                                                                │
│   231 │   │                                                                                      │
│   232 │   │   # efficient attention using Flash Attention CUDA kernels                           │
│ ❱ 233 │   │   y = F.scaled_dot_product_attention(                                                │
│   234 │   │   │   q, k, v, attn_mask=mask, dropout_p=0.0, scale=1.0 / math.sqrt(self.config.he   │
│   235 │   │   )                                                                                  │
│   236                                                                                            │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/utils/flop_counter.py:395 in __torch_dispatch__     │
│                                                                                                  │
│   392 │                                                                                          │
│   393 │   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                       │
│   394 │   │   kwargs = kwargs if kwargs else {}                                                  │
│ ❱ 395 │   │   out = func(*args, **kwargs)                                                        │
│   396 │   │   func_packet = func._overloadpacket                                                 │
│   397 │   │   if func_packet in self.flop_mapping:                                               │
│   398 │   │   │   flop_count_func = self.flop_mapping[func_packet]                               │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/_ops.py:401 in __call__                             │
│                                                                                                  │
│   398 │   │   )                                                                                  │
│   399 │                                                                                          │
│   400 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 401 │   │   return self._op(*args, **kwargs or {})                                             │
│   402 │                                                                                          │
│   403 │   def __hash__(self):                                                                    │
│   404 │   │   return hash(self._op)                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 2.22 GiB. GPU 0 has a total capacty of 79.15 GiB of which 228.38 MiB is free. Including non-PyTorch memory, this process
has 78.93 GiB memory in use. Of the allocated memory 76.28 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory 
is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It is also using just 1 GPU and not 8 that I have. Please help me resolve these issues ASAP. Thanks!

k21993 avatar Jun 16 '23 08:06 k21993

I have same issue on 48GB GPU - am following to see what the solution is.

griff4692 avatar Jun 16 '23 09:06 griff4692

tldr you can can force the strategy to be deepspeed and it should run. the default ds config is stage 2 which is effective even on a single gpu.

griff4692 avatar Jun 16 '23 09:06 griff4692

@griff4692 Thanks for the pointer, I hardcoded strategy as strategy = DeepSpeedStrategy(config=ds_config) here and it runs! Although there are two issues that I see:

  1. Peak GPU memory is ~30 GB.
  2. It only runs on 1 GPU even if multiple GPUs are available.
  3. Why does it run out of memory on a 80GB GPU when deepspeed is not enabled? The model is ~30GB right?

Do you know why this is the case?

k21993 avatar Jun 16 '23 09:06 k21993

The devices constant in Lora.py is set to 1. You could try changing it and see what happens

griff4692 avatar Jun 16 '23 09:06 griff4692

Aah I didn't realize it was hardcoded there, thanks!

k21993 avatar Jun 16 '23 10:06 k21993

The devices constant in Lora.py is set to 1. You could try changing it and see what happens

@awaelchli @lantiga maybe show a warning if more devices are available?

aniketmaurya avatar Jun 16 '23 10:06 aniketmaurya

@k21993 LoRA with Falcon 7B should work on a single GPU with ~16 Gb. If not, you can change the micro_batch_size = 4 to micro_batch_size = 1 (it only affects the runtime) or try to reduce the LoRA rank.

rasbt avatar Jun 16 '23 18:06 rasbt

what else did you change? even I change micro_batch_size = 4 to micro_batch_size = 1, LoRA with Falcon 7B does not work on a single GPU with 24 GB.

xy990 avatar Jun 16 '23 20:06 xy990

That's weird, here are the complete settings I used https://github.com/rasbt/LLM-finetuning-scripts/blob/main/lit-benchmarks/falcon-7b/finetune/lora.py

via

python finetune/lora.py  --checkpoint_dir checkpoints/tiiuae/falcon-7b/

the peak memory use was 16.97 according to

print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)

rasbt avatar Jun 16 '23 20:06 rasbt

@rasbt @aniketmaurya

  1. I tried running with 8 A100 (80GB) GPUs with the settings:
batch_size = 64
micro_batch_size = 4
lora_r = 8
devices=8

It runs for ~15k iterations and eventually fails with:

OutOfMemoryError: CUDA out of memory. Tried to allocate 632.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 452.44 MiB is free. Process 147633 has 32.01 GiB memory in 
use. Including non-PyTorch memory, this process has 46.70 GiB memory in use. Of the allocated memory 42.63 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but 
unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
PYTORCH_CUDA_ALLOC_CONF
  1. If I set devices=1 and run
python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/

It fails on startup itself.

  1. If I set devices=1 and hardcode strategy=deepspeed, it still uses a lot of memory: Screenshot 2023-06-16 at 1 42 26 PM

k21993 avatar Jun 16 '23 20:06 k21993

Regarding the 1 GPU setting you have above, you mention micro_batch_size = 4. So if you set this to micro_batch_size = 1, then theoretically it should work: 67,775 Mib / 4 = 16,943 Mib

rasbt avatar Jun 16 '23 20:06 rasbt

Regarding multi-GPU training, it is currently set to deep speed stage 2, which is not very memory efficient (it optimizes for speed). If you set this to deepspeed stage 3, it is more memory-efficient, but there is currently a bug with stage 3 & multi-GPU (#161). But the 1 GPU case should definitely work.

rasbt avatar Jun 16 '23 20:06 rasbt

I have a fix in #171 that will reduce the memory requirements for fine-tuning and training

carmocca avatar Jun 19 '23 19:06 carmocca

@carmocca Seems like this is a fix for the adapter method but not lora based on the PR. Can you outline the basic steps to make these changes for lora?

k21993 avatar Jun 19 '23 20:06 k21993

@k21993 the fix above also applies to lora

carmocca avatar Jun 19 '23 20:06 carmocca

Hey @carmocca I tried your fix and the memory requirement seems to be the same while the iteration time decreases from ~10s to ~7s.

Here's my config:

max_seq_len = 2048
micro_batch_size = 2
batch_size = 64
lora_r = 64
lora_alpha = 128
devices = 1
ds_config = {
    "train_micro_batch_size_per_gpu": micro_batch_size,
    "gradient_accumulation_steps": gradient_accumulation_iters,
    "zero_optimization": {"stage": 2},
}  

The memory occupied is the same (~73 GB) Screenshot 2023-06-19 at 2 55 05 PM

k21993 avatar Jun 19 '23 21:06 k21993

I did not do a deep analysis but here is what helped in my case (now mem consumption is constant at ~ 16GB with micro_batch_size of 1): First I removed the SpeedMonitor because for some reason this needed lots of memory. Second I have seen that over the training time more and more memory was consumed -- I now call torch.cuda.empty_cache() every n iterations and now the mem consumption is constant over time too.

peerdavid avatar Jun 20 '23 06:06 peerdavid

I'm currently following the instructions for fine tuning Falcon 7B with adapter V2 and ran into similar issues. I deleted the following lines in train:

    if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported
        measured_flops = measure_flops(
            model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), device=fabric.device)
        )
        fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
    else:
        measured_flops = None

and just replaced them with measured_flops = None. That seemed to fix everything for me on an NVIDIA RTX A6000 (48GB). That might be why setting the strategy to deepspeed seems to fix things.

fozziethebeat avatar Jun 20 '23 07:06 fozziethebeat

I'm currently following the instructions for fine tuning Falcon 7B with adapter V2 and ran into similar issues. I deleted the following lines in train:

    if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported
        measured_flops = measure_flops(
            model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), device=fabric.device)
        )
        fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
    else:
        measured_flops = None

and just replaced them with measured_flops = None. That seemed to fix everything for me on an NVIDIA RTX A6000 (48GB). That might be why setting the strategy to deepspeed seems to fix things.

I lied, I still ran into an OOM issue about 80 steps in after fixing a NaN problem (solved by using --precision bf16-mixed).

I've tried using adapter_v2.py, adapter.py and lora.py. All quickly OOM on my 48GB GPU (within 80 steps). Not sure what's causing this yet.

EDIT: With some tweaking, changing these settings got me a few more steps (up to about 600) before OOM:

batch_size = 64 / devices
micro_batch_size = 1

broadly, it'd be nice if the scripts referenced in the guide worked as reported. Even with all these tweaks the minimum vram usage i'm seeing when training starts is ~30GB, not 16GB.

fozziethebeat avatar Jun 20 '23 08:06 fozziethebeat

@fozziethebeat What's your micro_batch_size and max_seq_len? Since the sequence length is local to the batch, may be it finds a batch later in your training that is big enough to cause OOM.

k21993 avatar Jun 20 '23 08:06 k21993

i'm using the default max_seq_length as generated by running

python scripts/prepare_alpaca.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/

Looking at the config directly, looks like it's 1079. That doesn't seem to extreme to me and is lower than the block size (2048) reported by falcon-7b.

fozziethebeat avatar Jun 20 '23 08:06 fozziethebeat

So I'm having the same issue -- memory consumption is constant in general but after about 50 steps an OOM is raised. I logged the sequence length and in my case its definitely because of the sequence length (thanks for the hint @k21993) -- it happens exactly after the "1079 sample" occurs. All other samples are <= 650 until this point and exactly after this batch an OOM is raised -- which is fine IMO...

Update: When I restrict the token length it trains without OOMs :) Still its worth mentioning that I use a 3090 GPU so I have only 24GB of ram.

peerdavid avatar Jun 20 '23 09:06 peerdavid

I merged #173, that should fix the FLOPs counter issue.

I'll try replicating the sequence length issues you are seeing now

carmocca avatar Jun 20 '23 15:06 carmocca

So I'm having the same issue -- memory consumption is constant in general but after about 50 steps an OOM is raised. I logged the sequence length and in my case its definitely because of the sequence length (thanks for the hint @k21993) -- it happens exactly after the "1079 sample" occurs. All other samples are <= 650 until this point and exactly after this batch an OOM is raised -- which is fine IMO...

Update: When I restrict the token length it trains without OOMs :) Still its worth mentioning that I use a 3090 GPU so I have only 24GB of ram.

Noticing the same thing on my end. Specifically iter 251 gets a token length around 600 and crashes on my 3090. I modified the script to skip any inputs above 600 and it trains a little longer but crashes later on around a 500 token input. It appears the memory usage slowly creeps up over a few minutes while training, maybe something is not being released correctly.

cipher982 avatar Jun 20 '23 16:06 cipher982

Hey all. Using current main, here's what I'm calling:

python finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true

with micro_batch_size=1 I get a constant ~16GB use. It might seem to slowly creep up, but that is just the CUDA allocator keeping more than it needs. As https://github.com/Lightning-AI/lit-parrot/issues/159#issuecomment-1598193614 mentioned, empty_cache() will keep it down, but beware because that will slow it down a lot, so don't call it often if you need it.

In terms of model requirements, here's what you expect

Number of trainable parameters: 1365330
Number of non trainable parameters: 7217189760
Sum: 7218555090

Model weights fp32: 7218555090 * 4 / 1e9 = 28.87 GB
AdamW fp32: 2 * 4 * 1365330 / 1e9 = 0.01 GB

Which matches the observed 29.02 GB returned by torch.cuda.memory_reserved() and --precision bf16-mixed. Using 16-true or bf16-true, the memory is halved.

All is working as expected so far. Now, if I force all inputs to be of the maximum sequence length for the alpaca dataset (1079), the max memory reserved does jump to 24.5 GB.

I'll open a PR trying to alleviate that jump, as it's caused by an autograd issue with backward. However, you might still need to tweak the max_seq_length depending on your available GPU memory

carmocca avatar Jun 21 '23 00:06 carmocca

Thank you! This so far seems to be the needed fix.

Trying now at main and this so far is working really smoothly. Using the exact command you tried, I'm seeing ~29GB VRAM usage and no NaNs in my loss function. So far at step 600 and no issues.

I do see small memory increases but it's much less dramatic than before.

EDIT: posted too soon. Hit an OOM after iter 1599 step 100

fozziethebeat avatar Jun 21 '23 04:06 fozziethebeat

I merged #178 which should be a small decrease in memory usage.

I'll also be adding #182 which includes a change so that the longest alpaca sequence is loaded first, so that OOM happens at the beginning.

For the deepspeed issues, I'll be replacing it with FSDP in #118

Closing this issue. Feel free to open new ones for any new issues. Thank you all

carmocca avatar Jun 21 '23 16:06 carmocca

Should this be staying under 48GB VRAM usage when we run the command below at head?

python finetune/adapter.py \
    --data_dir data/alpaca  \
    --checkpoint_dir checkpoints/tiiuae/falcon-7b \
    --out_dir out/adapter/alpaca --precision bf16-true

I've just tried this out and I still see a OOM at iter 1599 step 100.

fozziethebeat avatar Jun 22 '23 00:06 fozziethebeat

Should this be staying under 48GB VRAM usage when we run the command below at head?

python finetune/adapter.py \
    --data_dir data/alpaca  \
    --checkpoint_dir checkpoints/tiiuae/falcon-7b \
    --out_dir out/adapter/alpaca --precision bf16-true

I've just tried this out and I still see a OOM at iter 1599 step 100.

Trying now on A6000 and it looks like I am basically maxed out on ~48GB right from the start. So possible it moves a bit up/down from there and gets OOM.

cipher982 avatar Jun 22 '23 00:06 cipher982

That's exactly what I noticed. It started at 100% VRAM usage and then something at iter 1599 step 100 kills it with the tiniest increase of memory.

fozziethebeat avatar Jun 22 '23 00:06 fozziethebeat