[Bug Report] Q cannot be reshaped correctly when model is loaded in 4bit
Describe the bug Query_input's shape is [batch, pos, n_heads, d_model], and the purpose of the code where the error occurred is to reshape query_input to [batch, pos, n_heads, d_head]. I found that the shape of output of bnb.matmul_4bit is still [batch, pos, n_heads, d_model] so it cannot be reshaped to [batch, pos, n_heads, d_head].
The reason for this error may be the following code in abstract_attention.py:
if self.cfg.load_in_4bit:
nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2)
self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
else:
self.W_Q = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_model,
self.cfg.d_head,
dtype=self.cfg.dtype,
)
)
self.W_O = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_head,
self.cfg.d_model,
dtype=self.cfg.dtype,
)
)
When model is loaded in 4bit, the shape of matrix W_Q is [(self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2, 1] which leads to the unexpected shape of the output from the bnb.matmul_4bit function. When model is not loaded in 4bit, the shape of matrix W_Q is [n_heads, d_model, d_head] which does nor trigger the bug mentioned above.
Code example Eorro message:
File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:364, in AbstractAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
339 if self.cfg.load_in_4bit:
340 q = self.hook_q(
341 # call bitsandbytes method to dequantize and multiply
342 bnb.matmul_4bit(
343 query_input,
344 self.W_Q.t(),
345 bias=None,
346 quant_state=self.W_Q.quant_state,
-->347 ).reshape(
348 query_input.shape[0],
349 query_input.shape[1],
350 self.cfg.n_heads,
351 self.cfg.d_head,
352 )
353 + self.b_Q
RuntimeError: shape '[20, 22, 32, 128]' is invalid for input of size 57671680"
code:
with torch.inference_mode():
with model.hooks(fwd_hooks=fwd_hooks_corrupted):
_ = model(corrupted)
System Info Describe the characteristic of your environment:
- git clone
- Linux
- python 3.11.4
Checklist
- [√] I have checked that there is no similar issue in the repo (required)
When I load the model in 4-bit and set model.cfg.use_split_qkv_input = True, this bug will be triggered. Code Example:
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, proxies=proxies,local_files_only=False, low_cpu_mem_usage=True, use_safetensors=False, load_in_4bit=True, torch_dtype=torch.float32, )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HookedTransformer.from_pretrained("llama-7b-hf", center_unembed=False, fold_ln=False, fold_value_biases=False, device='cuda', hf_model=model, tokenizer=tokenizer, hf_model_4bit=True, center_writing_weights=False,
)
model.cfg.use_split_qkv_input = True
model.generate("The capital of Germany is", max_new_tokens=2, temperature=0)
Error:
File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:195, in AbstractAttention.forward(self, query_input, key_input, value_input, past_kv_cache_entry, additive_attention_mask, attention_mask, position_bias)
167 def forward(
168 self,
169 query_input: Union[
(...)
186 position_bias: Optional[Float[torch.Tensor, \"1 head_index pos kv_pos\"]] = None,
187 ) -> Float[torch.Tensor, \"batch pos d_model\"]:
188 \"\"\"
189 shortformer_pos_embed is only used if self.cfg.positional_embedding_type == \"shortformer\", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
190 past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
191 additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
192 attention_mask is the attention mask for padded tokens. Defaults to None.
193 \"\"\"
--> 195 q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
197 if past_kv_cache_entry is not None:
198 # Appends the new keys and values to the cached values, and automatically updates the cache
199 kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
File ~/python_library/TransformerLens/transformer_lens/components/abstract_attention.py:348, in AbstractAttention.calculate_qkv_matrices(self, query_input, key_input, value_input)
339 if self.cfg.load_in_4bit:
340 print('In calculate_qkv_matrices: query_input.shape =', query_input.shape) # XD debug
341 q = self.hook_q(
342 # call bitsandbytes method to dequantize and multiply
343 bnb.matmul_4bit(
344 query_input,
345 self.W_Q.t(),
346 bias=None,
347 quant_state=self.W_Q.quant_state,
--> 348 ).reshape(
349 query_input.shape[0],
350 query_input.shape[1],
351 self.cfg.n_heads,
352 self.cfg.d_head,
353 )
354 + self.b_Q
355 )
356 else:
357 q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q))
RuntimeError: shape '[1, 6, 32, 128]' is invalid for input of size 786432"
@po13on in order to investigate this further, I am going to need to see exactly the code you used to initialize TransformerLens. This bug could be a wide ranging bug, but more likely, it is a specific model causing the issue. I need to see the full block of code you ran to boot TransformerLens in an invalid state.
@bryce13950 I'm sorry for providing incomplete code. The model I loaded is vicuna-7b. Below is the complete code
model_name = 'lmsys/vicuna-7b-v1.3'
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, proxies=proxies,local_files_only=False, low_cpu_mem_usage=True, use_safetensors=False, load_in_4bit=True, torch_dtype=torch.float32, )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HookedTransformer.from_pretrained("llama-7b-hf", center_unembed=False, fold_ln=False, fold_value_biases=False, device='cuda', hf_model=model, tokenizer=tokenizer, hf_model_4bit=True, center_writing_weights=False,
)
model.cfg.use_split_qkv_input = True
model.generate("The capital of Germany is", max_new_tokens=2, temperature=0)
No problem! Thanks for providing this. This should be enough for us to recreate it now.