RL4LMs icon indicating copy to clipboard operation
RL4LMs copied to clipboard

Numbeams

Open tatiana-iazykova opened this issue 1 year ago • 5 comments

Tried to set num_beams parameter for generation to 3, but got an error

config:

tokenizer:
  model_name: "t5-base"
  padding_side: right
  truncation_side: right
  truncation: True
  padding: True
  max_length: 128
  # pad_token_as_eos_token: False

reward_fn:
  id: meteor 
  
datapool:
  id: wmt16
  args:
    train_path: "data/train.csv"
    eval_path: "data/eval.csv"
    test_path: "data/test.xlsx"


env:
  n_envs: 10
  args:
    max_prompt_length: 128
    max_episode_length: 128
    terminate_on_eos: True
    prompt_truncation_side: "right"
    context_start_token: 0

alg:
  id: ppo
  args: 
    n_steps: 2
    batch_size: 20
    verbose: 2
    learning_rate: 0.000001
    n_epochs: 5
    ent_coef: 0.0
  kl_div:
    coeff: 0.001
    target_kl: 0.2
  policy:
    id: seq2seq_lm_actor_critic_policy
    args:
      model_name: "t5-base"
      apply_model_parallel: True
      prompt_truncation_side: "right"
      generation_kwargs:
        do_sample: True
        num_beams: 3
        max_length: 128
        length_penalty: 0.85
        repetition_penalty: 2.0
        max_new_tokens: 128

    
train_evaluation:
  eval_batch_size: 1
  n_iters: 10
  eval_every: 10
  save_every: 1
  metrics:
    - id: meteor
      args: {}
    - id: sacre_bleu
      args:
        tokenize: "intl"
  generation_kwargs:
    do_sample: True
    num_beams: 3
    max_length: 128
    length_penalty: 0.85
    max_new_tokens: 128
    repetition_penalty: 2.0

error:

Evaluating:   0%|                                                                                                                               | 0/1 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "scripts/training/train_text_generation.py", line 71, in <module>
    args.log_to_wandb)
  File "scripts/training/train_text_generation.py", line 42, in main
    trainer.train_and_eval()
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 198, in train_and_eval
    self._evaluate_on_datapools(epoch=iter_start)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 193, in _evaluate_on_datapools
    gen_kwargs=self._eval_gen_kwargs)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 41, in evaluate_on_samples
    dt_control_token, gen_kwargs)
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py", line 99, in generate_text
    gen_kwargs=gen_kwargs)["gen_texts"]
  File "/home/jovyan/yazykova-tv/rl_allen/RL4LMs/rl4lms/envs/text_generation/policy.py", line 324, in generate
    log_probs = distribution.log_prob(actions_at_step)
  File "/home/user/conda/lib/python3.7/site-packages/torch/distributions/categorical.py", line 117, in log_prob
    self._validate_sample(value)
  File "/home/user/conda/lib/python3.7/site-packages/torch/distributions/distribution.py", line 277, in _validate_sample
    format(actual_shape, expected_shape))
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([10]) vs torch.Size([30]).

tatiana-iazykova avatar Oct 18 '22 11:10 tatiana-iazykova