unsloth icon indicating copy to clipboard operation
unsloth copied to clipboard

Error with PPO training about hidden state in-place modification

Open Robinysh opened this issue 1 year ago • 12 comments

Code to reproduce:

import trl
from unsloth import FastLanguageModel
import torch
from tqdm import tqdm

from transformers import AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

config = PPOConfig(
    model_name="mistralai/Mistral-7B-Instruct-v0.2",
    learning_rate=1.41e-5,
)

LOG_FREQ = 32
LORA_ALPHA = 32
LORA_RANK = 32

def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train[:10%]")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

def get_unsloth_model(base_model_name):
    model, _ = FastLanguageModel.from_pretrained(
        model_name=base_model_name,
        max_seq_length=2048,
        load_in_4bit=True,
        device_map="auto",
        trust_remote_code=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
            "o_proj",  # attention (self_attn)
            "gate_proj",
            "down_proj",
            "up_proj",  # FFN (mlp)
        ],
        r=LORA_RANK,
        lora_alpha=LORA_ALPHA,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing=False,
    )
    trl.trainer.peft_module_casting_to_bf16(model)
    return model

model = AutoModelForCausalLMWithValueHead.from_pretrained(get_unsloth_model(config.model_name))
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(get_unsloth_model(config.model_name))
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    response_tensors = [ppo_trainer.generate(x, **generation_kwargs).squeeze() for x in query_tensors]
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    rewards = [torch.tensor(0.5) for _ in range(len(response_tensors))]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

Output:

Unsloth: You passed in `mistralai/Mistral-7B-Instruct-v0.2` and `load_in_4bit = True`.
We shall load `unsloth/mistral-7b-instruct-v0.2-bnb-4bit` for 4x faster loading.
==((====))==  Unsloth: Fast Mistral patching release 2024.3
   \\   [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)|    GPU: NVIDIA GeForce RTX 3090. Max memory: 23.681 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
WARNING:root:The `device_map` argument is not provided. We will override the device_map argument. to set the entire model on the current device. If you want to set the model on multiple devices, please provide a custom `device_map` argument.
==((====))==  Unsloth: Fast Mistral patching release 2024.3
   \\   [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)|    GPU: NVIDIA GeForce RTX 3090. Max memory: 23.681 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        [/](https://vscode-remote+ssh-002dremote-002bdesktop.vscode-resource.vscode-cdn.net/)    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
WARNING:root:The `device_map` argument is not provided. We will override the device_map argument. to set the entire model on the current device. If you want to set the model on multiple devices, please provide a custom `device_map` argument.
0it [00:55, ?it/s]

RuntimeError                              Traceback (most recent call last)
Cell In[1], line 99
     96 rewards = [torch.tensor(0.5) for _ in range(len(response_tensors))]
     98 #### Run PPO step
---> 99 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    100 ppo_trainer.log_stats(stats, batch, rewards)

File /usr/lib/python3.11/contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 @wraps(func)
     79 def inner(*args, **kwds):
     80     with self._recreate_cm():
---> 81         return func(*args, **kwds)

File ~/code/user/.venv/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py:788, in PPOTrainer.step(self, queries, responses, scores, response_masks)
    785 with self.accelerator.accumulate(self.model):
    786     model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
--> 788     logprobs, logits, vpreds, _ = self.batched_forward_pass(
    789         self.model,
    790         mini_batch_dict[\"queries\"],
    791         mini_batch_dict[\"responses\"],
    792         model_inputs,
    793         return_logits=True,
    794     )
    795     train_stats = self.train_minibatch(
    796         mini_batch_dict[\"logprobs\"],
    797         mini_batch_dict[\"values\"],
   (...)
    803         mini_batch_dict[\"returns\"],
    804     )
    805     all_stats.append(train_stats)

File /usr/lib/python3.11/contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 @wraps(func)
     79 def inner(*args, **kwds):
     80     with self._recreate_cm():
---> 81         return func(*args, **kwds)

File ~/code/user/.venv/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py:984, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
    982 if response_masks is not None:
    983     response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
--> 984 logits, _, values = model(**input_kwargs)
    985 #logits, _, values = model(**input_kwargs, use_cache=False)
    987 if self.is_encoder_decoder:

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1561, in Module._call_impl(self, *args, **kwargs)
   1558     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1559     args = bw_hook.setup_input_hook(args)
-> 1561 result = forward_call(*args, **kwargs)
   1562 if _global_forward_hooks or self._forward_hooks:
   1563     for hook_id, hook in (
   1564         *_global_forward_hooks.items(),
   1565         *self._forward_hooks.items(),
   1566     ):
   1567         # mark that always called hook is run

File ~/code/user/.venv/lib/python3.11/site-packages/trl/models/modeling_value_head.py:170, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
    167 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == \"PREFIX_TUNING\":
    168     kwargs.pop(\"past_key_values\")
--> 170 base_model_output = self.pretrained_model(
    171     input_ids=input_ids,
    172     attention_mask=attention_mask,
    173     **kwargs,
    174 )
    176 last_hidden_state = base_model_output.hidden_states[-1]
    177 lm_logits = base_model_output.logits

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/llama.py:857, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
    844 def PeftModelForCausalLM_fast_forward(
    845     self,
    846     input_ids=None,
   (...)
    855     **kwargs,
    856 ):
--> 857     return self.base_model(
    858         input_ids=input_ids,
    859         causal_mask=causal_mask,
    860         attention_mask=attention_mask,
    861         inputs_embeds=inputs_embeds,
    862         labels=labels,
    863         output_attentions=output_attentions,
    864         output_hidden_states=output_hidden_states,
    865         return_dict=return_dict,
    866         **kwargs,
    867     )

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/code/user/.venv/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File ~/code/user/.venv/lib/python3.11/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/mistral.py:212, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    204     outputs = LlamaModel_fast_forward_inference(
    205         self,
    206         input_ids,
   (...)
    209         attention_mask = attention_mask,
    210     )
    211 else:
--> 212     outputs = self.model(
    213         input_ids=input_ids,
    214         causal_mask=causal_mask,
    215         attention_mask=attention_mask,
    216         position_ids=position_ids,
    217         past_key_values=past_key_values,
    218         inputs_embeds=inputs_embeds,
    219         use_cache=use_cache,
    220         output_attentions=output_attentions,
    221         output_hidden_states=output_hidden_states,
    222         return_dict=return_dict,
    223     )
    224 pass
    226 hidden_states = outputs[0]

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/code/user/.venv/lib/python3.11/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/llama.py:655, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    645     layer_outputs = torch.utils.checkpoint.checkpoint(
    646         create_custom_forward(decoder_layer),
    647         hidden_states,
   (...)
    652         preserve_rng_state=False,
    653     )
    654 else:
--> 655     layer_outputs = decoder_layer(
    656         hidden_states,
    657         causal_mask=causal_mask,
    658         attention_mask=attention_mask,
    659         position_ids=position_ids,
    660         past_key_value=past_key_value,
    661         output_attentions=output_attentions,
    662         use_cache=use_cache,
    663         padding_mask=padding_mask,
    664     )
    665 pass
    667 hidden_states = layer_outputs[0]

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/code/user/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/code/user/.venv/lib/python3.11/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/code/user/.venv/lib/python3.11/site-packages/unsloth/models/llama.py:423, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
    412 hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
    413 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    414     hidden_states=hidden_states,
    415     causal_mask=causal_mask,
   (...)
    421     padding_mask=padding_mask,
    422 )
--> 423 hidden_states += residual
    425 # Fully Connected
    426 residual = hidden_states

RuntimeError: Output 0 of LoRA_WBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function."

Robinysh avatar Apr 04 '24 17:04 Robinysh

@Robinysh Thanks for the report - ill take a look - sorry on the delayed response!

danielhanchen avatar Apr 05 '24 17:04 danielhanchen

Hi @danielhanchen, same thing happened when i performed a custom pytorch training with gemma model. To be specific, Line 90 of gemma.py

shizheng-rlfresh avatar Apr 10 '24 21:04 shizheng-rlfresh

@Robinysh @shizheng-rlfresh OOHHH I actually never tried PPO, but because it's generating on the fly as well, hence the inplace issue

danielhanchen avatar Apr 11 '24 09:04 danielhanchen

Could you please try wrapping your generate call in this context manager and report results? It might still fail on backprop, but if it doesn't it'd be helpful to know.

class fast_eval_mode:
    """
    Convert to model.eval(), then revert to previous state

    Behavior
    - DOESNT disable grad
    - Disable dropout layers
    - Freeze BatchNorm
    """
    def __init__(self, model):
        self.model = model

    def __enter__(self):
        self.was_training = self.model.training
        if self.was_training:
            self.model.eval()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.was_training:
            self.model.train()

lapp0 avatar May 02 '24 14:05 lapp0

Encounter the same error

RuntimeError: Output 0 of LoRA_WBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

haosdent avatar May 03 '24 10:05 haosdent

Try the approach you provided, still encounter the same error.

haosdent avatar May 03 '24 10:05 haosdent

@haosdent here's how I fixed the same error in my PPO code. I think this solution will work for trl's PPOTrainer too:

Wrap these lines with this context manager:

class disable_caching(ContextDecorator):  # noqa: N801
    def __init__(self, model):
        self.model = model
        self.prev_value: Any = "UNSET"  # config values may be T/F/None

    def __enter__(self):
        self.prev_value = self.model.config.use_cache
        self.model.config.use_cache = False

    def __exit__(self, *exc):
        if self.prev_value != "UNSET":
            self.model.config.use_cache = self.prev_value
        self.prev_value = "UNSET"
# @line 798 of ppo_trainer.py
                    with disable_caching(self.model):
                          logprobs, logits, vpreds, _ = self.batched_forward_pass(
                              self.model,
                              mini_batch_dict["queries"],
                              mini_batch_dict["responses"],
                              model_inputs,
                              return_logits=True,
                          )

I believe the model config defaults to use_cache=True (which is great for sampling from the policy), but we need to disable it when doing the forward pass for the loss.

sidnarayanan avatar May 03 '24 16:05 sidnarayanan

Cool, thanks a lot for your solution @sidnarayanan . it works like a charm.

Also I have to call trl.trainer.peft_module_casting_to_bf16(model) to avoid the bfloat error

haosdent avatar May 04 '24 08:05 haosdent

Oh if any of you are willing to do a PR to fix the issue, that'll be awesome :) Thanks again!

danielhanchen avatar May 04 '24 10:05 danielhanchen

You mean send a PR to trl, right? @danielhanchen

haosdent avatar May 04 '24 10:05 haosdent

Ye if possible - another way is to directly inject it via Unsloth

danielhanchen avatar May 04 '24 10:05 danielhanchen

After check the code, I feel maybe we should update unsloth instead, would get back later. @danielhanchen

haosdent avatar May 06 '24 02:05 haosdent