None ref_model in ppo train
hi there. for the code https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py#L115, the ref model is set to None when not using peft. This seems to cause error below since the None ref_model is passed to https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py#L133.
therefore, can we assume that no matter under peft or not, the ref_model should always be there for inference (only) ?
response_tensors, ref_response_tensors = ppo_trainer.generate( File "/home/chenyanan/trl/trl/trainer/ppo_trainer.py", line 474, in generate ref_response = self._generate_batched( File "/home/chenyanan/trl/trl/trainer/ppo_trainer.py", line 546, in _generate_batched generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs) File "/home/chenyanan/trl/trl/models/modeling_value_head.py", line 203, in generate return self.pretrained_model.generate(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/transformers/generation/utils.py", line 1575, in generate result = self._sample( File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/transformers/generation/utils.py", line 2697, in _sample outputs = self( File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1157, in forward outputs = self.model( File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1042, in forward layer_outputs = decoder_layer( File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 757, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 653, in forward query_states = self.q_proj(hidden_states) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 687, in forward out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 562, in matmul return MatMul8bitLt.apply(A, B, out, bias, state) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 341, in forward state.CxB, state.SB = F.transform(state.CB, to_order=formatB) File "/home/chenyanan/anaconda3/envs/trl/lib/python3.10/site-packages/bitsandbytes/functional.py", line 2255, in transform prev_device = pre_call(A.device) AttributeError: 'NoneType' object has no attribute 'device'