RL4LMs
RL4LMs copied to clipboard
'GPT2Model' object has no attribute 'first_device'
I get the following error when running python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/dialog/gpt2_ppo.yml
. I have double-checked that transformers==4.18.0.
Traceback (most recent call last):
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 84, in <module>
main(
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 55, in main
trainer.train_and_eval()
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 232, in train_and_eval
self._alg.learn(self._n_steps_per_iter)
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/algorithms/ppo/ppo.py", line 342, in learn
return super().learn(
File "/opt/anaconda3/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 247, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 384, in collect_rollouts
rollout_info = self.generate_batch(
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 159, in generate_batch
gen_output = self.policy.generate(
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py", line 230, in generate
inputs=input_ids.to(self.get_policy_first_device()),
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/causal_policy.py", line 259, in get_policy_first_device
self._policy_model.transformer.first_device
File "/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1185, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GPT2Model' object has no attribute 'first_device'