unsloth
                                
                                
                                
                                    unsloth copied to clipboard
                            
                            
                            
                        Error with PPO training about hidden state in-place modification
Code to reproduce:
import trl
from unsloth import FastLanguageModel
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
config = PPOConfig(
    model_name="mistralai/Mistral-7B-Instruct-v0.2",
    learning_rate=1.41e-5,
)
LOG_FREQ = 32
LORA_ALPHA = 32
LORA_RANK = 32
def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train[:10%]")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)
    input_size = LengthSampler(input_min_text_length, input_max_text_length)
    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample
    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds
dataset = build_dataset(config)
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
def get_unsloth_model(base_model_name):
    model, _ = FastLanguageModel.from_pretrained(
        model_name=base_model_name,
        max_seq_length=2048,
        load_in_4bit=True,
        device_map="auto",
        trust_remote_code=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
            "o_proj",  # attention (self_attn)
            "gate_proj",
            "down_proj",
            "up_proj",  # FFN (mlp)
        ],
        r=LORA_RANK,
        lora_alpha=LORA_ALPHA,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing=False,
    )
    trl.trainer.peft_module_casting_to_bf16(model)
    return model
model = AutoModelForCausalLMWithValueHead.from_pretrained(get_unsloth_model(config.model_name))
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(get_unsloth_model(config.model_name))
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    response_tensors = [ppo_trainer.generate(x, **generation_kwargs).squeeze() for x in query_tensors]
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    rewards = [torch.tensor(0.5) for _ in range(len(response_tensors))]
    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
Output:
Unsloth: You passed in `mistralai/Mistral-7B-Instruct-v0.2` and `load_in_4bit = True`.
We shall load `unsloth/mistral-7b-instruct-v0.2-bnb-4bit` for 4x faster loading.
==((====))==  Unsloth: Fast Mistral patching release 2024.3
   \\   [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)|    GPU: NVIDIA GeForce RTX 3090. Max memory: 23.681 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
WARNING:root:The `device_map` argument is not provided. We will override the device_map argument. to set the entire model on the current device. If you want to set the model on multiple devices, please provide a custom `device_map` argument.
==((====))==  Unsloth: Fast Mistral patching release 2024.3
   \\   [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)|    GPU: NVIDIA GeForce RTX 3090. Max memory: 23.681 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
WARNING:root:The `device_map` argument is not provided. We will override the device_map argument. to set the entire model on the current device. If you want to set the model on multiple devices, please provide a custom `device_map` argument.
0it [00:55, ?it/s]
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 99
     96 rewards = [torch.tensor(0.5) for _ in range(len(response_tensors))]
     98 #### Run PPO step
---> 99 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    100 ppo_trainer.log_stats(stats, batch, rewards)
File /usr/lib/python3.11/contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 @wraps(func)
     79 def inner(*args, **kwds):
     80     with self._recreate_cm():
---> 81         return func(*args, **kwds)
File ~/code/user/.venv/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py:788, in PPOTrainer.step(self, queries, responses, scores, response_masks)
    785 with self.accelerator.accumulate(self.model):
    786     model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
--> 788     logprobs, logits, vpreds, _ = self.batched_forward_pass(
    789         self.model,
    790         mini_batch_dict[\"queries\"],
    791         mini_batch_dict[\"responses\"],
    792         model_inputs,
    793         return_logits=True,
    794     )
    795     train_stats = self.train_minibatch(
    796         mini_batch_dict[\"logprobs\"],
    797         mini_batch_dict[\"values\"],
   (...)
    803         mini_batch_dict[\"returns\"],
    804     )
    805     all_stats.append(train_stats)
File /usr/lib/python3.11/contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 @wraps(func)
     79 def inner(*args, **kwds):
     80     with self._recreate_cm():
---> 81         return func(*args, **kwds)
File ~/code/user/.venv/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py:984, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
    982 if response_masks is not None:
    983     response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
--> 984 logits, _, values = model(**input_kwargs)
    985 #logits, _, values = model(**input_kwargs, use_cache=False)
    987 if self.is_encoder_decoder:
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1561, in Module._call_impl(self, *args, **kwargs)
   1558     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1559     args = bw_hook.setup_input_hook(args)
-> 1561 result = forward_call(*args, **kwargs)
   1562 if _global_forward_hooks or self._forward_hooks:
   1563     for hook_id, hook in (
   1564         *_global_forward_hooks.items(),
   1565         *self._forward_hooks.items(),
   1566     ):
   1567         # mark that always called hook is run
File ~/code/user/.venv/lib/python3.11/site-packages/trl/models/modeling_value_head.py:170, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
    167 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == \"PREFIX_TUNING\":
    168     kwargs.pop(\"past_key_values\")
--> 170 base_model_output = self.pretrained_model(
    171     input_ids=input_ids,
    172     attention_mask=attention_mask,
    173     **kwargs,
    174 )
    176 last_hidden_state = base_model_output.hidden_states[-1]
    177 lm_logits = base_model_output.logits
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None
File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/llama.py:857, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
    844 def PeftModelForCausalLM_fast_forward(
    845     self,
    846     input_ids=None,
   (...)
    855     **kwargs,
    856 ):
--> 857     return self.base_model(
    858         input_ids=input_ids,
    859         causal_mask=causal_mask,
    860         attention_mask=attention_mask,
    861         inputs_embeds=inputs_embeds,
    862         labels=labels,
    863         output_attentions=output_attentions,
    864         output_hidden_states=output_hidden_states,
    865         return_dict=return_dict,
    866         **kwargs,
    867     )
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None
File ~/code/user/.venv/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)
File ~/code/user/.venv/lib/python3.11/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)
File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/mistral.py:212, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    204     outputs = LlamaModel_fast_forward_inference(
    205         self,
    206         input_ids,
   (...)
    209         attention_mask = attention_mask,
    210     )
    211 else:
--> 212     outputs = self.model(
    213         input_ids=input_ids,
    214         causal_mask=causal_mask,
    215         attention_mask=attention_mask,
    216         position_ids=position_ids,
    217         past_key_values=past_key_values,
    218         inputs_embeds=inputs_embeds,
    219         use_cache=use_cache,
    220         output_attentions=output_attentions,
    221         output_hidden_states=output_hidden_states,
    222         return_dict=return_dict,
    223     )
    224 pass
    226 hidden_states = outputs[0]
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None
File ~/code/user/.venv/lib/python3.11/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)
File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/llama.py:655, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    645     layer_outputs = torch.utils.checkpoint.checkpoint(
    646         create_custom_forward(decoder_layer),
    647         hidden_states,
   (...)
    652         preserve_rng_state=False,
    653     )
    654 else:
--> 655     layer_outputs = decoder_layer(
    656         hidden_states,
    657         causal_mask=causal_mask,
    658         attention_mask=attention_mask,
    659         position_ids=position_ids,
    660         past_key_value=past_key_value,
    661         output_attentions=output_attentions,
    662         use_cache=use_cache,
    663         padding_mask=padding_mask,
    664     )
    665 pass
    667 hidden_states = layer_outputs[0]
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)
File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None
File ~/code/user/.venv/lib/python3.11/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)
File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/llama.py:423, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
    412 hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
    413 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    414     hidden_states=hidden_states,
    415     causal_mask=causal_mask,
   (...)
    421     padding_mask=padding_mask,
    422 )
--> 423 hidden_states += residual
    425 # Fully Connected
    426 residual = hidden_states
RuntimeError: Output 0 of LoRA_WBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function."
                                    
                                    
                                    
                                
@Robinysh Thanks for the report - ill take a look - sorry on the delayed response!
Hi @danielhanchen, same thing happened when i performed a custom pytorch training with gemma model.
To be specific, Line 90 of gemma.py
@Robinysh @shizheng-rlfresh OOHHH I actually never tried PPO, but because it's generating on the fly as well, hence the inplace issue
Could you please try wrapping your generate call in this context manager and report results? It might still fail on backprop, but if it doesn't it'd be helpful to know.
class fast_eval_mode:
    """
    Convert to model.eval(), then revert to previous state
    Behavior
    - DOESNT disable grad
    - Disable dropout layers
    - Freeze BatchNorm
    """
    def __init__(self, model):
        self.model = model
    def __enter__(self):
        self.was_training = self.model.training
        if self.was_training:
            self.model.eval()
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.was_training:
            self.model.train()
                                    
                                    
                                    
                                
Encounter the same error
RuntimeError: Output 0 of LoRA_WBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.
                                    
                                    
                                    
                                
Try the approach you provided, still encounter the same error.
@haosdent here's how I fixed the same error in my PPO code. I think this solution will work for trl's PPOTrainer too:
Wrap these lines with this context manager:
class disable_caching(ContextDecorator):  # noqa: N801
    def __init__(self, model):
        self.model = model
        self.prev_value: Any = "UNSET"  # config values may be T/F/None
    def __enter__(self):
        self.prev_value = self.model.config.use_cache
        self.model.config.use_cache = False
    def __exit__(self, *exc):
        if self.prev_value != "UNSET":
            self.model.config.use_cache = self.prev_value
        self.prev_value = "UNSET"
# @line 798 of ppo_trainer.py
                    with disable_caching(self.model):
                          logprobs, logits, vpreds, _ = self.batched_forward_pass(
                              self.model,
                              mini_batch_dict["queries"],
                              mini_batch_dict["responses"],
                              model_inputs,
                              return_logits=True,
                          )
I believe the model config defaults to use_cache=True (which is great for sampling from the policy), but we need to disable it when doing the forward pass for the loss.
Cool, thanks a lot for your solution @sidnarayanan . it works like a charm.
Also I have to call trl.trainer.peft_module_casting_to_bf16(model) to avoid the bfloat error
Oh if any of you are willing to do a PR to fix the issue, that'll be awesome :) Thanks again!
You mean send a PR to trl, right? @danielhanchen
Ye if possible - another way is to directly inject it via Unsloth
After check the code, I feel maybe we should update unsloth instead, would get back later. @danielhanchen