reasoning-on-graphs icon indicating copy to clipboard operation
reasoning-on-graphs copied to clipboard

CUDA Out of Memory Error with 16 GB RAM on running src/qa_prediction/predict_answer.py

Open himanshunaidu opened this issue 1 year ago • 4 comments

Greetings, I hope you are all doing well.

All in all, I must say this is really great work, and I am glad this is being maintained so well. I intend to utilize this project as a baseline for my own research line. So I have a small doubt on the GPU requirements of the project.

When I run the src/qa_prediction/gen_rule_path.py file, it runs without much problems, utilizing up to 13 GB memory. However, when I run the final prediction file src/qa_prediction/predict_answer.py, I run out of CUDA memory fairly quickly.

Command:

python src/qa_prediction/predict_answer.py \
        --model_name RoG \
        --model_path rmanluo/RoG \
        -d RoG-cwq \
        --prompt_path prompts/llama2_predict.txt \
        --add_rule \
        --rule_path results/gen_rule_path/RoG-cwq/RoG/test/predictions_3_False.jsonl

Output:

Load dataset from finished
Save results to:  results/KGQA/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl
Prepare pipline for inference...
/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py:823: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
  warnings.warn(
/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/pipelines/__init__.py:763: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
  warnings.warn(
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.05s/it]
Device set to use cuda:0
  0%|▍                                                                                                                  | 6/1628 [00:03<13:36,  1.99it/s]
Traceback (most recent call last):
  File "/home/ubuntu/ClaimBenchKG_Baselines/reasoning-on-graphs/src/qa_prediction/predict_answer.py", line 229, in <module>
    main(args, LLM)
  File "/home/ubuntu/ClaimBenchKG_Baselines/reasoning-on-graphs/src/qa_prediction/predict_answer.py", line 172, in main
    res = prediction(data, processed_list, input_builder, model)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/ClaimBenchKG_Baselines/reasoning-on-graphs/src/qa_prediction/predict_answer.py", line 78, in prediction
    prediction = model.generate_sentence(input)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/ClaimBenchKG_Baselines/reasoning-on-graphs/src/qa_prediction/../llms/language_models/llama.py", line 34, in generate_sentence
    outputs = self.generator(llm_input, return_full_text=False, max_new_tokens=self.args.max_new_tokens)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/pipelines/text_generation.py", line 285, in __call__
    return super().__call__(text_inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1362, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1369, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1269, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/pipelines/text_generation.py", line 383, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2255, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/generation/utils.py", line 3254, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 831, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 589, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 348, in forward
    hidden_states = self.mlp(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/rog3/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 186, in forward
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 84.00 MiB. GPU 0 has a total capacity of 14.58 GiB of which 53.56 MiB is free. Including non-PyTorch memory, this process has 14.52 GiB memory in use. Of the allocated memory 13.94 GiB is allocated by PyTorch, and 467.66 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

The documentation does refer to the fact that the minimum to run the model is 12 GB RAM, so it is quite understandable that it may actually be going slightly above the minimum requirements.

However, just in case, I wanted to ask if there are slight optimizations available that can be done to bring the GPU utilization to just below 14.5 GB. I went through the code, and based on the extent to which I have seen, the given command I ran should mostly incur the minimum possible utilization but I could be wrong.

I ask this because it is going to be quite a while before I have a bigger instance available, hence I thought it wouldn't harm to pose the question.

Please feel free to let me know if it is possible, and if not, we can close the issue right away.

Thank you!

himanshunaidu avatar Jan 20 '25 09:01 himanshunaidu

My machine has a 16GB GPU and I also encountered this problem. It works perfectly after using 4-bit quantization. I ran through the author's code in this way.

lijiabao2 avatar Mar 22 '25 14:03 lijiabao2

I see. Yeah quantization should definitely help. Problem is, I wanted to implement that as a baseline, so I wanted to implement it exactly as the authors did. But yes, thanks for the info. I'll keep this in mind

himanshunaidu avatar Mar 25 '25 17:03 himanshunaidu

Greetings, I hope you are all doing well.

The link of pre-trained weights: https://huggingface.co/rmanluo/RoG can not be found. I hope you can update it.

YRVGFO9588 avatar Apr 28 '25 13:04 YRVGFO9588

My machine has a 16GB GPU and I also encountered this problem. It works perfectly after using 4-bit quantization. I ran through the author's code in this way.

Hi, bro, do you have a similar result?

chrislouis0106 avatar May 27 '25 02:05 chrislouis0106