pyreft
pyreft copied to clipboard
forward() got an unexpected keyword argument 'unit_locations'
I have an error on trainer.trian(). Plese help me!
Error
TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'unit_locations'
code
import pyreft
import torch
import transformers
import pandas as pd
prompt_no_input_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
device='cpu'
model_id = "rinna/llama-3-youko-8b"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id, model_max_length=2048,
padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
# https://github.com/matsuvr/OjousamaTalkScriptDataset
df = pd.read_csv('./OjousamaTalkScriptDataset/ojousamatalkscript200.csv')
sample_df = df.sample(20)
data_module = pyreft.make_last_position_supervised_data_module(
tokenizer, model, [prompt_no_input_template % row['prompt'] for _, row in sample_df.iterrows()],
[row['completion'] for _, row in sample_df.iterrows()])
reft_config = pyreft.ReftConfig(representations={
"layer": 8, "component": "block_output",
"low_rank_dimension": 4,
"intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device(device)
reft_model.print_trainable_parameters()
training_args = transformers.TrainingArguments(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 8,
warmup_steps = 100,
num_train_epochs = 1,
learning_rate = 5e-4,
# bf16 = True,
logging_steps = 1,
optim = "paged_adamw_32bit",
weight_decay = 0.0,
lr_scheduler_type = "cosine",
output_dir = "outputs",
report_to=[]
)
trainer = pyreft.ReftTrainerForCausalLM(model=model, tokenizer=tokenizer, args=training_args, **data_module)
_ = trainer.train()
Environment
pyreft 0.0.5
pyvene 0.1.1
torch 2.0.0
transformers 4.39.3