RL4LMs
RL4LMs copied to clipboard
Value is not broadcastable with batch_shape+event_shape
My yaml:
tokenizer:
model_name: facebook/bart-large-cnn
padding_side: left
truncation_side: left
pad_token_as_eos_token: False
reward_fn:
id: rouge
args:
rouge_type: "rouge1"
datapool:
id: cnn_daily_mail
args:
prompt_prefix: "Summarize: "
max_size: 500
env:
n_envs: 1
args:
max_prompt_length: 64
max_episode_length: 100
terminate_on_eos: True
prompt_truncation_side: "right"
context_start_token: 0
alg:
id: ppo
args:
n_steps: 512
batch_size: 4
verbose: 1
learning_rate: 0.000002
n_epochs: 5
ent_coef: 0.0
kl_div:
coeff: 0.001
target_kl: 0.2
policy:
id: seq2seq_lm_actor_critic_policy
args:
model_name: facebook/bart-large-cnn
apply_model_parallel: False
prompt_truncation_side: "right"
generation_kwargs:
do_sample: True
top_k: 50
min_length: 50
max_new_tokens: 100
train_evaluation:
eval_batch_size: 16
n_iters: 100
eval_every: 10
save_every: 1
metrics:
- id: meteor
args: {}
- id: rouge
- id: bleu
args: {}
- id: bert_score
args:
language: en
# - id: bleurt
# args:
# config_name: bleurt-large-512
- id: diversity
args: {}
# - id: summaCZS
# args:
# granularity: sentence
# use_ent: True
# use_con: False
# - id: summaCConv
# args:
# granularity: sentence
generation_kwargs:
do_sample: True
top_k: 0
temperature: 0.7
min_length: 50
max_new_tokens: 100
My error:
[/content/RL4LMs/scripts/training/train_text_generation.py](https://localhost:8080/#) in main(config_path, project_name, experiment_name, base_path_to_store_results, entity_name, log_to_wandb)
53 tracker=tracker,
54 )
---> 55 trainer.train_and_eval()
56
57
[/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in train_and_eval(self)
195 # evaluate on val and test set before fine-tuning once
196 iter_start = self._trainer_state["current_iter"]
--> 197 self._evaluate_on_datapools(epoch=iter_start)
198
199 # train for given number of iters
[/content/RL4LMs/rl4lms/envs/text_generation/training_utils.py](https://localhost:8080/#) in _evaluate_on_datapools(self, epoch, splits)
181 splits: List[str] = ["val", "test"]):
182 for split in splits:
--> 183 evaluate_on_samples(policy=self._alg.policy,
184 tokenizer=self._tokenizer,
185 samples=self._samples_by_split[split],
[/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in evaluate_on_samples(policy, tokenizer, samples, batch_size, max_prompt_length, metrics, epoch, split_name, tracker, dt_control_token, gen_kwargs)
39 n_samples = len(samples)
40 for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"):
---> 41 batch_generated_texts = generate_text(
42 policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs
43 )
[/content/RL4LMs/rl4lms/envs/text_generation/evaluation_utils.py](https://localhost:8080/#) in generate_text(policy, tokenizer, samples, max_prompt_length, dt_control_token, gen_kwargs)
109 dt_control_token + sample.prompt_or_input_text for sample in samples
110 ]
--> 111 generated_texts = policy.generate(
112 tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs
113 ).gen_texts
[/content/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py](https://localhost:8080/#) in generate(self, tokenizer, texts, max_prompt_length, input_ids, attention_mask, gen_kwargs)
254 actions_at_step = gen_tokens[:, step]
255 distribution = Categorical(logits=raw_logits)
--> 256 log_probs = distribution.log_prob(actions_at_step)
257 step_wise_logprobs.append(log_probs)
258 step_wise_actions.append(actions_at_step)
[/usr/local/lib/python3.8/dist-packages/torch/distributions/categorical.py](https://localhost:8080/#) in log_prob(self, value)
121 def log_prob(self, value):
122 if self._validate_args:
--> 123 self._validate_sample(value)
124 value = value.long().unsqueeze(-1)
125 value, log_pmf = torch.broadcast_tensors(value, self.logits)
[/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py](https://localhost:8080/#) in _validate_sample(self, value)
280 for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
281 if i != 1 and j != 1 and i != j:
--> 282 raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
283 format(actual_shape, expected_shape))
284 try:
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([400]).
By the way, can I use
datapool:
id: cnn_daily_mail
args:
prompt_prefix: "Summarize: "
max_size: 500
to control the data size I use?