starcoder icon indicating copy to clipboard operation
starcoder copied to clipboard

Cuda OOM during generate() call on 4 GPUs

Open rggs opened this issue 2 years ago • 5 comments

I'm using the starcoderbase model across four GPUs to run inference. For some reason, one of the prompts I'm using causes a CUDA OOM error. Putting the tensor on GPU shows it's about 420 MB (the shape is [[1, 6086]]), which still seems high, but not high enough to cause the error. The error occurs during the call to generate, after the model has been parallelized and the input has been put on GPU. The full error is:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.59 GiB (GPU 0; 31.75 GiB total capacity; 
23.14 GiB already allocated; 2.50 GiB free; 28.32 GiB reserved in total by PyTorch) 
If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  
See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Here's a min example of my code:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, CodeGenForCausalLM, CodeGenTokenizer, AutoConfig
from huggingface_hub import hf_hub_download, snapshot_download
import accelerate


with open("single_prompt.txt", "r") as f:
    prompt = f.read()

tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase", revision="main", device_map="auto")
weights_location = snapshot_download("bigcode/starcoderbase") # The model is sharded and there's a .index.json file to direct the shards

config = AutoConfig.from_pretrained("bigcode/starcoderbase", pad_token_id=tokenizer.eos_token_id)
with accelerate.init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)

model.tie_weights()

model = accelerate.load_checkpoint_and_dispatch(model, weights_location, device_map="auto", no_split_module_classes=["GPTBigCodeBlock"])

prompt_tokenized = tokenizer(prompt, return_tensors="pt", truncation=False)

# Manually truncate
max_length = 8192
max_new_tokens = 256

token_len = prompt_tokenized["input_ids"].shape[1]
if token_len > max_length - max_new_tokens:
    prompt_tokenized["input_ids"] = prompt_tokenized["input_ids"][:, token_len - (max_length - max_new_tokens):]
    prompt_tokenized["attention_mask"] = prompt_tokenized["attention_mask"][:, token_len - (max_length - max_new_tokens):]


prompt_tokenized = prompt_tokenized.to("cuda")
sample = model.generate(**prompt_tokenized, max_new_tokens=max_new_tokens)
pc = tokenizer.decode(sample[:, prompt_tokenized["input_ids"].shape[1]:][0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])```

And, it's worth noting that I don't have the same issue with similarly sized models like ```CodeGen 16B```, although the max context length is smaller for those models which might contribute to the issue.

rggs avatar May 26 '23 14:05 rggs

Hi, I am not able to reproduce the error on my side. I tried with a single_prompt.txt that I created myself but it probably has significantly less tokens than yours. I also use 4 GPUs and I run my code with accelerate launch. I think (from what I see in your error) that I may have more GPU memory than you, which is likely to explain why I do not encounter that issue. You may want to share your single_prompt.txt for further analysis.

ArmelRandy avatar May 31 '23 10:05 ArmelRandy

Right, that was my thinking too, I just don't understand why this file is causing the error (i.e. say it's 20% larger than the other files I've used). I'm sort of wondering if there's a good way to ratchet down memory use for larger samples without doing something like split the file in two etc. BTW – The single prompt here is the Java implementation of Minesweeper from Rosetta Code. single_prompt.txt

rggs avatar May 31 '23 13:05 rggs

You can look at the hardware requirements for starcoder. Try Loading the model in 8bit with the code provided there.

ArmelRandy avatar May 31 '23 14:05 ArmelRandy

Thanks, I'll try that. Why would it be running OOM during a generate call though? The inputs are put on gpu prior to that call.

rggs avatar May 31 '23 21:05 rggs

You can create a callback and clear cache every now and then, and maybe do gc.collect(). To improve performance the allocator "refuses" to let cache memory go, i.e. an OOM. Just a suggestion, no guarantees of success.

phalexo avatar Jun 10 '23 15:06 phalexo