OpenChatKit icon indicating copy to clipboard operation
OpenChatKit copied to clipboard

Add documentation for running inference on multiple GPUs

Open satpalsr opened this issue 2 years ago • 22 comments

While trying out python inference/bot.py --retrieval --model togethercomputer/GPT-NeoXT-Chat-Base-20B I got this error on A100 GPU:

File "inference/bot.py", line 185, in <module>
    main()
  File "inference/bot.py", line 173, in main
    OpenChatKitShell(
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/cmd.py", line 138, in cmdloop
    stop = self.onecmd(line)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/cmd.py", line 217, in onecmd
    return func(arg)
  File "inference/bot.py", line 87, in do_say
    output = self._model.do_inference(
  File "inference/bot.py", line 32, in do_inference
    outputs = self._model.generate(
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/transformers/generation_utils.py", line 1326, in generate
    return self.sample(
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/transformers/generation_utils.py", line 1944, in sample
    outputs = self(
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 619, in forward
    outputs = self.gpt_neox(
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 511, in forward
    outputs = layer(
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 319, in forward
    attention_layer_outputs = self.attention(
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 115, in forward
    qkv = self.query_key_value(hidden_states)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/anaconda3/envs/openkit/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

satpalsr avatar Mar 12 '23 07:03 satpalsr

I've seen this issue when running out of GPU RAM. Unfortunately, the model requires an A100 80GB right now. Are you using an A100 40GB?

csris avatar Mar 12 '23 07:03 csris

yeah! It's 40 GB, but I have 8 of them. Can I use them together to avoid this issue?

The problem occurs after loading both model and retrieval index when I type out the prompt.

satpalsr avatar Mar 12 '23 07:03 satpalsr

For inference, I saw that some folks on Discord were able to run on multiple cards in this thread. I haven't had a chance to try it myself.

For the retrieval index, you can control which GPU the index is loaded on by modifying this line, I believe.

@LorrinWWW, any other advice?

csris avatar Mar 12 '23 08:03 csris

Can you share invite link to the discord server? Can't access the thread.

satpalsr avatar Mar 12 '23 08:03 satpalsr

Of course! https://discord.gg/9Rk6sSeWEG

csris avatar Mar 12 '23 08:03 csris

@satpalsr I can use multiple GPUs with the following snippet:

import torch
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoModelForCausalLM

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

tokenizer = AutoTokenizer.from_pretrained('togethercomputer/GPT-NeoXT-Chat-Base-20B')
model = AutoModelForCausalLM.from_pretrained('togethercomputer/GPT-NeoXT-Chat-Base-20B')

max_memory = get_balanced_memory(
    model,
    max_memory=None,
    no_split_module_classes=["GPTNeoXLayer"],
    dtype='float16',
    low_zero=False,
)

device_map = infer_auto_device_map(
    model, 
    max_memory=max_memory,
    no_split_module_classes=["GPTNeoXLayer"], 
    dtype='float16'
)

model = dispatch_model(model, device_map=device_map)

But I recommend just using two A100 40G, because more doesn't provide acceleration.

LorrinWWW avatar Mar 12 '23 15:03 LorrinWWW

I'm re-purposing this issue to track adding multi-GPU inference documentation to the repo.

csris avatar Mar 12 '23 22:03 csris

can i use 4T4 GPU? is the 4T4 GPU memory enough?? thanks

git3210 avatar Mar 14 '23 02:03 git3210

Contributor

@LorrinWWW I use the same code to inference the model, but it still the error, is it mean that need to use more gpus, like 4 A100 40G?

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:1                                                                                    │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in    │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/transformers/generation/utils.py:143 │
│ 7 in generate                                                                                    │
│                                                                                                  │
│   1434 │   │   │   )                                                                             │
│   1435 │   │   │                                                                                 │
│   1436 │   │   │   # 13. run sample                                                              │
│ ❱ 1437 │   │   │   return self.sample(                                                           │
│   1438 │   │   │   │   input_ids,                                                                │
│   1439 │   │   │   │   logits_processor=logits_processor,                                        │
│   1440 │   │   │   │   logits_warper=logits_warper,                                              │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/transformers/generation/utils.py:244 │
│ 3 in sample                                                                                      │
│                                                                                                  │
│   2440 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  │
│   2441 │   │   │                                                                                 │
│   2442 │   │   │   # forward pass to get next token                                              │
│ ❱ 2443 │   │   │   outputs = self(                                                               │
│   2444 │   │   │   │   **model_inputs,                                                           │
│   2445 │   │   │   │   return_dict=True,                                                         │
│   2446 │   │   │   │   output_attentions=output_attentions,                                      │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/accelerate/hooks.py:165 in           │
│ new_forward                                                                                      │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/transformers/models/gpt_neox/modelin │
│ g_gpt_neox.py:654 in forward                                                                     │
│                                                                                                  │
│   651 │   │   ```"""                                                                             │
│   652 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return   │
│   653 │   │                                                                                      │
│ ❱ 654 │   │   outputs = self.gpt_neox(                                                           │
│   655 │   │   │   input_ids,                                                                     │
│   656 │   │   │   attention_mask=attention_mask,                                                 │
│   657 │   │   │   head_mask=head_mask,                                                           │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/transformers/models/gpt_neox/modelin │
│ g_gpt_neox.py:546 in forward                                                                     │
│                                                                                                  │
│   543 │   │   │   │   │   head_mask[i],                                                          │
│   544 │   │   │   │   )                                                                          │
│   545 │   │   │   else:                                                                          │
│ ❱ 546 │   │   │   │   outputs = layer(                                                           │
│   547 │   │   │   │   │   hidden_states,                                                         │
│   548 │   │   │   │   │   attention_mask=attention_mask,                                         │
│   549 │   │   │   │   │   head_mask=head_mask[i],                                                │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/accelerate/hooks.py:165 in           │
│ new_forward                                                                                      │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/transformers/models/gpt_neox/modelin │
│ g_gpt_neox.py:319 in forward                                                                     │
│                                                                                                  │
│   316 │   │   output_attentions=False,                                                           │
│   317 │   ):                                                                                     │
│   318 │   │                                                                                      │
│ ❱ 319 │   │   attention_layer_outputs = self.attention(                                          │
│   320 │   │   │   self.input_layernorm(hidden_states),                                           │
│   321 │   │   │   attention_mask=attention_mask,                                                 │
│   322 │   │   │   layer_past=layer_past,                                                         │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/accelerate/hooks.py:165 in           │
│ new_forward                                                                                      │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/transformers/models/gpt_neox/modelin │
│ g_gpt_neox.py:115 in forward                                                                     │
│                                                                                                  │
│   112 │   │   # Compute QKV                                                                      │
│   113 │   │   # Attention heads [batch, seq_len, hidden_size]                                    │
│   114 │   │   #   --> [batch, seq_len, (np * 3 * head_size)]                                     │
│ ❱ 115 │   │   qkv = self.query_key_value(hidden_states)                                          │
│   116 │   │                                                                                      │
│   117 │   │   # [batch, seq_len, (num_heads * 3 * head_size)]                                    │
│   118 │   │   #   --> [batch, seq_len, num_heads, 3 * head_size]                                 │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/accelerate/hooks.py:165 in           │
│ new_forward                                                                                      │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
│                                                                                                  │
│ /data/miniconda3/envs/env-3.8.8/lib/python3.8/site-packages/torch/nn/modules/linear.py:114 in    │
│ forward                                                                                          │
│                                                                                                  │
│   111 │   │   │   init.uniform_(self.bias, -bound, bound)                                        │
│   112 │                                                                                          │
│   113 │   def forward(self, input: Tensor) -> Tensor:                                            │
│ ❱ 114 │   │   return F.linear(input, self.weight, self.bias)                                     │
│   115 │                                                                                          │
│   116 │   def extra_repr(self) -> str:                                                           │
│   117 │   │   return 'in_features={}, out_features={}, bias={}'.format(                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

trouble-maker007 avatar Mar 14 '23 12:03 trouble-maker007

Hey @trouble-maker007,

I changed my torch version to solve it.

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

I also had bunch of other issues later, also mentioning them if you face same. Issue:

ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /admin/home-satpal/anaconda3/envs/OpenChatKit/lib/python3.10/site-packages/faiss/../../../libfaiss.so)

Solution:

conda install libgcc

Issue:

OSError: Could not find/load shared object file: libllvmlite.so
 Error was: /admin/home-satpal/anaconda3/lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /admin/home-satpal/anaconda3/envs/OpenChatKit/lib/python3.10/site-packages/llvmlite/binding/../../../../libLLVM-11.so)

Solution

conda uninstall llvmlite
pip install llvmlite

Issue:

no module named 'package'

Solution:

pip3 install --upgrade pip
pip3 install packaging

Issue:

dataset require "pandas"

Solution:

pip install pandas

Issue:

self._df_sentences = pd.read_parquet(wiki_sentence_path, engine='fastparquet')
ImportError: Missing optional dependency 'fastparquet'. fastparquet is required for parquet support. Use pip or conda to install fastparquet.

Solution:

pip install fastparquet

satpalsr avatar Mar 14 '23 16:03 satpalsr

You can inference it with multi-gpus using accelerate, the updated code in bot.py is below:

from accelerate import dispatch_model

class ChatModel:
    human_id = "<human>"
    bot_id = "<bot>"

    def __init__(self, model_name, gpu_id):
        kwargs = dict(
            # load_in_8bit=True,
            torch_dtype=torch.float16,
            device_map="auto",  # "balanced_low_0"
        )
        self._model = AutoModelForCausalLM.from_pretrained(
            model_name, **kwargs)

        self._model = dispatch_model(
            model=self._model,
            device_map=self._model.hf_device_map
        )

        self._tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)

The environment is

Ubuntu 20.04
Nvidia GTX 3090 * 4
CUDA Version: 11.7
Torch: 1.13.1
accelerate: 0.16.0

I have 2 3090 cards available, and it costs about 42G cuda-mem with above script. The launch cmd is CUDA_VISIBLE_DEVICES=2,3 python3 bot.py --model=/path/to/GPT-NeoXT-Chat-Base-20B

But it seems that the QA ability is poor and inference cost long time ( 30-40 secs with max_new_tokens=256). Here is a QA example

Q: where is china
A:
: china is a country in east asia.
: What is the capital of the country?
: the capital of china is beijing.
: What is the population of the country?
: the population of china is 1,384,000,000.
: What is the currency of the country?
: the currency of china is the renminbi.
: What is the currency of the country?
: the currency of china is the renminbi.

the part of the answer repeats almost in the end.

better629 avatar Mar 15 '23 09:03 better629

@better629 My inference is also slow, though I only use a single RTX-8000 GPU. I even load the model using load_in_8bit=True. The inference takes around 12s for max_new_tokens=64 and greedy search.

jasontian6666 avatar Mar 16 '23 02:03 jasontian6666

@jasontian6666 RTX-8000 has 48G gpu-mem which it's enough to load model even without load_in_8bit=True.

better629 avatar Mar 16 '23 03:03 better629

@better629 The original model is 48GB so I think for a single GPU, I would need something like A100-80GB. Loading it in 8bit, the model size is 22GB. I also tried to load it in torch_dtype=torch.float16, and the size is 40GB. The inference speed for the 8bit model is actually slower. I've no idea why.

jasontian6666 avatar Mar 16 '23 03:03 jasontian6666

@better629 This works for me.

My environment is

Ubuntu 18.04
Nvidia A100(40G) * 2
CUDA Version: 11.6
Torch: 1.13.1
accelerate: 0.17.1

randyadd163 avatar Mar 16 '23 04:03 randyadd163

@better629 My inference is also slow, though I only use a single RTX-8000 GPU. I even load the model using load_in_8bit=True. The inference takes around 12s for max_new_tokens=64 and greedy search.

Do you finetune the model successfully on a single RTX-8000 GPU? I have 4 RTX 6000 GPU, but I got CUDA out of memery error when running bash training/finetune_GPT-NeoXT-Chat-Base-20B.sh. Could you please tell me how I should do? Thanks!

wallon-ai avatar Mar 17 '23 02:03 wallon-ai

Any updates? @better629 's workaround perfectly works for few times of prompt, but it finally fail for like 8 times of prompt. As far I can tell, the GPU RAM get accumulated each time I input something.

Zaoyee avatar Mar 21 '23 09:03 Zaoyee

@csris any pointers on where I can find the 8x 80GB A100 instance type via the cloud? I checked Lambda labs and AWS And cant seem to find it. What do you use?

DeepTitan avatar Mar 21 '23 15:03 DeepTitan

Similar to what @Zaoyee found, I am observing that GPU memory is accumulating for every batch of inference. I tried torch.cuda.empty_cache() but still the same. I wonder if it is caused by accelerate.

jimmychou0704 avatar Mar 23 '23 23:03 jimmychou0704

yeah! It's 40 GB, but I have 8 of them. Can I use them together to avoid this issue?

The problem occurs after loading both model and retrieval index when I type out the prompt.

@satpalsr We've added new documentation and options for running inference on multiple GPUs, specific GPUs, and consumer hardware!

To run a 40 GB model (GPT-NeoXT-Chat-Base-20B) on 8 GPUs, I would recommend adding -g 0:5 1:5 2:5 3:5 4:5 5:5 6:5 7:10 to allocate 5 GB VRAM on each GPU (10 on the last for redundancy because the model is slightly >40 GB.

If you're running this on fewer than 8 GPUs, make sure that the Total VRAM > size of the model.

@better629 The original model is 48GB so I think for a single GPU, I would need something like A100-80GB. Loading it in 8bit, the model size is 22GB. I also tried to load it in torch_dtype=torch.float16, and the size is 40GB. The inference speed for the 8bit model is actually slower. I've no idea why.

@jasontian6666 @Zaoyee @jimmychou0704 If you find yourselves running out of VRAM, read the updated docs and add -g CUDA:VRAM where CUDA is the CUDA_ID and VRAM is the maximum memory (in GiB) you'd like to allocate to the device. You could limit it to say 30 GB and add -r 15 to allocate 15 GiB in CPU RAM to offload parts of the model that don't fit in the 30 GiB.

Note, this method can be slow, but it works.

Any updates? @better629 's workaround perfectly works for few times of prompt, but it finally fail for like 8 times of prompt. As far I can tell, the GPU RAM get accumulated each time I input something.

I have noticed that VRAM goes up by 100-200 MiB per prompt. I will look into what can be done, but for now, you should be able to offload parts of the model to the CPU RAM to make room for this and run it.

P.S. Upgrading transformers to the latest version will give you a progress bar while loading the model. This can be helpful when trying to load the model onto RAM/disk which can often be quite slow.

orangetin avatar Mar 29 '23 14:03 orangetin

@Zaoyee I found the gpu-mem accumulation too, I will check it out.

better629 avatar Mar 31 '23 02:03 better629

Close this issue and open a new one for gpu-mem accumulation? I believe this is solved.

orangetin avatar Mar 31 '23 20:03 orangetin