RL4LMs
RL4LMs copied to clipboard
OOM on summarization example
Hi there, I'm having OOM errors when running the summarization example on a 80GB A100 (CUDA 11.8).
I'm also getting some Tensorflow/TensorRT warnings, I'm wondering if it's related to that
2022-11-08 22:44:46.878785: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-08 22:44:47.016183: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-08 22:44:47.979748: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979824: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979834: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
OOM error:
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ │
│ /mnt/home/code/RL4LMs/scripts/training/train_text_generation.py:66 in │
│ <module> │
│ │
│ 63 │ │ │ │ │ │ help="Whether to use wandb logging") │
│ 64 │ args = parser.parse_args() │
│ 65 │ │
│ ❱ 66 │ main(args.config_path, │
│ 67 │ │ args.project_name, │
│ 68 │ │ args.experiment_name, │
│ 69 │ │ args.base_path_to_store_results, │
│ /mnt/home/code/RL4LMs/scripts/training/train_text_generation.py:42 in main │
│ │
│ 39 │ │ │ │ │ │ │ │ on_policy_alg_config=config["alg"], │
│ 40 │ │ │ │ │ │ │ │ train_eval_config=config["train_evalu │
│ 41 │ │ │ │ │ │ │ │ tracker=tracker) │
│ ❱ 42 │ trainer.train_and_eval() │
│ 43 │
│ 44 │
│ 45 if __name__ == "__main__": │
│ │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/training_utils.py:205 in │
│ train_and_eval │
│ │
│ 202 │ │ │ self._trainer_state["current_iter"] = epoch │
│ 203 │ │ │ │
│ 204 │ │ │ # inner rollout and learn loop for on-policy algorithm │
│ ❱ 205 │ │ │ self._alg.learn(self._n_steps_per_iter) │
│ 206 │ │ │ │
│ 207 │ │ │ # save the policy checkpoint │
│ 208 │ │ │ if (epoch + 1) % self._train_eval_config.get("save_every", │
│ │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ppo.py:347 in learn │
│ │
│ 344 │ │ reset_num_timesteps: bool = True, │
│ 345 │ ) -> "PPO": │
│ 346 │ │ │
│ ❱ 347 │ │ return super().learn( │
│ 348 │ │ │ total_timesteps=total_timesteps, │
│ 349 │ │ │ callback=callback, │
│ 350 │ │ │ log_interval=log_interval, │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/stable_baselines3/common/on │
│ _policy_algorithm.py:267 in learn │
│ │
│ 264 │ │ │ │ self.logger.record("time/total_timesteps", self.num_ti │
│ 265 │ │ │ │ self.logger.dump(step=self.num_timesteps) │
│ 266 │ │ │ │
│ ❱ 267 │ │ │ self.train() │
│ 268 │ │ │
│ 269 │ │ callback.on_training_end() │
│ 270 │
│ │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ppo.py:224 in train │
│ │
│ 221 │ │ │ │ if self.use_sde: │
│ 222 │ │ │ │ │ self.policy.reset_noise(self.batch_size) │
│ 223 │ │ │ │ │
│ ❱ 224 │ │ │ │ values, log_prob, entropy = self.policy.evaluate_actio │
│ 225 │ │ │ │ │ rollout_data.observations, actions) │
│ 226 │ │ │ │ values = values.flatten() │
│ 227 │ │ │ │ # Normalize advantage │
│ │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/policy.py:211 in │
│ evaluate_actions │
│ │
│ 208 │ │ │
│ 209 │ │ _, log_prob, entropy, _, _ = self.forward_policy(obs=obs, │
│ 210 │ │ │ │ │ │ │ │ │ │ │ │ │ │ actions=acti │
│ ❱ 211 │ │ values, _ = self.forward_value(obs) │
│ 212 │ │ │
│ 213 │ │ return values, log_prob, entropy │
│ 214 │
│ │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/policy.py:447 in │
│ forward_value │
│ │
│ 444 │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │
│ 445 │ │ │
│ 446 │ │ # and forrward pass to get hidden states │
│ ❱ 447 │ │ outputs = self._value_model( │
│ 448 │ │ │ **model_inputs, │
│ 449 │ │ │ output_hidden_states=True, │
│ 450 │ │ │ decoder_attention_mask=decoder_attn_mask, │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl │
│ │
│ 1107 │ │ # this function, and just call forward. │
│ 1108 │ │ if not (self._backward_hooks or self._forward_hooks or self._ │
│ 1109 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │ │ │ return forward_call(*input, **kwargs) │
│ 1111 │ │ # Do not call functions when jit is used │
│ 1112 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1113 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:1648 in forward │
│ │
│ 1645 │ │ │ │ decoder_attention_mask = decoder_attention_mask.to(se │
│ 1646 │ │ │
│ 1647 │ │ # Decode │
│ ❱ 1648 │ │ decoder_outputs = self.decoder( │
│ 1649 │ │ │ input_ids=decoder_input_ids, │
│ 1650 │ │ │ attention_mask=decoder_attention_mask, │
│ 1651 │ │ │ inputs_embeds=decoder_inputs_embeds, │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl │
│ │
│ 1107 │ │ # this function, and just call forward. │
│ 1108 │ │ if not (self._backward_hooks or self._forward_hooks or self._ │
│ 1109 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │ │ │ return forward_call(*input, **kwargs) │
│ 1111 │ │ # Do not call functions when jit is used │
│ 1112 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1113 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:1040 in forward │
│ │
│ 1037 │ │ │ │ │ None, # past_key_value is always None with gradi │
│ 1038 │ │ │ │ ) │
│ 1039 │ │ │ else: │
│ ❱ 1040 │ │ │ │ layer_outputs = layer_module( │
│ 1041 │ │ │ │ │ hidden_states, │
│ 1042 │ │ │ │ │ attention_mask=extended_attention_mask, │
│ 1043 │ │ │ │ │ position_bias=position_bias, │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl │
│ │
│ 1107 │ │ # this function, and just call forward. │
│ 1108 │ │ if not (self._backward_hooks or self._forward_hooks or self._ │
│ 1109 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │ │ │ return forward_call(*input, **kwargs) │
│ 1111 │ │ # Do not call functions when jit is used │
│ 1112 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1113 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:699 in forward │
│ │
│ 696 │ │ │ else: │
│ 697 │ │ │ │ query_length = None │
│ 698 │ │ │ │
│ ❱ 699 │ │ │ cross_attention_outputs = self.layer[1]( │
│ 700 │ │ │ │ hidden_states, │
│ 701 │ │ │ │ key_value_states=encoder_hidden_states, │
│ 702 │ │ │ │ attention_mask=encoder_attention_mask, │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl │
│ │
│ 1107 │ │ # this function, and just call forward. │
│ 1108 │ │ if not (self._backward_hooks or self._forward_hooks or self._ │
│ 1109 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │ │ │ return forward_call(*input, **kwargs) │
│ 1111 │ │ # Do not call functions when jit is used │
│ 1112 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1113 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:613 in forward │
│ │
│ 610 │ │ output_attentions=False, │
│ 611 │ ): │
│ 612 │ │ normed_hidden_states = self.layer_norm(hidden_states) │
│ ❱ 613 │ │ attention_output = self.EncDecAttention( │
│ 614 │ │ │ normed_hidden_states, │
│ 615 │ │ │ mask=attention_mask, │
│ 616 │ │ │ key_value_states=key_value_states, │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl │
│ │
│ 1107 │ │ # this function, and just call forward. │
│ 1108 │ │ if not (self._backward_hooks or self._forward_hooks or self._ │
│ 1109 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │ │ │ return forward_call(*input, **kwargs) │
│ 1111 │ │ # Do not call functions when jit is used │
│ 1112 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1113 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:509 in forward │
│ │
│ 506 │ │ ) │
│ 507 │ │ │
│ 508 │ │ # compute scores │
│ ❱ 509 │ │ scores = torch.matmul( │
│ 510 │ │ │ query_states, key_states.transpose(3, 2) │
│ 511 │ │ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_stat │
│ 512 │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA out of memory. Tried to allocate 150.00 MiB (GPU 0; 79.35 GiB
total capacity; 76.08 GiB already allocated; 108.19 MiB free; 77.20 GiB reserved
in total by PyTorch) If reserved memory is >> allocated memory try setting
max_split_size_mb to avoid fragmentation. See documentation for Memory
Management and PYTORCH_CUDA_ALLOC_CONF
Any clues what's the issue? 80GB seems like a lot for just a T5-base model
Hey, this is expected I think. Keep in mind we have three t5 models loaded into memory. policy, value and reference policy. So far, we have used 4 GPUs to run summarization tasks. If you have just one GPU, just try to reduce the n_envs
to a lower value. Also reduce the batch size of PPO. Otherwise, I would suggest running with more GPUs if possible.
I see thanks. I didn't expect there to be three full models. Any advantages vs. plugging three heads onto one language model trunk?
Also, I'm just curious, does the n_envs
parameter scale up the batch size? Why does it have influence on GPU memory?
Many thanks.
yes people seem to usually just have different heads
@gabrielhuang Sure, shared layers will be paramter efficient. I am not sure how much performance change it will bring in.
Regarding n_envs
, it controls batch_size
for generating rollouts. You can think of it as a batched generation.
I'm trying to get google/flan-t5-xxl
to run with a single A100 80GB gpu, for seq2seq policy.
Is there already a way to set the precision to bfloat16? (I don't see one, but just to be sure) If not I'll write a policy for that.
Also, will try to add model sharing between the policy and the value models, and allowing to freeze parts of the model.
Enabling offloading a model from GPU memory to CPU memory when it's not in use would likely be helpful too.
@gabrielhuang have you started doing work like this? (I'm also at Mila)
@JulesGM We don't have support for precision setting yet. You can implement this in a new policy and possibly we can try to merge this into existing classes (by configuring some args)
This is my current approach, indeed, just allowing the user to pass kwargs for from_pretrained
and Linear
.
Passing torch_dtype
to from_pretrained
and dtype
to Linear
works.
I suppose adding amp mixed precision auto-casting and gradient scaling would be a good / important idea though.
import copy
import sys
import torch
import transformers
sys.path.append("/home/mila/g/gagnonju/RL4LMs")
import rl4lms.envs.text_generation.registry as rl4lms_registry
import rl4lms.envs.text_generation.policy.seq2seq_policy as rl4lms_seq2seq_policy
from rl4lms.envs.text_generation import hf_generation_utils
class PrecisionControlSeq2SeqLMActorCriticPolicy(rl4lms_seq2seq_policy.Seq2SeqLMActorCriticPolicy):
def __init__(
self,
*args,
from_pretrained_kwargs,
head_kwargs,
**kwargs,
):
self._from_pretrained_kwargs = from_pretrained_kwargs
self._head_kwargs = head_kwargs
super().__init__(*args, **kwargs)
def _build_model_heads(self, model_name: str):
self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)
self._policy_model.__class__ = hf_generation_utils.override_generation_routines(
type(self._policy_model)
)
self._value_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
model_name, **self._from_pretrained_kwargs)
self._ref_model = copy.deepcopy(self._policy_model).eval()
self._value_head = torch.nn.Linear(
self._value_model.config.hidden_size, 1, bias=False, **self._head_kwargs,
).to(self.device)
# apply model parallel
if torch.cuda.is_available():
if self._apply_model_parallel and self._policy_model.is_parallelizable:
self._policy_model.parallelize()
self._ref_model.parallelize()
self._value_model.parallelize()
self._value_head = self._value_head.to(self.device)
else: # else defaults to data parallel
self._policy_model = torch.nn.DataParallel(self._policy_model.to(self.device))
self._ref_model = torch.nn.DataParallel(self._ref_model .to(self.device))
self._value_model = torch.nn.DataParallel(self._value_model .to(self.device))
self._value_head = torch.nn.DataParallel(self._value_head .to(self.device))
rl4lms_registry.PolicyRegistry.add(
"precision_control_seq2seq_lm_actor_critic",
PrecisionControlSeq2SeqLMActorCriticPolicy,
)
looks like stable baselines 3 doesn't support bfloat16, because of all the a_tensor_name.cpu().numpy()
calls. Indeed, doing that with a bfloat16
tensor leads to an exception, because torch tries to build a numpy array with the bfloat16 dtype, which is not supported by Numpy
in order for baselines 3 (and then rl4lms) to support bfloat16, it would suffice to modify a_tensor_name.cpu().numpy()
to a_tensor_name.cpu().float().numpy()
.
This is my current approach, indeed, just allowing the user to pass kwargs for
from_pretrained
andLinear
. Passingtorch_dtype
tofrom_pretrained
anddtype
toLinear
works.I suppose adding amp mixed precision auto-casting and gradient scaling would be a good / important idea though.
import copy import sys import torch import transformers sys.path.append("/home/mila/g/gagnonju/RL4LMs") import rl4lms.envs.text_generation.registry as rl4lms_registry import rl4lms.envs.text_generation.policy.seq2seq_policy as rl4lms_seq2seq_policy from rl4lms.envs.text_generation import hf_generation_utils class PrecisionControlSeq2SeqLMActorCriticPolicy(rl4lms_seq2seq_policy.Seq2SeqLMActorCriticPolicy): def __init__( self, *args, from_pretrained_kwargs, head_kwargs, **kwargs, ): self._from_pretrained_kwargs = from_pretrained_kwargs self._head_kwargs = head_kwargs super().__init__(*args, **kwargs) def _build_model_heads(self, model_name: str): self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name) self._policy_model.__class__ = hf_generation_utils.override_generation_routines( type(self._policy_model) ) self._value_model = transformers.AutoModelForSeq2SeqLM.from_pretrained( model_name, **self._from_pretrained_kwargs) self._ref_model = copy.deepcopy(self._policy_model).eval() self._value_head = torch.nn.Linear( self._value_model.config.hidden_size, 1, bias=False, **self._head_kwargs, ).to(self.device) # apply model parallel if torch.cuda.is_available(): if self._apply_model_parallel and self._policy_model.is_parallelizable: self._policy_model.parallelize() self._ref_model.parallelize() self._value_model.parallelize() self._value_head = self._value_head.to(self.device) else: # else defaults to data parallel self._policy_model = torch.nn.DataParallel(self._policy_model.to(self.device)) self._ref_model = torch.nn.DataParallel(self._ref_model .to(self.device)) self._value_model = torch.nn.DataParallel(self._value_model .to(self.device)) self._value_head = torch.nn.DataParallel(self._value_head .to(self.device)) rl4lms_registry.PolicyRegistry.add( "precision_control_seq2seq_lm_actor_critic", PrecisionControlSeq2SeqLMActorCriticPolicy, )
hey, thanks for your solution!
I do have one question - is there a reason you didn't pass the from_pretrained_kwargs to the intialization of _policy_model
:
self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)
@JulesGM
@JulesGM
Hey, so I tried what you suggested, passing to the from_pretrained
**{"torch_dtype":torch.float16}
and to the Linear
**{"dtype": torch.float16} , and I got the following error:
Traceback (most recent call last):
File "/home/nlp/sloboda1/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/nlp/sloboda1/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in
I even tried passing **{"torch_dtype":torch.float16}
to the from_pretrained
of the self._poliy_model
and still got that error.
Is there anythying else you converted to FP16 by any chance?
yes I did a bunch of other changes in the end
May I ask if there is a complete code with changes that I can learn from?