transformers
transformers copied to clipboard
How can we resume training from lora model?
System Info
transformersversion: 4.39.0.dev0- Platform: Linux-3.10.0-1160.71.1.el7.x86_64-x86_64-with-glibc2.17
- Python version: 3.10.13
- Huggingface_hub version: 0.21.4
- Safetensors version: 0.4.2
- Accelerate version: 0.27.2
- Accelerate config: - compute_environment: LOCAL_MACHINE - distributed_type: DEEPSPEED - mixed_precision: bf16 - use_cpu: False - debug: False - num_processes: 4 - machine_rank: 0 - num_machines: 1 - rdzv_backend: static - same_network: True - main_training_function: main - deepspeed_config: {'gradient_accumulation_steps': 16, 'zero3_init_flag': False, 'zero_stage': 0} - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False - tpu_env: []
- PyTorch version (GPU?): 2.2.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
- Name: peft Version: 0.9.1.dev0
Who can help?
@younesbelkada
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
I found a relevant PR #24274 but it seems not work for me (I found that there is an commenet in the PR that someone who is completely consistent with me has not been resolved) In short, when using transformers trainer, if training is interrupted, simply use trainer.train(resume_from_checkpoint=True) will get following error
ValueError: Can't find a checkpoint index (pytorch_model.bin.index.json or model.safetensors.index.json) in ***/checkpoint-50.
reproduction code
from transformers import MambaForCausalLM,AutoTokenizer,Seq2SeqTrainer,DataCollatorForSeq2Seq,Seq2SeqTrainingArguments,Trainer,TrainingArguments
import torch
from dataset import MyDataset
import json
from plot import plot_loss
from peft import LoraConfig
model_dir='mamba-2.8b-hf'
output_dir='./mamba-translate'
tokenizer=AutoTokenizer.from_pretrained(model_dir,padding_side='left')
model=MambaForCausalLM.from_pretrained(model_dir,torch_dtype=torch.bfloat16)
collator=DataCollatorForSeq2Seq(tokenizer,model)
lora_config = LoraConfig(
r=64,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none",
use_rslora=True,
)
model.add_adapter(lora_config)
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
args=TrainingArguments(
overwrite_output_dir =True,
remove_unused_columns =False,
gradient_accumulation_steps=16,
# gradient_checkpointing=True,
#------------------------------
evaluation_strategy='steps',
eval_delay=100,
eval_steps =50,
#-------------------------------
save_strategy ='steps',
save_steps = 50,
save_total_limit =3,
load_best_model_at_end=True,
#--------------------------------
dataloader_num_workers =10,
learning_rate=2e-3,
num_train_epochs=30,
# auto_find_batch_size=True,
per_device_train_batch_size=8,
per_device_eval_batch_size =8,
output_dir="/data/ruanjh/mamba-translate-2.8b-lora",
logging_steps=5,
bf16=True,
prediction_loss_only=True,
lr_scheduler_type="cosine",
# torch_compile=True,
# torch_compile_backend='inductor',
# torch_compile_mode='max-autotune',
optim='adamw_apex_fused',
# save_safetensors =False,
),
data_collator=collator,
)
trainer.train(resume_from_checkpoint=True)
Expected behavior
Resume training correctly
I found the possible cause, in hf we can use model.add_adapter(lora_config) or model=get_peft_model(model,lora_config) to convert a model to peft model. But the former will cause the error while the latter works well.
But using the latter one will cause another error:
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
also cc @pacman100
@rangehow
Apologies for the delay, could you try to run a single training step on CPU and report the error here ? Alternatively you can also run CUDA_LAUNCH_BLOCKING=1 python xxx and paste the error traceback here 🙏
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi, I seem to be having the same problem: When calling trainer.train(resume_from_checkpoint=True) to work with a previous checkpoint generated during LoRA training (in particular, I am working with this notebook), the execution fails with the error below. This suggests that the code is looking for the full model weights in the checkpoint directory (which don't exist there), not just the adapter checkpoints.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[16], [line 1](vscode-notebook-cell:?execution_count=16&line=1)
----> [1](vscode-notebook-cell:?execution_count=16&line=1) trainer.train(resume_from_checkpoint=True)
File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1920, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
[1918](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1918) if resume_from_checkpoint is not None:
[1919](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1919) if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
-> [1920](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1920) self._load_from_checkpoint(resume_from_checkpoint)
[1921](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1921) # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
[1922](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1922) state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2635, in Trainer._load_from_checkpoint(self, resume_from_checkpoint, model)
[2632](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2632) logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
[2633](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2633) else:
[2634](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2634) # We load the sharded checkpoint
-> [2635](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2635) load_result = load_sharded_checkpoint(
[2636](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2636) model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
[2637](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2637) )
[2638](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2638) if not is_sagemaker_mp_enabled():
[2639](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2639) self._issue_warnings_after_load(load_result)
File /opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:492, in load_sharded_checkpoint(model, folder, strict, prefer_safe)
[488](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:488) if not index_present and not (safe_index_present and is_safetensors_available()):
[489](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:489) filenames = (
[490](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:490) (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,)
[491](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:491) )
--> [492](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:492) raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
[494](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:494) load_safe = False
[495](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:495) if safe_index_present:
ValueError: Can't find a checkpoint index (pytorch_model.bin.index.json or model.safetensors.index.json) in ./idefics3-llama-vqav2/checkpoint-1000.
Hi, I seem to be having the same problem: When calling
trainer.train(resume_from_checkpoint=True)to work with a previous checkpoint generated during LoRA training (in particular, I am working with this notebook), the execution fails with the error below. This suggests that the code is looking for the full model weights in the checkpoint directory (which don't exist there), not just the adapter checkpoints.--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[16], [line 1](vscode-notebook-cell:?execution_count=16&line=1) ----> [1](vscode-notebook-cell:?execution_count=16&line=1) trainer.train(resume_from_checkpoint=True) File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1920, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs) [1918](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1918) if resume_from_checkpoint is not None: [1919](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1919) if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: -> [1920](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1920) self._load_from_checkpoint(resume_from_checkpoint) [1921](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1921) # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly [1922](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:1922) state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2635, in Trainer._load_from_checkpoint(self, resume_from_checkpoint, model) [2632](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2632) logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") [2633](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2633) else: [2634](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2634) # We load the sharded checkpoint -> [2635](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2635) load_result = load_sharded_checkpoint( [2636](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2636) model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors [2637](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2637) ) [2638](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2638) if not is_sagemaker_mp_enabled(): [2639](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2639) self._issue_warnings_after_load(load_result) File /opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:492, in load_sharded_checkpoint(model, folder, strict, prefer_safe) [488](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:488) if not index_present and not (safe_index_present and is_safetensors_available()): [489](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:489) filenames = ( [490](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:490) (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,) [491](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:491) ) --> [492](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:492) raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") [494](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:494) load_safe = False [495](https://vscode-remote+ssh-002dremote-002bssh4-002evast-002eai.vscode-resource.vscode-cdn.net/opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:495) if safe_index_present: ValueError: Can't find a checkpoint index (pytorch_model.bin.index.json or model.safetensors.index.json) in ./idefics3-llama-vqav2/checkpoint-1000.
I'm having same problem as well.
cc @SunMarc
Thanks for the reproducer @joris-sense ! I'll try to reproduce and fix the issue !
Hi, is there any progress for this ticket?