llama-cookbook icon indicating copy to clipboard operation
llama-cookbook copied to clipboard

peft_method works fine with lora, but pops error when using prefix and llama_adapter

Open cyberyu opened this issue 1 year ago • 10 comments

System Info

[pip3] numpy==1.26.3 [pip3] torch==2.2.0+cu118 [pip3] triton==2.2.0 [conda] numpy 1.26.3 pypi_0 pypi [conda] torch 2.2.0+cu118 pypi_0 pypi [conda] triton 2.2.0 pypi_0 pypi

GCP instance g2-standard-48, 4xL4 GPU

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

🐛 Describe the bug

when change peft_method from lora to other two methods (llama_adapter and prefix), errors pop up

llama_adapter reports error about: ValueError: Must flatten tensors with uniform requires_grad when use_orig_params=False

prefix reports error in the transformer package modeling_llama.py code, past_key_value.get_usable_length AttributeError: 'tuple' object has no attribute 'get_usable_length'

The transformer/torch package version I use is print(transformers.version)
4.36.2 print(torch.version) 2.3.0a0+gitc7e9c15

Error logs

==== llama_adapter case ============ ** running command: CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes 1 --nproc_per_node 2 examples/finetuning.py --enable_fsdp --dataset grammar_dataset -use_peft --peft_method llama_adapter --model_name meta-llama/Llama-2-7b-hf --output_dir ./output-models-grammar-llama_adapter --num_epochs 1 --pure_bf16

**outputs:

[2024-01-27 18:51:26,053] torch.distributed.run: [WARNING] [2024-01-27 18:51:26,053] torch.distributed.run: [WARNING] ***************************************** [2024-01-27 18:51:26,053] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. [2024-01-27 18:51:26,053] torch.distributed.run: [WARNING] ***************************************** Clearing GPU cache for all ranks --> Running with torch dist debug set to detail Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.02s/it] Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:09<00:00, 4.98s/it] --> Model meta-llama/Llama-2-7b-hf --> meta-llama/Llama-2-7b-hf has 6738.415616 Million params trainable params: 1,228,830 || all params: 6,739,644,446 || trainable%: 0.018232860944605387 trainable params: 1,228,830 || all params: 6,739,644,446 || trainable%: 0.018232860944605387 bFloat16 enabled for mixed precision - using bfSixteen policy Traceback (most recent call last): File "/home/user/newlama/llama-recipes/examples/finetuning.py", line 8, in fire.Fire(main) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/llama_recipes/finetuning.py", line 144, in main model = FSDP( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in init _auto_wrap( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( [Previous line repeated 2 more times] File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap return wrapper_cls(module, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in init _init_param_handle_from_module( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 594, in _init_param_handle_from_module _init_param_handle_from_params(state, managed_params, fully_sharded_module) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 606, in _init_param_handle_from_params handle = FlatParamHandle( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in init self._init_flat_param_and_metadata( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata ) = self._validate_tensors_to_flatten(params) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten raise ValueError( ValueError: Must flatten tensors with uniform requires_grad when use_orig_params=False Traceback (most recent call last): File "/home/user/newlama/llama-recipes/examples/finetuning.py", line 8, in fire.Fire(main) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/llama_recipes/finetuning.py", line 144, in main model = FSDP( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in init _auto_wrap( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( [Previous line repeated 2 more times] File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap return wrapper_cls(module, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in init _init_param_handle_from_module( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 594, in _init_param_handle_from_module _init_param_handle_from_params(state, managed_params, fully_sharded_module) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 606, in _init_param_handle_from_params handle = FlatParamHandle( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in init self._init_flat_param_and_metadata( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata ) = self._validate_tensors_to_flatten(params) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten raise ValueError( ValueError: Must flatten tensors with uniform requires_grad when use_orig_params=False [2024-01-27 18:51:51,094] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1150407) of binary: /opt/conda/envs/newlama/bin/python3.9 Traceback (most recent call last): File "/opt/conda/envs/newlama/bin/torchrun", line 8, in sys.exit(main()) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 347, in wrapper return f(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/run.py", line 812, in main run(args) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/run.py", line 803, in run elastic_launch( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 135, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

examples/finetuning.py FAILED

Failures: [1]: time : 2024-01-27_18:51:51 host : instance-2.us-central1-a.c.training-382917.internal rank : 1 (local_rank: 1) exitcode : 1 (pid: 1150408) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Root Cause (first observed failure): [0]: time : 2024-01-27_18:51:51 host : instance-2.us-central1-a.c.training-382917.internal rank : 0 (local_rank: 0) exitcode : 1 (pid: 1150407) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html`

==== prefix case ============ ** input command: torchrun --nnodes 1 --nproc_per_node 2 examples/finetuning.py --enable_fsdp --dataset grammar_dataset -use_peft --peft_method prefix --model_name meta-llama/Llama-2-7b-hf --output_dir ./output-models-grammar-prefix --num_epochs 1 --pure_bf16

--------------outputs----------------: [2024-01-27 18:49:51,058] torch.distributed.run: [WARNING] [2024-01-27 18:49:51,058] torch.distributed.run: [WARNING] ***************************************** [2024-01-27 18:49:51,058] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. [2024-01-27 18:49:51,058] torch.distributed.run: [WARNING] ***************************************** Clearing GPU cache for all ranks --> Running with torch dist debug set to detail Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:09<00:00, 5.00s/it] Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.00s/it] --> Model meta-llama/Llama-2-7b-hf --> meta-llama/Llama-2-7b-hf has 6738.415616 Million params trainable params: 7,864,320 || all params: 6,746,279,936 || trainable%: 0.11657269005446738 trainable params: 7,864,320 || all params: 6,746,279,936 || trainable%: 0.11657269005446738 bFloat16 enabled for mixed precision - using bfSixteen policy --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13016/13016 [00:10<00:00, 1283.08it/s] Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13016/13016 [00:10<00:00, 1277.85it/s] --> Training Set Length = 484 Generating train split: 0 examples [00:00, ? examples/s]/opt/conda/envs/newlama/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:778: FutureWarning: The 'verbose' keyword in pd.read_csv is deprecated and will be removed in a future version. return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs) Generating train split: 2988 examples [00:00, 192317.78 examples/s] Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2988/2988 [00:01<00:00, 1531.01it/s] Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2988/2988 [00:01<00:00, 1535.34it/s] --> Validation Set Length = 82 /opt/conda/envs/newlama/lib/python3.9/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( /opt/conda/envs/newlama/lib/python3.9/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 0: 0%| | 0/60 [00:00<?, ?it/s]Traceback (most recent call last): File "/home/user/newlama/llama-recipes/examples/finetuning.py", line 8, in fire.Fire(main) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/llama_recipes/finetuning.py", line 237, in main results = train( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/llama_recipes/utils/train_utils.py", line 79, in train loss = model(**batch).loss File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/peft/peft_model.py", line 1108, in forward return self.base_model( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward outputs = self.model( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward layer_outputs = decoder_layer( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward return self.checkpoint_fn( # type: ignore[misc] File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint ret = function(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 703, in forward kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) AttributeError: 'tuple' object has no attribute 'get_usable_length' Traceback (most recent call last): File "/home/user/newlama/llama-recipes/examples/finetuning.py", line 8, in fire.Fire(main) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/llama_recipes/finetuning.py", line 237, in main results = train( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/llama_recipes/utils/train_utils.py", line 79, in train loss = model(**batch).loss File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/peft/peft_model.py", line 1108, in forward return self.base_model( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward outputs = self.model( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward layer_outputs = decoder_layer( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward return self.checkpoint_fn( # type: ignore[misc] File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint ret = function(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 703, in forward kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) AttributeError: 'tuple' object has no attribute 'get_usable_length' Training Epoch: 0: 0%| | 0/60 [00:01<?, ?it/s] Training Epoch: 0: 0%| | 0/60 [00:01<?, ?it/s] [2024-01-27 18:50:36,119] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1150053) of binary: /opt/conda/envs/newlama/bin/python3.9 Traceback (most recent call last): File "/opt/conda/envs/newlama/bin/torchrun", line 8, in sys.exit(main()) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 347, in wrapper return f(*args, **kwargs) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/run.py", line 812, in main run(args) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/run.py", line 803, in run elastic_launch( File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 135, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/opt/conda/envs/newlama/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

examples/finetuning.py FAILED

Failures: [1]: time : 2024-01-27_18:50:36 host : instance-2.us-central1-a.c.training-382917.internal rank : 1 (local_rank: 1) exitcode : 1 (pid: 1150054) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Root Cause (first observed failure): [0]: time : 2024-01-27_18:50:36 host : instance-2.us-central1-a.c.training-382917.internal rank : 0 (local_rank: 0) exitcode : 1 (pid: 1150053) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Expected behavior

Since Lora works fine, I didn't expect so much troubles happen switching to other two methods. I checked up all previous bug reports and didn't find similar cases.

cyberyu avatar Jan 27 '24 19:01 cyberyu

Hey, Did you manage to solve this? I am getting the same error.

Mugheeera avatar Apr 01 '24 18:04 Mugheeera

just a quick try, I wonder if setting use_orig_params=False in here would help to bypass the error?

HamidShojanazeri avatar Apr 01 '24 19:04 HamidShojanazeri

use_orig_params=False results in this error. [rank0]: RuntimeError: The size of tensor a (4096) must match the size of tensor b (4126) at non-singleton dimension 3 when --peft_method prefix and ValueError: Must flatten tensors with uniform requires_grad when use_orig_params=False when --peft_method llama_adapter

Mugheeera avatar Apr 02 '24 10:04 Mugheeera

Hey, Did you manage to solve this? I am getting the same error.

Not yet. I didn't try anything new afterwards.

cyberyu avatar Apr 02 '24 20:04 cyberyu

same error here

wang-sj16 avatar Apr 22 '24 19:04 wang-sj16

Hi @cyberyu @Mugheeera @wang-sj16 I looked into llama_adapter and it turns out that the way llama_adapter and FSDP are written they are currently incompatible. Llama_adapter adds nn.Parameters to the model layers which do not get wrapped separately by FSDP which only wraps nn.Modules . We will add docs and probably check the config to make users aware of this.

Still looking into prefix, but it might be a similar story there.

mreso avatar May 01 '24 23:05 mreso

Turns out that in the case of prefix tuning its actually an incompatibility between the current peft and transformers/llama implementation. Previously, past_key_values used tuples as a data structure but "recently" (~5month ago) some models in transformers where updated to use a dedicated Cache data structure (including llama). The peft model is still creating a tuple https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py#L517 which is incompatible with the current llama implementation in transformers. Will create a PR to remove prefix tuning for now until peft allows it with llama. @cyberyu @Mugheeera @wang-sj16 please fell free to create an issue in the peft repo if you want this use case supported.

mreso avatar May 02 '24 00:05 mreso

th

Turns out that in the case of prefix tuning its actually an incompatibility between the current peft and transformers/llama implementation. Previously, past_key_values used tuples as a data structure but "recently" (~5month ago) some models in transformers where updated to use a dedicated Cache data structure (including llama). The peft model is still creating a tuple https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py#L517 which is incompatible with the current llama implementation in transformers. Will create a PR to remove prefix tuning for now until peft allows it with llama. @cyberyu @Mugheeera @wang-sj16 please fell free to create an issue in the peft repo if you want this use case supported.

Thank you so much for looking into this.

cyberyu avatar May 02 '24 00:05 cyberyu

Has this been changed yet? I'm getting the error here with llama2 7b,

transformers==4.3.2 torch==2.1.0 deepspeed==0.12.2

File "/root/.venv/unlrn/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 402, in forward kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) AttributeError: 'list' object has no attribute 'get_usable_length'

aengusl avatar May 20 '24 10:05 aengusl

Hi @aengusl which techniques are you trying to run? Are you running the latest version of llama-recipes? Please note that prefix fine tuning is currently disabled completely as the peft library is incompatible with the current llama model provided by transformers and llama adapters only works without FSDP. The newest llama-recipes should reflect this changes.

mreso avatar May 20 '24 16:05 mreso

Closing this issue due to inactivity, feel free to reopen if there are further questions.

mreso avatar Jun 28 '24 00:06 mreso