trlx
trlx copied to clipboard
ILQL training batch2 tensor dimensions error
Hi, I'm trying an ILQL training with a gpt-j network trained with this code. I don't have this problem with the original pre-trained net, nor with a flan-xl.
Traceback (most recent call last):
File "/home/jupyter/trlx/examples/summarize_rlhf/ilql_gptj.py", line 118, in <module>
main()
File "/home/jupyter/trlx/examples/summarize_rlhf/ilql_gptj.py", line 109, in main
trlx.train(
File "/home/jupyter/trlx/trlx/trlx.py", line 126, in train
trainer.learn()
File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 539, in learn
results = self.evaluate()
File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 384, in evaluate
samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"])
File "/home/jupyter/trlx/trlx/trainer/accelerate_base_trainer.py", line 276, in generate_eval
return self.accelerator.unwrap_model(self.model).generate(
File "/home/jupyter/trlx/trlx/models/modeling_ilql.py", line 307, in generate
out = self.forward(
File "/home/jupyter/trlx/trlx/models/modeling_ilql.py", line 263, in forward
outputs = self.base_model(**forward_kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 854, in forward
transformer_outputs = self.transformer(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 689, in forward
outputs = block(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 309, in forward
attn_outputs = self.attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 257, in forward
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/gptj/modeling_gptj.py", line 183, in _attn
attn_output = torch.matmul(attn_weights, value)
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [256, 101] but got: [256, 1].
This is my config:
config = TRLConfig(
train=TrainConfig(
seq_length=768,
epochs=epochs,
total_steps=total_steps,
batch_size=batch_size,
checkpoint_interval=eval_and_checkpoint,
eval_interval=eval_and_checkpoint,
pipeline="PromptPipeline",
trainer="AccelerateILQLTrainer",
save_best=True,
checkpoint_dir="ckpts_ilql"
),
model=ModelConfig(
model_path=pretrained_model_path,
num_layers_unfrozen=-1,
),
tokenizer=TokenizerConfig(
tokenizer_path="gpt2",
truncation_side="right",
),
optimizer=OptimizerConfig(
name="adamw",
kwargs={
"lr": 5.0e-5,
"betas": [0.9, 0.999],
"eps": 1.0e-8,
"weight_decay": 1.0e-6,
},
),
scheduler=SchedulerConfig(
name="cosine_annealing",
kwargs=dict(T_max=1e12, eta_min=5.0e-5)
),
method=ILQLConfig(
name="ILQLConfig",
tau=0.7,
gamma=0.99,
cql_scale=0.1,
awac_scale=1,
alpha=0.001,
beta=0,
steps_for_target_q_sync=5,
two_qs=True,
gen_kwargs=dict(max_new_tokens=256, top_k=20, beta=4, temperature=1.0)
),
)
Thanks.
Hi @GenVr! Can you show your training code as well alongside your config? There might be an error in how you passed the training data in. Thanks!
@maxreciprocate Regarding the dataset and train, I use this train() code:
trlx.train(
samples = [(text,output) for text,output in zip(ttv_ds['train']['text'],ttv_ds['train']['output'])],
rewards = labels,
eval_prompts=ttv_ds['validation']['text'][:16],
config = config,
)
Where:
samples = [(string, string), (string, string), ...] # list of tuples (string, string)
labels = [0,1,0,1...] # list of labels 0/1
samples = [string, string, ..] # list of strings
Thanks for your answer!