transformers
transformers copied to clipboard
Model saving when using Trainer with Accelerate
System Info
python: 3.9.18 transformers: 4.39.0 pytorch: 2.0.1 accelerate: 0.28.0
Who can help?
@muellerzr @pa
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
accelerator = Accelerator()
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = LlamaForCausalLM.from_pretrained(llm_path,
cache_dir=cache_dir)
train_dataset = ToyDataset(args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
model = accelerator.prepare(model)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
callbacks=[TrainerCallback],
)
trainer.train()
trainer.save_state()
trainer.save_model()
model = LlamaForCausalLM.from_pretrained(fine_tuned_model_path)
Expected behavior
The model is wrapped with a 'module' namespace since I am using Accelerate, which wraps the model with DDP. When I save the model with Trainer.save_model() and load it again with LlamaForCausalLM.from_pretrained(), none of the parameter keys are matched; thus, everything is initialized with new weights. Can you add an argument in Trainer.save_model() to account for Accelerate and unwrap the model before saving?
cc @pacman100
Gentle ping @pacman100 @muellerzr
@amyeroberts is this up for taking?
Yes it is @nakranivaibhav! 🤗
(And agree with the proposal, I thought about this too recently)
Alright @muellerzr i will look into it.
Okay so when i try to replicate this on a smaller model, bert-base-uncased ,there is no error. Does the model have to be big enough so that it gets sharded across devices for the error to come up? @fezsid Can i get the args passed into the trainer, or the script to replicate the error @muellerzr
@nakranivaibhav I am facing this issue when I use multiple GPUs using accelerate. I don't think that the model size matters as such. As a workaround I am using the following to save my weights:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save)
Below is a simplified version of the script I use to train my model. It works right now using unwrapped_model.save_pretrained(), but it would be nice if it could be integrated into the trainer class.
def main(args):
accelerator = Accelerator()
args.dataloader_num_workers = accelerator.num_processes
model = LlamaForCausalLM.from_pretrained(args.llm_path,
cache_dir=args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(args.llm_path)
train_dataset = CustomDataset(args)
val_dataset = CustomDataset(args)
data_collator = CustomCollator()
model = accelerator.prepare(model)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
callbacks=[TrainerCallback],
)
# Start training
trainer.train()
# Save model, tokenizer and state
trainer.save_state()
tokenizer.save_pretrained(args.output_dir)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save)
# Evaluate
trainer.evaluate()
accelerator.wait_for_everyone()
if __name__ == '__main__':
cli_parser = argparse.ArgumentParser()
cli_parser.add_argument('--config', type=str)
config = cli_parser.parse_args().config
args_parser = transformers.HfArgumentParser(TrainingArguments)
args = args_parser.parse_yaml_file(config, allow_extra_keys=True)[0]
os.environ['WANDB_PROJECT'] = args.project_name
main(args)