CUDA Out of Memory for the Falcon 7B model on A100 80GB GPU
I am trying to reproduce the Falcon-7B Lora fine-tuning on the Alpaca dataset. I followed the steps to convert the checkpoints to lightning format, downloaded and tokenized the Alpaca dataset as instructed. When I run:
python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/
I get the following traceback:
{'eval_interval': 100, 'save_interval': 100, 'eval_iters': 100, 'log_interval': 1, 'devices': 1, 'learning_rate': 0.0003, 'batch_size': 4, 'micro_batch_size': 4, 'gradient_accumulation_iters': 1, 'max_iters': 50000, 'weight_decay': 0.01, 'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05, 'warmup_iters': 100}
Using bfloat16 Automatic Mixed Precision (AMP)
Global seed set to 1337
Loading model 'checkpoints/tiiuae/falcon-7b/lit_model.pth' with {'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 512, 'padded_vocab_size': 65024, 'n_layer': 32, 'n_head': 71, 'n_embd': 4544, 'rotary_percentage': 1.0, 'parallel_residual': True, 'bias': False, 'n_query_groups': 1, 'shared_attention_norm': True}
Number of trainable parameters: 3506176
Validating ...
Recommend a movie for me to watch during the weekend and explain the reason.
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Recommend a movie for me to watch during the weekend and explain the reason.
### Response:
[The Martian](https://www.imdb.com/title/tt1878107) is a really good movie to watch during the weekend. It is set on Mars and is based on the book by Andy Weir. Weir is a retired engineer who won an international writing contest for promising science fiction writers. The movie is funny and at the same time it is thoughtful and inspiring. I will recommend this movie to you because of the following reasons.
1. The movie
Estimated TFLOPs: 384.19
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ finetune/lora.py: │
│ 288 in <module> │
│ │
│ 285 │ │ message="Remove `.no_backward_sync()` from your code", │
│ 286 │ ) │
│ 287 │ │
│ ❱ 288 │ CLI(setup) │
│ 289 │
│ │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:85 in CLI │
│ │
│ 82 │ │ │ return parser │
│ 83 │ │ cfg = parser.parse_args(args) │
│ 84 │ │ cfg_init = parser.instantiate_classes(cfg) │
│ ❱ 85 │ │ return _run_component(component, cfg_init) │
│ 86 │ │
│ 87 │ subcommands = parser.add_subcommands(required=True) │
│ 88 │ comp_dict = {c.__name__: c for c in components} │
│ │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:147 in _run_component │
│ │
│ 144 def _run_component(component, cfg): │
│ 145 │ cfg.pop("config", None) │
│ 146 │ if not inspect.isclass(component): │
│ ❱ 147 │ │ return component(**cfg) │
│ 148 │ subcommand = cfg.pop("subcommand") │
│ 149 │ if not subcommand: │
│ 150 │ │ return component(**cfg) │
│ │
│ finetune/lora.py: │
│ 75 in setup │
│ │
│ 72 │ print(hparams) │
│ 73 │ │
│ 74 │ fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision) │
│ ❱ 75 │ fabric.launch(main, data_dir, checkpoint_dir, out_dir, precision) │
│ 76 │
│ 77 │
│ 78 def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, precisio │
│ │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:759 in launch │
│ │
│ 756 │ │ │ │ f"To use the `{type(self.strategy).__name__}` strategy, `.launch()` need │
│ 757 │ │ │ │ " that contains the code to launch in processes." │
│ 758 │ │ │ ) │
│ ❱ 759 │ │ return self._wrap_and_launch(function, self, *args, **kwargs) │
│ 760 │ │
│ 761 │ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None: │
│ 762 │ │ """Trigger the callback methods with the given name and arguments. │
│ │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:841 in _wrap_and_launch │
│ │
│ 838 │ │ to_run = partial(self._wrap_with_setup, to_run) │
│ 839 │ │ if (launcher := self._strategy.launcher) is not None: │
│ 840 │ │ │ return launcher.launch(to_run, *args, **kwargs) │
│ ❱ 841 │ │ return to_run(*args, **kwargs) │
│ 842 │ │
│ 843 │ def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any: │
│ 844 │ │ self._strategy.setup_environment() │
│ │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:846 in _wrap_with_setup │
│ │
│ 843 │ def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any: │
│ 844 │ │ self._strategy.setup_environment() │
│ 845 │ │ with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(Bat │
│ ❱ 846 │ │ │ return to_run(*args, **kwargs) │
│ 847 │ │
│ 848 │ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn │
│ 849 │ │ initial_device = next(model.parameters(), torch.tensor(0)).device │
│ │
│ finetune/lora.py: │
│ 112 in main │
│ │
│ 109 │ │ max_seq_length = json.load(data_config_path).get("max_seq_length", model.config. │
│ 110 │ │
│ 111 │ train_time = time.time() │
│ ❱ 112 │ train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir, max_s │
│ 113 │ fabric.print(f"Training time: {(time.time()-train_time):.2f}s") │
│ 114 │ │
│ 115 │ # Save the final LoRA checkpoint at the end of training │
│ │
│ finetune/lora.py: │
│ 138 in train │
│ │
│ 135 │ estimated_flops = estimate_flops(model) * micro_batch_size │
│ 136 │ fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") │
│ 137 │ if not isinstance(fabric.strategy, DeepSpeedStrategy): # unsupported │
│ ❱ 138 │ │ measured_flops = measure_flops( │
│ 139 │ │ │ model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), devi │
│ 140 │ │ ) │
│ 141 │ │ fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}" │
│ │
│ code/lit-parrot/lit_parrot/speed_ │
│ monitor.py:269 in measure_flops │
│ │
│ 266 │ flop_counter = FlopCounterMode(model, display=False) │
│ 267 │ ctx = nullcontext() if model.training else torch.no_grad() │
│ 268 │ with ctx, flop_counter: │
│ ❱ 269 │ │ y = model(x) │
│ 270 │ │ if model.training: │
│ 271 │ │ │ y.sum().backward() │
│ 272 │ return flop_counter.get_total_flops() │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl │
│ │
│ 1499 │ │ if self._compiled_call_impl is not None: │
│ 1500 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1501 │ │ else: │
│ ❱ 1502 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1503 │ │
│ 1504 │ def _call_impl(self, *args, **kwargs): │
│ 1505 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl │
│ │
│ 1545 │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) │
│ 1546 │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1547 │ │ │
│ ❱ 1548 │ │ result = forward_call(*args, **kwargs) │
│ 1549 │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1550 │ │ │ for hook_id, hook in ( │
│ 1551 │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/wrappers.py:116 in forward │
│ │
│ 113 │ │ args, kwargs = self._precision.convert_input((args, kwargs)) │
│ 114 │ │ │
│ 115 │ │ with self._precision.forward_context(): │
│ ❱ 116 │ │ │ output = self._forward_module(*args, **kwargs) │
│ 117 │ │ │
│ 118 │ │ output = self._precision.convert_output(output) │
│ 119 │ │ return output │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl │
│ │
│ 1499 │ │ if self._compiled_call_impl is not None: │
│ 1500 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1501 │ │ else: │
│ ❱ 1502 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1503 │ │
│ 1504 │ def _call_impl(self, *args, **kwargs): │
│ 1505 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl │
│ │
│ 1545 │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) │
│ 1546 │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1547 │ │ │
│ ❱ 1548 │ │ result = forward_call(*args, **kwargs) │
│ 1549 │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1550 │ │ │ for hook_id, hook in ( │
│ 1551 │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ code/lit-parrot/lit_parrot/model. │
│ py:92 in forward │
│ │
│ 89 │ │ │
│ 90 │ │ if input_pos is None: # proxy for use_cache=False │
│ 91 │ │ │ for block in self.transformer.h: │
│ ❱ 92 │ │ │ │ x, *_ = block(x, (cos, sin), mask, max_seq_length) │
│ 93 │ │ else: │
│ 94 │ │ │ self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, c │
│ 95 │ │ │ for i, block in enumerate(self.transformer.h): │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl │
│ │
│ 1499 │ │ if self._compiled_call_impl is not None: │
│ 1500 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1501 │ │ else: │
│ ❱ 1502 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1503 │ │
│ 1504 │ def _call_impl(self, *args, **kwargs): │
│ 1505 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl │
│ │
│ 1545 │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) │
│ 1546 │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1547 │ │ │
│ ❱ 1548 │ │ result = forward_call(*args, **kwargs) │
│ 1549 │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1550 │ │ │ for hook_id, hook in ( │
│ 1551 │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ code/lit-parrot/lit_parrot/model. │
│ py:158 in forward │
│ │
│ 155 │ │ kv_cache: Optional[KVCache] = None, │
│ 156 │ ) -> Tuple[torch.Tensor, Optional[KVCache]]: │
│ 157 │ │ n_1 = self.norm_1(x) │
│ ❱ 158 │ │ h, new_kv_cache = self.attn(n_1, rope, mask, max_seq_length, input_pos, kv_cache │
│ 159 │ │ if self.config.parallel_residual: │
│ 160 │ │ │ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) │
│ 161 │ │ │ x = x + h + self.mlp(n_2) │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl │
│ │
│ 1499 │ │ if self._compiled_call_impl is not None: │
│ 1500 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1501 │ │ else: │
│ ❱ 1502 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1503 │ │
│ 1504 │ def _call_impl(self, *args, **kwargs): │
│ 1505 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl │
│ │
│ 1545 │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) │
│ 1546 │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1547 │ │ │
│ ❱ 1548 │ │ result = forward_call(*args, **kwargs) │
│ 1549 │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1550 │ │ │ for hook_id, hook in ( │
│ 1551 │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ code/lit-parrot/lit_parrot/model. │
│ py:233 in forward │
│ │
│ 230 │ │ │ kv_cache = k, v │
│ 231 │ │ │
│ 232 │ │ # efficient attention using Flash Attention CUDA kernels │
│ ❱ 233 │ │ y = F.scaled_dot_product_attention( │
│ 234 │ │ │ q, k, v, attn_mask=mask, dropout_p=0.0, scale=1.0 / math.sqrt(self.config.he │
│ 235 │ │ ) │
│ 236 │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/utils/flop_counter.py:395 in __torch_dispatch__ │
│ │
│ 392 │ │
│ 393 │ def __torch_dispatch__(self, func, types, args=(), kwargs=None): │
│ 394 │ │ kwargs = kwargs if kwargs else {} │
│ ❱ 395 │ │ out = func(*args, **kwargs) │
│ 396 │ │ func_packet = func._overloadpacket │
│ 397 │ │ if func_packet in self.flop_mapping: │
│ 398 │ │ │ flop_count_func = self.flop_mapping[func_packet] │
│ │
│ /opt/conda/lib/python3.8/site-packages/torch/_ops.py:401 in __call__ │
│ │
│ 398 │ │ ) │
│ 399 │ │
│ 400 │ def __call__(self, *args, **kwargs): │
│ ❱ 401 │ │ return self._op(*args, **kwargs or {}) │
│ 402 │ │
│ 403 │ def __hash__(self): │
│ 404 │ │ return hash(self._op) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 2.22 GiB. GPU 0 has a total capacty of 79.15 GiB of which 228.38 MiB is free. Including non-PyTorch memory, this process
has 78.93 GiB memory in use. Of the allocated memory 76.28 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory
is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
It is also using just 1 GPU and not 8 that I have. Please help me resolve these issues ASAP. Thanks!
I have same issue on 48GB GPU - am following to see what the solution is.
tldr you can can force the strategy to be deepspeed and it should run. the default ds config is stage 2 which is effective even on a single gpu.
@griff4692 Thanks for the pointer, I hardcoded strategy as strategy = DeepSpeedStrategy(config=ds_config) here and it runs! Although there are two issues that I see:
- Peak GPU memory is ~30 GB.
- It only runs on 1 GPU even if multiple GPUs are available.
- Why does it run out of memory on a 80GB GPU when deepspeed is not enabled? The model is ~30GB right?
Do you know why this is the case?
The devices constant in Lora.py is set to 1. You could try changing it and see what happens
Aah I didn't realize it was hardcoded there, thanks!
The devices constant in Lora.py is set to 1. You could try changing it and see what happens
@awaelchli @lantiga maybe show a warning if more devices are available?
@k21993 LoRA with Falcon 7B should work on a single GPU with ~16 Gb. If not, you can change the micro_batch_size = 4 to micro_batch_size = 1 (it only affects the runtime) or try to reduce the LoRA rank.
what else did you change? even I change micro_batch_size = 4 to micro_batch_size = 1, LoRA with Falcon 7B does not work on a single GPU with 24 GB.
That's weird, here are the complete settings I used https://github.com/rasbt/LLM-finetuning-scripts/blob/main/lit-benchmarks/falcon-7b/finetune/lora.py
via
python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/
the peak memory use was 16.97 according to
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
@rasbt @aniketmaurya
- I tried running with 8 A100 (80GB) GPUs with the settings:
batch_size = 64
micro_batch_size = 4
lora_r = 8
devices=8
It runs for ~15k iterations and eventually fails with:
OutOfMemoryError: CUDA out of memory. Tried to allocate 632.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 452.44 MiB is free. Process 147633 has 32.01 GiB memory in
use. Including non-PyTorch memory, this process has 46.70 GiB memory in use. Of the allocated memory 42.63 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but
unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and
PYTORCH_CUDA_ALLOC_CONF
- If I set
devices=1and run
python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/
It fails on startup itself.
- If I set
devices=1and hardcodestrategy=deepspeed, it still uses a lot of memory:
Regarding the 1 GPU setting you have above, you mention micro_batch_size = 4. So if you set this to micro_batch_size = 1, then theoretically it should work: 67,775 Mib / 4 = 16,943 Mib
Regarding multi-GPU training, it is currently set to deep speed stage 2, which is not very memory efficient (it optimizes for speed). If you set this to deepspeed stage 3, it is more memory-efficient, but there is currently a bug with stage 3 & multi-GPU (#161). But the 1 GPU case should definitely work.
I have a fix in #171 that will reduce the memory requirements for fine-tuning and training
@carmocca Seems like this is a fix for the adapter method but not lora based on the PR. Can you outline the basic steps to make these changes for lora?
@k21993 the fix above also applies to lora
Hey @carmocca I tried your fix and the memory requirement seems to be the same while the iteration time decreases from ~10s to ~7s.
Here's my config:
max_seq_len = 2048
micro_batch_size = 2
batch_size = 64
lora_r = 64
lora_alpha = 128
devices = 1
ds_config = {
"train_micro_batch_size_per_gpu": micro_batch_size,
"gradient_accumulation_steps": gradient_accumulation_iters,
"zero_optimization": {"stage": 2},
}
The memory occupied is the same (~73 GB)
I did not do a deep analysis but here is what helped in my case (now mem consumption is constant at ~ 16GB with micro_batch_size of 1): First I removed the SpeedMonitor because for some reason this needed lots of memory. Second I have seen that over the training time more and more memory was consumed -- I now call torch.cuda.empty_cache() every n iterations and now the mem consumption is constant over time too.
I'm currently following the instructions for fine tuning Falcon 7B with adapter V2 and ran into similar issues. I deleted the following lines in train:
if not isinstance(fabric.strategy, DeepSpeedStrategy): # unsupported
measured_flops = measure_flops(
model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), device=fabric.device)
)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
else:
measured_flops = None
and just replaced them with measured_flops = None. That seemed to fix everything for me on an NVIDIA RTX A6000 (48GB). That might be why setting the strategy to deepspeed seems to fix things.
I'm currently following the instructions for fine tuning Falcon 7B with adapter V2 and ran into similar issues. I deleted the following lines in
train:if not isinstance(fabric.strategy, DeepSpeedStrategy): # unsupported measured_flops = measure_flops( model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), device=fabric.device) ) fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") else: measured_flops = Noneand just replaced them with
measured_flops = None. That seemed to fix everything for me on an NVIDIA RTX A6000 (48GB). That might be why setting the strategy to deepspeed seems to fix things.
I lied, I still ran into an OOM issue about 80 steps in after fixing a NaN problem (solved by using --precision bf16-mixed).
I've tried using adapter_v2.py, adapter.py and lora.py. All quickly OOM on my 48GB GPU (within 80 steps). Not sure what's causing this yet.
EDIT: With some tweaking, changing these settings got me a few more steps (up to about 600) before OOM:
batch_size = 64 / devices
micro_batch_size = 1
broadly, it'd be nice if the scripts referenced in the guide worked as reported. Even with all these tweaks the minimum vram usage i'm seeing when training starts is ~30GB, not 16GB.
@fozziethebeat What's your micro_batch_size and max_seq_len? Since the sequence length is local to the batch, may be it finds a batch later in your training that is big enough to cause OOM.
i'm using the default max_seq_length as generated by running
python scripts/prepare_alpaca.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/
Looking at the config directly, looks like it's 1079. That doesn't seem to extreme to me and is lower than the block size (2048) reported by falcon-7b.
So I'm having the same issue -- memory consumption is constant in general but after about 50 steps an OOM is raised. I logged the sequence length and in my case its definitely because of the sequence length (thanks for the hint @k21993) -- it happens exactly after the "1079 sample" occurs. All other samples are <= 650 until this point and exactly after this batch an OOM is raised -- which is fine IMO...
Update: When I restrict the token length it trains without OOMs :) Still its worth mentioning that I use a 3090 GPU so I have only 24GB of ram.
I merged #173, that should fix the FLOPs counter issue.
I'll try replicating the sequence length issues you are seeing now
So I'm having the same issue -- memory consumption is constant in general but after about 50 steps an OOM is raised. I logged the sequence length and in my case its definitely because of the sequence length (thanks for the hint @k21993) -- it happens exactly after the "1079 sample" occurs. All other samples are <= 650 until this point and exactly after this batch an OOM is raised -- which is fine IMO...
Update: When I restrict the token length it trains without OOMs :) Still its worth mentioning that I use a 3090 GPU so I have only 24GB of ram.
Noticing the same thing on my end. Specifically iter 251 gets a token length around 600 and crashes on my 3090. I modified the script to skip any inputs above 600 and it trains a little longer but crashes later on around a 500 token input. It appears the memory usage slowly creeps up over a few minutes while training, maybe something is not being released correctly.
Hey all. Using current main, here's what I'm calling:
python finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true
with micro_batch_size=1 I get a constant ~16GB use. It might seem to slowly creep up, but that is just the CUDA allocator keeping more than it needs. As https://github.com/Lightning-AI/lit-parrot/issues/159#issuecomment-1598193614 mentioned, empty_cache() will keep it down, but beware because that will slow it down a lot, so don't call it often if you need it.
In terms of model requirements, here's what you expect
Number of trainable parameters: 1365330
Number of non trainable parameters: 7217189760
Sum: 7218555090
Model weights fp32: 7218555090 * 4 / 1e9 = 28.87 GB
AdamW fp32: 2 * 4 * 1365330 / 1e9 = 0.01 GB
Which matches the observed 29.02 GB returned by torch.cuda.memory_reserved() and --precision bf16-mixed. Using 16-true or bf16-true, the memory is halved.
All is working as expected so far. Now, if I force all inputs to be of the maximum sequence length for the alpaca dataset (1079), the max memory reserved does jump to 24.5 GB.
I'll open a PR trying to alleviate that jump, as it's caused by an autograd issue with backward. However, you might still need to tweak the max_seq_length depending on your available GPU memory
Thank you! This so far seems to be the needed fix.
Trying now at main and this so far is working really smoothly. Using the exact command you tried, I'm seeing ~29GB VRAM usage and no NaNs in my loss function. So far at step 600 and no issues.
I do see small memory increases but it's much less dramatic than before.
EDIT: posted too soon. Hit an OOM after iter 1599 step 100
I merged #178 which should be a small decrease in memory usage.
I'll also be adding #182 which includes a change so that the longest alpaca sequence is loaded first, so that OOM happens at the beginning.
For the deepspeed issues, I'll be replacing it with FSDP in #118
Closing this issue. Feel free to open new ones for any new issues. Thank you all
Should this be staying under 48GB VRAM usage when we run the command below at head?
python finetune/adapter.py \
--data_dir data/alpaca \
--checkpoint_dir checkpoints/tiiuae/falcon-7b \
--out_dir out/adapter/alpaca --precision bf16-true
I've just tried this out and I still see a OOM at iter 1599 step 100.
Should this be staying under 48GB VRAM usage when we run the command below at head?
python finetune/adapter.py \ --data_dir data/alpaca \ --checkpoint_dir checkpoints/tiiuae/falcon-7b \ --out_dir out/adapter/alpaca --precision bf16-trueI've just tried this out and I still see a OOM at iter 1599 step 100.
Trying now on A6000 and it looks like I am basically maxed out on ~48GB right from the start. So possible it moves a bit up/down from there and gets OOM.
That's exactly what I noticed. It started at 100% VRAM usage and then something at iter 1599 step 100 kills it with the tiniest increase of memory.