Murtadha Ahmed

Results 2 comments of Murtadha Ahmed

model_args['attn_implementation'] = 'flash_attention_2' model = LlamaForCausalLM.from_pretrained(model_name, **model_args).eval() adding the flash_attention_2 works for me

I tried this code, but it doesn't work ``` if prefix_parallel and prefix_parallel > 1 : key_length_ = ((key_length - query_length) // prefix_parallel) + query_length causal_mask = self.bias[:, :, key_length_...