RL4LMs
RL4LMs copied to clipboard
Numbeams
Tried to set num_beams parameter for generation to 3, but got an error
config:
tokenizer:
model_name: "t5-base"
padding_side: right
truncation_side: right
truncation: True
padding: True
max_length: 128
# pad_token_as_eos_token: False
reward_fn:
id: meteor
datapool:
id: wmt16
args:
train_path: "data/train.csv"
eval_path: "data/eval.csv"
test_path: "data/test.xlsx"
env:
n_envs: 10
args:
max_prompt_length: 128
max_episode_length: 128
terminate_on_eos: True
prompt_truncation_side: "right"
context_start_token: 0
alg:
id: ppo
args:
n_steps: 2
batch_size: 20
verbose: 2
learning_rate: 0.000001
n_epochs: 5
ent_coef: 0.0
kl_div:
coeff: 0.001
target_kl: 0.2
policy:
id: seq2seq_lm_actor_critic_policy
args:
model_name: "t5-base"
apply_model_parallel: True
prompt_truncation_side: "right"
generation_kwargs:
do_sample: True
num_beams: 3
max_length: 128
length_penalty: 0.85
repetition_penalty: 2.0
max_new_tokens: 128
train_evaluation:
eval_batch_size: 1
n_iters: 10
eval_every: 10
save_every: 1
metrics:
- id: meteor
args: {}
- id: sacre_bleu
args:
tokenize: "intl"
generation_kwargs:
do_sample: True
num_beams: 3
max_length: 128
length_penalty: 0.85
max_new_tokens: 128
repetition_penalty: 2.0
error:
Evaluating: 0%| | 0/1 [00:01<?, ?it/s]
Traceback (most recent call last):
File "scripts/training/train_text_generation.py", line 71, in <module>
args.log_to_wandb)
File "scripts/training/train_text_generation.py", line 42, in main
trainer.train_and_eval()
File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 198, in train_and_eval
self._evaluate_on_datapools(epoch=iter_start)
File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 193, in _evaluate_on_datapools
gen_kwargs=self._eval_gen_kwargs)
File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 41, in evaluate_on_samples
dt_control_token, gen_kwargs)
File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 99, in generate_text
gen_kwargs=gen_kwargs)["gen_texts"]
File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/policy.py", line 324, in generate
log_probs = distribution.log_prob(actions_at_step)
File "/home/user/conda/lib/python3.7/site-packages/torch/distributions/categorical.py", line 117, in log_prob
self._validate_sample(value)
File "/home/user/conda/lib/python3.7/site-packages/torch/distributions/distribution.py", line 277, in _validate_sample
format(actual_shape, expected_shape))
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([10]) vs torch.Size([30]).