transformers
transformers copied to clipboard
Generate: support for left-padding on GPTNeoX and Llama
What does this PR do?
As the title indicates, adds left-padding support for GPTNeoX and Llama.
It adds the position_ids input, propagates all the way to the position embedding, and gathers the position embeddings given the value in position_ids. All slow tests are now passing in both models, including the newly added left-padding support test and the GPTNeoX integration test.
Also makes a few changes on Llama to make it more similar to other models 🤗
The documentation is not available anymore as the PR was closed or merged.
The failing CI is fixed by #22383 :)
@ArthurZucker @sgugger woopsie, I forgot that it affected the weight loading code -- I come from a place where weight names have to be specified 👼 Reverted (self.llama is self.model again)!
It appears as if this may have broken FSDP. For example, as specified in the Alpaca repo, finetuning with --fsdp "full_sh ard auto_wrap" --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer worked before this commit, but after it gives the error such as:
File "/home/fsuser/.local/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 313, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/fsuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'position_ids'
Reverting the commit fixes it, although perhaps the problem is with accelerate not supporting position_ids? cc: @ArthurZucker
@jquesnelle can you paste the full stack trace? It would allow us to find the root cause :D (maybe, as you mention, the problem is in accelerate... or maybe it comes from the Alpaca repo!)
I'm seeing a pretty significant performance hit on RedPajama-7b-chat that I think is due to this change. I ran the PyTorch profiler and all of the repeat operators in apply_rotary_pos_emb are expensive and run mostly on CPU. Reverting to transformers 4.27.x resolves the performance issue.
You should try the main branch, #22785 removed the repeat solving this