pyreft
pyreft copied to clipboard
[P1] Refactor ReftTrainer to save artifacts with the config
The issue is that ReftTrainer.save_model
does not save the ReftConfig
, only the intervention.
As a workaround, we can load the model from the checkpoint using the following code (by reinstantiating the config manually):
import pyreft
import pyvene as pv
reft_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype= torch.bfloat16, device_map="cuda")
reft_config = pyreft.ReftConfig(representations={
"layer": 15, "component": "block_output",
"low_rank_dimension": 4,
"intervention": pyreft.LoreftIntervention(embed_dim=reft_model.config.hidden_size,
low_rank_dimension=4)})
reft_model = pv.IntervenableModel(reft_config, reft_model)
reft_model.load_intervention('./tmp/checkpoint-78/intervenable_model')
device = 'cuda'
for k, v in reft_model.interventions.items():
v[0].to(device)
Please let me know if I am missing something!
Thanks, Bryan