[WIP] Add diffllama
What does this PR do?
This PR adds the codes for the DiffLlama, which is Llama model with Differential Transformer. Please refer to Differential Transformer. @ArthurZucker
I am coding now, but it's first time I contribute transformers and other OSS. I may ask you some help.
I still have a error located in modeling_diffllama.py@377: apply_rotary_pos_emb. Var "query_states" must be torch.Size([2, 32, 10, 128]) but the var is torch.Size([2, 64, 10, 64]). I need to change "query_states" or "cos"&"sin".
I've finished making normal/eager Attention, and I can run with AutoModelforForCausalLM.generate(). But I'll adapt it for FlashAttention2 and Sdpa Attention.
And also I fixed to fit modular transfomres.
@bzantium I found Attention missed implemented from paper still on e072544a3bfc69b8a903e062729f861108ffecd3. So I'll revert to e072544a3bfc69b8a903e062729f861108ffecd3 and re-implement with your suggested code style.
if commit about FlashAttention2 is imported: DiffLlamaAttention and DiffLlamaSdpaAttention output the same tensor. DiffLlamaFlashAttention2 cannot work alone, just as LlamaFlashAttention2 (original) cannot work alone.
But DiffLlamaForCausalLM with eager and one with sdpa DON'T output the same tensor. DiffLlamaForCausalLM with eager and one with flash_attention_2 DON'T output the same tensor. DiffLlamaForCausalLM with sdpa and one with flash_attention_2 output the same tensor.
I cannot found why thay don't output same tensor.
Thanks for re-implementing. I had checked DiffLlama Model with these 3 attentions output about the same results. I think because of little difference of libraries, flash-attn and pytorch.nn.functional, they don't output exactly the same.
I think last commit is not right. see original code and it is groupnorm for each head not for whole self.hidden_size based on the paper as well, so I think the code I suggested it right.
to: @weak-kajuma
Sorry, I think I was wrong and your interpretation is correct. I'll revert.
Feel free to ping @Cyrilvallez once you think this is ready for review!
Could you review this PR? to: @Cyrilvallez
could you make all test passed? to: @weak-kajuma
I found that you need to place diffllama alphabetically on the src/transformers/__init__.py to pass check_code_quality.
to: @weak-kajuma
I think runing make fixup should help you with this!
Rebasing / merging from main will fix the other unrelated tests!
To pass the test of test_initialization and test_mismatched_shapes_have_properly_initialized_weights, I want to change/add to the code of tests/test_modeking_common.py. But this is common code. Could I change/add to the code like below?
tests/test_modeking_common.py:705
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
+ configs_no_init.zero_init = True
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
similarly for test_mismatched_shapes_have_properly_initialized_weights at 3437
All tests passed other than tests/utils/test_modeling_utils.py::ModelUtilsTest::test_generation_config_is_loaded_with_model, unrelated to adding this model.
Please review this PR again? And could you tell me how to fix the error? to: @Cyrilvallez
All tests passed other than
tests/utils/test_modeling_utils.py::ModelUtilsTest::test_generation_config_is_loaded_with_model, unrelated to adding this model.Please review this PR again? And could you tell me how to fix the error? to: @Cyrilvallez
This failing test seems to only be due to a CI internal error (this happens sometimes unfortunately). When it happens, you can push an empty commit to re-trigger the CIs.
To pass the test of test_initialization and test_mismatched_shapes_have_properly_initialized_weights, I want to change/add to the code of
tests/test_modeking_common.py. But this is common code. Could I change/add to the code like below?
tests/test_modeking_common.py:705def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() configs_no_init = _config_zero_init(config) + configs_no_init.zero_init = True for model_class in self.all_model_classes: model = model_class(config=configs_no_init) for name, param in model.named_parameters(): if param.requires_grad: self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", )similarly for
test_mismatched_shapes_have_properly_initialized_weights at 3437
We want to avoid it, see my review for more infos 🤗
All of your review implemented. And I tried the test many times, but it didn't pass. What should I do? To: @Cyrilvallez
Hey! Sorry we were all off for a week on a company-wide offsite! 🤗 @Cyrilvallez should be back on monday!
I wonder this pr is still working in progress? Or, most of the implementation has been finalized and waiting for the test coverage review?
Hey, sorry for the delay!
In order to use modular transformers, you need to create a new file, modular_diffllama.py, in which you can use inheritance from the different Llama classes. Then, to automatically create the modeling_diffllama.py file, just use our CLI: python utils/modular_model_converter.py --files_to_parse src/transformers/models/diffllama/modular_diffllama.py from the root of the transformers repo 🤗
LMK if you need more guidance for this! You can find some modular example, e.g. here
Basically, any class similar to a Llama class you can directly inherit from to avoid rewriting it, e.g. if DiffLlamaRotaryEmbedding is similar to LlamaRotaryEmbedding, you can use
class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
pass
in the modular file. In your case, you will probably need to only rewrite the attention classes 😉
Are you still working on this PR, @weak-kajuma ?
@Cyrilvallez Could you review again? I made modular_diffllama.py.
You may need to rebase/merge on main though for modular to work perfectly as you seem to be a bit far behind. If something does not work as expected after my comments, you should try that first 🤗
@Cyrilvallez Could you review again? Moduler transformers is very easy and good. And also I can pass all tests by merging latest changes.
@Cyrilvallez any plannings to review this pr?
The main change of https://github.com/huggingface/transformers/pull/35235 is about Attention, I know. But I may not be able to change differential attention like https://github.com/huggingface/transformers/pull/35235. You are so busy, but I want you to make PR.
Cool let's merge then! 🤗