Ms-PoE
Ms-PoE copied to clipboard
"Found in the Middle: How Language Models Use Long Contexts Better via Plug-and-Play Positional Encoding" Zhenyu Zhang, Runjin Chen, Shiwei Liu, Zhewei Yao, Olatunji Ruwase, Beidi Chen, Xiaoxia Wu, Zh...
The following is your implementation of the MsPoELlamaRotaryEmbedding: ``` def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:,:seq_len].to(dtype=x.dtype), self.sin_cached[:,:seq_len].to(dtype=x.dtype),...
1. Support more models such as mistral, gemma, qwen2 2. Support flash attention 3. Truncate sequence length to under 4k when calculate outliers to avoid OOM
File "Ms-PoE/utils/modify_arch/llama.py", line 330, in forward self.head_order = self._head_wise_statistics(query_states, key_states, q_len, kv_seq_len, bsz, attention_mask) File "Ms-PoE/utils/modify_arch/llama.py", line 176, in _head_wise_statistics raise ValueError( ValueError: Attention weights should be of size (1,...
1.What device did you use to inference on ZeroScrolls dataset. 2.What is the input prompt length do you use when inferencing which is not mentioned in your paper ?
TypeError: MsPoELlamaAttention.forward() got an unexpected keyword argument 'cache_position'