RL4LMs icon indicating copy to clipboard operation
RL4LMs copied to clipboard

OOM on summarization example

Open gabrielhuang opened this issue 2 years ago • 15 comments

Hi there, I'm having OOM errors when running the summarization example on a 80GB A100 (CUDA 11.8).

I'm also getting some Tensorflow/TensorRT warnings, I'm wondering if it's related to that

2022-11-08 22:44:46.878785: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-08 22:44:47.016183: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-08 22:44:47.979748: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979824: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-11-08 22:44:47.979834: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

OOM error:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│                                                                              │
│ /mnt/home/code/RL4LMs/scripts/training/train_text_generation.py:66 in        │
│ <module>                                                                     │
│                                                                              │
│   63 │   │   │   │   │   │   help="Whether to use wandb logging")            │
│   64 │   args = parser.parse_args()                                          │
│   65 │                                                                       │
│ ❱ 66 │   main(args.config_path,                                              │
│   67 │   │    args.project_name,                                             │
│   68 │   │    args.experiment_name,                                          │
│   69 │   │    args.base_path_to_store_results,                               │
│ /mnt/home/code/RL4LMs/scripts/training/train_text_generation.py:42 in main   │
│                                                                              │
│   39 │   │   │   │   │   │   │   │     on_policy_alg_config=config["alg"],   │
│   40 │   │   │   │   │   │   │   │     train_eval_config=config["train_evalu │
│   41 │   │   │   │   │   │   │   │     tracker=tracker)                      │
│ ❱ 42 │   trainer.train_and_eval()                                            │
│   43                                                                         │
│   44                                                                         │
│   45 if __name__ == "__main__":                                              │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/training_utils.py:205 in   │
│ train_and_eval                                                               │
│                                                                              │
│   202 │   │   │   self._trainer_state["current_iter"] = epoch                │
│   203 │   │   │                                                              │
│   204 │   │   │   # inner rollout and learn loop for on-policy algorithm     │
│ ❱ 205 │   │   │   self._alg.learn(self._n_steps_per_iter)                    │
│   206 │   │   │                                                              │
│   207 │   │   │   # save the policy checkpoint                               │
│   208 │   │   │   if (epoch + 1) % self._train_eval_config.get("save_every", │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ppo.py:347 in learn              │
│                                                                              │
│   344 │   │   reset_num_timesteps: bool = True,                              │
│   345 │   ) -> "PPO":                                                        │
│   346 │   │                                                                  │
│ ❱ 347 │   │   return super().learn(                                          │
│   348 │   │   │   total_timesteps=total_timesteps,                           │
│   349 │   │   │   callback=callback,                                         │
│   350 │   │   │   log_interval=log_interval,                                 │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/stable_baselines3/common/on │
│ _policy_algorithm.py:267 in learn                                            │
│                                                                              │
│   264 │   │   │   │   self.logger.record("time/total_timesteps", self.num_ti │
│   265 │   │   │   │   self.logger.dump(step=self.num_timesteps)              │
│   266 │   │   │                                                              │
│ ❱ 267 │   │   │   self.train()                                               │
│   268 │   │                                                                  │
│   269 │   │   callback.on_training_end()                                     │
│   270                                                                        │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/algorithms/ppo/ppo.py:224 in train              │
│                                                                              │
│   221 │   │   │   │   if self.use_sde:                                       │
│   222 │   │   │   │   │   self.policy.reset_noise(self.batch_size)           │
│   223 │   │   │   │                                                          │
│ ❱ 224 │   │   │   │   values, log_prob, entropy = self.policy.evaluate_actio │
│   225 │   │   │   │   │   rollout_data.observations, actions)                │
│   226 │   │   │   │   values = values.flatten()                              │
│   227 │   │   │   │   # Normalize advantage                                  │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/policy.py:211 in           │
│ evaluate_actions                                                             │
│                                                                              │
│    208 │   │                                                                 │
│    209 │   │   _, log_prob, entropy, _, _ = self.forward_policy(obs=obs,     │
│    210 │   │   │   │   │   │   │   │   │   │   │   │   │   │    actions=acti │
│ ❱  211 │   │   values, _ = self.forward_value(obs)                           │
│    212 │   │                                                                 │
│    213 │   │   return values, log_prob, entropy                              │
│    214                                                                       │
│                                                                              │
│ /mnt/home/code/RL4LMs/rl4lms/envs/text_generation/policy.py:447 in           │
│ forward_value                                                                │
│                                                                              │
│    444 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │     │
│    445 │   │                                                                 │
│    446 │   │   # and forrward pass to get hidden states                      │
│ ❱  447 │   │   outputs = self._value_model(                                  │
│    448 │   │   │   **model_inputs,                                           │
│    449 │   │   │   output_hidden_states=True,                                │
│    450 │   │   │   decoder_attention_mask=decoder_attn_mask,                 │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:1648 in forward                                                   │
│                                                                              │
│   1645 │   │   │   │   decoder_attention_mask = decoder_attention_mask.to(se │
│   1646 │   │                                                                 │
│   1647 │   │   # Decode                                                      │
│ ❱ 1648 │   │   decoder_outputs = self.decoder(                               │
│   1649 │   │   │   input_ids=decoder_input_ids,                              │
│   1650 │   │   │   attention_mask=decoder_attention_mask,                    │
│   1651 │   │   │   inputs_embeds=decoder_inputs_embeds,                      │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:1040 in forward                                                   │
│                                                                              │
│   1037 │   │   │   │   │   None,  # past_key_value is always None with gradi │
│   1038 │   │   │   │   )                                                     │
│   1039 │   │   │   else:                                                     │
│ ❱ 1040 │   │   │   │   layer_outputs = layer_module(                         │
│   1041 │   │   │   │   │   hidden_states,                                    │
│   1042 │   │   │   │   │   attention_mask=extended_attention_mask,           │
│   1043 │   │   │   │   │   position_bias=position_bias,                      │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:699 in forward                                                    │
│                                                                              │
│    696 │   │   │   else:                                                     │
│    697 │   │   │   │   query_length = None                                   │
│    698 │   │   │                                                             │
│ ❱  699 │   │   │   cross_attention_outputs = self.layer[1](                  │
│    700 │   │   │   │   hidden_states,                                        │
│    701 │   │   │   │   key_value_states=encoder_hidden_states,               │
│    702 │   │   │   │   attention_mask=encoder_attention_mask,                │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:613 in forward                                                    │
│                                                                              │
│    610 │   │   output_attentions=False,                                      │
│    611 │   ):                                                                │
│    612 │   │   normed_hidden_states = self.layer_norm(hidden_states)         │
│ ❱  613 │   │   attention_output = self.EncDecAttention(                      │
│    614 │   │   │   normed_hidden_states,                                     │
│    615 │   │   │   mask=attention_mask,                                      │
│    616 │   │   │   key_value_states=key_value_states,                        │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py: │
│ 1110 in _call_impl                                                           │
│                                                                              │
│   1107 │   │   # this function, and just call forward.                       │
│   1108 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1109 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1110 │   │   │   return forward_call(*input, **kwargs)                     │
│   1111 │   │   # Do not call functions when jit is used                      │
│   1112 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1113 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /mnt/home/miniconda3/lib/python3.9/site-packages/transformers/models/t5/mode │
│ ling_t5.py:509 in forward                                                    │
│                                                                              │
│    506 │   │   )                                                             │
│    507 │   │                                                                 │
│    508 │   │   # compute scores                                              │
│ ❱  509 │   │   scores = torch.matmul(                                        │
│    510 │   │   │   query_states, key_states.transpose(3, 2)                  │
│    511 │   │   )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_stat │
│    512                                                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA out of memory. Tried to allocate 150.00 MiB (GPU 0; 79.35 GiB
total capacity; 76.08 GiB already allocated; 108.19 MiB free; 77.20 GiB reserved
in total by PyTorch) If reserved memory is >> allocated memory try setting
max_split_size_mb to avoid fragmentation.  See documentation for Memory
Management and PYTORCH_CUDA_ALLOC_CONF

Any clues what's the issue? 80GB seems like a lot for just a T5-base model

gabrielhuang avatar Nov 08 '22 23:11 gabrielhuang

Hey, this is expected I think. Keep in mind we have three t5 models loaded into memory. policy, value and reference policy. So far, we have used 4 GPUs to run summarization tasks. If you have just one GPU, just try to reduce the n_envs to a lower value. Also reduce the batch size of PPO. Otherwise, I would suggest running with more GPUs if possible.

rajcscw avatar Nov 09 '22 08:11 rajcscw

I see thanks. I didn't expect there to be three full models. Any advantages vs. plugging three heads onto one language model trunk?

Also, I'm just curious, does the n_envs parameter scale up the batch size? Why does it have influence on GPU memory?

Many thanks.

gabrielhuang avatar Nov 09 '22 23:11 gabrielhuang

yes people seem to usually just have different heads

JulesGM avatar Nov 10 '22 06:11 JulesGM

@gabrielhuang Sure, shared layers will be paramter efficient. I am not sure how much performance change it will bring in.

Regarding n_envs, it controls batch_size for generating rollouts. You can think of it as a batched generation.

rajcscw avatar Nov 11 '22 14:11 rajcscw

I'm trying to get google/flan-t5-xxl to run with a single A100 80GB gpu, for seq2seq policy.

Is there already a way to set the precision to bfloat16? (I don't see one, but just to be sure) If not I'll write a policy for that.

Also, will try to add model sharing between the policy and the value models, and allowing to freeze parts of the model.

JulesGM avatar Nov 29 '22 23:11 JulesGM

Enabling offloading a model from GPU memory to CPU memory when it's not in use would likely be helpful too.

JulesGM avatar Nov 29 '22 23:11 JulesGM

@gabrielhuang have you started doing work like this? (I'm also at Mila)

JulesGM avatar Nov 29 '22 23:11 JulesGM

@JulesGM We don't have support for precision setting yet. You can implement this in a new policy and possibly we can try to merge this into existing classes (by configuring some args)

rajcscw avatar Dec 01 '22 09:12 rajcscw

This is my current approach, indeed, just allowing the user to pass kwargs for from_pretrained and Linear. Passing torch_dtype to from_pretrained and dtype to Linear works.

I suppose adding amp mixed precision auto-casting and gradient scaling would be a good / important idea though.


import copy
import sys

import torch
import transformers

sys.path.append("/home/mila/g/gagnonju/RL4LMs")
import rl4lms.envs.text_generation.registry as rl4lms_registry
import rl4lms.envs.text_generation.policy.seq2seq_policy as rl4lms_seq2seq_policy
from rl4lms.envs.text_generation import hf_generation_utils 


class PrecisionControlSeq2SeqLMActorCriticPolicy(rl4lms_seq2seq_policy.Seq2SeqLMActorCriticPolicy):
    def __init__(
        self,
        *args,
        from_pretrained_kwargs,
        head_kwargs,
        **kwargs,
    ):
        
        self._from_pretrained_kwargs = from_pretrained_kwargs
        self._head_kwargs = head_kwargs

        super().__init__(*args, **kwargs)

    def _build_model_heads(self, model_name: str):
        self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self._policy_model.__class__ = hf_generation_utils.override_generation_routines(
            type(self._policy_model)
        )

        self._value_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
            model_name, **self._from_pretrained_kwargs)
        self._ref_model = copy.deepcopy(self._policy_model).eval()

        self._value_head = torch.nn.Linear(
            self._value_model.config.hidden_size, 1, bias=False, **self._head_kwargs,
        ).to(self.device)

        # apply model parallel
        if torch.cuda.is_available():
            if self._apply_model_parallel and self._policy_model.is_parallelizable:
                self._policy_model.parallelize()
                self._ref_model.parallelize()
                self._value_model.parallelize()
                self._value_head = self._value_head.to(self.device)
            else:  # else defaults to data parallel
                self._policy_model = torch.nn.DataParallel(self._policy_model.to(self.device))
                self._ref_model    = torch.nn.DataParallel(self._ref_model   .to(self.device))
                self._value_model  = torch.nn.DataParallel(self._value_model .to(self.device))
                self._value_head   = torch.nn.DataParallel(self._value_head  .to(self.device))



rl4lms_registry.PolicyRegistry.add(
    "precision_control_seq2seq_lm_actor_critic",
    PrecisionControlSeq2SeqLMActorCriticPolicy,
)

JulesGM avatar Dec 01 '22 20:12 JulesGM

looks like stable baselines 3 doesn't support bfloat16, because of all the a_tensor_name.cpu().numpy() calls. Indeed, doing that with a bfloat16 tensor leads to an exception, because torch tries to build a numpy array with the bfloat16 dtype, which is not supported by Numpy

JulesGM avatar Dec 03 '22 21:12 JulesGM

in order for baselines 3 (and then rl4lms) to support bfloat16, it would suffice to modify a_tensor_name.cpu().numpy() to a_tensor_name.cpu().float().numpy().

JulesGM avatar Dec 03 '22 21:12 JulesGM

This is my current approach, indeed, just allowing the user to pass kwargs for from_pretrained and Linear. Passing torch_dtype to from_pretrained and dtype to Linear works.

I suppose adding amp mixed precision auto-casting and gradient scaling would be a good / important idea though.

import copy
import sys

import torch
import transformers

sys.path.append("/home/mila/g/gagnonju/RL4LMs")
import rl4lms.envs.text_generation.registry as rl4lms_registry
import rl4lms.envs.text_generation.policy.seq2seq_policy as rl4lms_seq2seq_policy
from rl4lms.envs.text_generation import hf_generation_utils 


class PrecisionControlSeq2SeqLMActorCriticPolicy(rl4lms_seq2seq_policy.Seq2SeqLMActorCriticPolicy):
    def __init__(
        self,
        *args,
        from_pretrained_kwargs,
        head_kwargs,
        **kwargs,
    ):
        
        self._from_pretrained_kwargs = from_pretrained_kwargs
        self._head_kwargs = head_kwargs

        super().__init__(*args, **kwargs)

    def _build_model_heads(self, model_name: str):
        self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self._policy_model.__class__ = hf_generation_utils.override_generation_routines(
            type(self._policy_model)
        )

        self._value_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
            model_name, **self._from_pretrained_kwargs)
        self._ref_model = copy.deepcopy(self._policy_model).eval()

        self._value_head = torch.nn.Linear(
            self._value_model.config.hidden_size, 1, bias=False, **self._head_kwargs,
        ).to(self.device)

        # apply model parallel
        if torch.cuda.is_available():
            if self._apply_model_parallel and self._policy_model.is_parallelizable:
                self._policy_model.parallelize()
                self._ref_model.parallelize()
                self._value_model.parallelize()
                self._value_head = self._value_head.to(self.device)
            else:  # else defaults to data parallel
                self._policy_model = torch.nn.DataParallel(self._policy_model.to(self.device))
                self._ref_model    = torch.nn.DataParallel(self._ref_model   .to(self.device))
                self._value_model  = torch.nn.DataParallel(self._value_model .to(self.device))
                self._value_head   = torch.nn.DataParallel(self._value_head  .to(self.device))



rl4lms_registry.PolicyRegistry.add(
    "precision_control_seq2seq_lm_actor_critic",
    PrecisionControlSeq2SeqLMActorCriticPolicy,
)

hey, thanks for your solution! I do have one question - is there a reason you didn't pass the from_pretrained_kwargs to the intialization of _policy_model :

self._policy_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name)

@JulesGM

lovodkin93 avatar Dec 29 '22 10:12 lovodkin93

@JulesGM Hey, so I tried what you suggested, passing to the from_pretrained **{"torch_dtype":torch.float16} and to the Linear **{"dtype": torch.float16} , and I got the following error:

Traceback (most recent call last): File "/home/nlp/sloboda1/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/nlp/sloboda1/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in cli.main() File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="main") File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/nlp/sloboda1/.vscode-server/extensions/ms-python.python-2022.20.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/scripts/training/train_text_generation.py", line 93, in main( File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/scripts/training/train_text_generation.py", line 64, in main trainer.train_and_eval() File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 214, in train_and_eval self._alg.learn(self._n_steps_per_iter) File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/rl4lms/algorithms/ppo/ppo.py", line 341, in learn return super().learn( File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/venvs/RL4LMs_venv/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 267, in learn self.train() File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/rl4lms/algorithms/ppo/ppo.py", line 288, in train loss.backward() File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/venvs/RL4LMs_venv/lib/python3.9/site-packages/torch/_tensor.py", line 363, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/home/nlp/sloboda1/controlled_reduction/DL_approach/RL4LMs/venvs/RL4LMs_venv/lib/python3.9/site-packages/torch/autograd/init.py", line 173, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Found dtype Float but expected Half

I even tried passing **{"torch_dtype":torch.float16} to the from_pretrained of the self._poliy_model and still got that error.

Is there anythying else you converted to FP16 by any chance?

lovodkin93 avatar Dec 29 '22 11:12 lovodkin93

yes I did a bunch of other changes in the end

JulesGM avatar Jan 03 '23 15:01 JulesGM

May I ask if there is a complete code with changes that I can learn from?

CathyKitten avatar Dec 16 '23 11:12 CathyKitten