transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[WIP] Add diffllama

Open weak-kajuma opened this issue 1 year ago • 4 comments

What does this PR do?

This PR adds the codes for the DiffLlama, which is Llama model with Differential Transformer. Please refer to Differential Transformer. @ArthurZucker

weak-kajuma avatar Oct 11 '24 08:10 weak-kajuma

I am coding now, but it's first time I contribute transformers and other OSS. I may ask you some help.

weak-kajuma avatar Oct 11 '24 08:10 weak-kajuma

I still have a error located in modeling_diffllama.py@377: apply_rotary_pos_emb. Var "query_states" must be torch.Size([2, 32, 10, 128]) but the var is torch.Size([2, 64, 10, 64]). I need to change "query_states" or "cos"&"sin".

weak-kajuma avatar Oct 11 '24 13:10 weak-kajuma

I've finished making normal/eager Attention, and I can run with AutoModelforForCausalLM.generate(). But I'll adapt it for FlashAttention2 and Sdpa Attention.

weak-kajuma avatar Oct 16 '24 13:10 weak-kajuma

And also I fixed to fit modular transfomres.

weak-kajuma avatar Oct 16 '24 13:10 weak-kajuma

@bzantium I found Attention missed implemented from paper still on e072544a3bfc69b8a903e062729f861108ffecd3. So I'll revert to e072544a3bfc69b8a903e062729f861108ffecd3 and re-implement with your suggested code style.

weak-kajuma avatar Oct 20 '24 11:10 weak-kajuma

if commit about FlashAttention2 is imported: DiffLlamaAttention and DiffLlamaSdpaAttention output the same tensor. DiffLlamaFlashAttention2 cannot work alone, just as LlamaFlashAttention2 (original) cannot work alone.

But DiffLlamaForCausalLM with eager and one with sdpa DON'T output the same tensor. DiffLlamaForCausalLM with eager and one with flash_attention_2 DON'T output the same tensor. DiffLlamaForCausalLM with sdpa and one with flash_attention_2 output the same tensor.

I cannot found why thay don't output same tensor.

weak-kajuma avatar Oct 20 '24 13:10 weak-kajuma

Thanks for re-implementing. I had checked DiffLlama Model with these 3 attentions output about the same results. I think because of little difference of libraries, flash-attn and pytorch.nn.functional, they don't output exactly the same.

weak-kajuma avatar Oct 21 '24 12:10 weak-kajuma

I think last commit is not right. see original code and it is groupnorm for each head not for whole self.hidden_size based on the paper as well, so I think the code I suggested it right. image to: @weak-kajuma

bzantium avatar Oct 21 '24 13:10 bzantium

Sorry, I think I was wrong and your interpretation is correct. I'll revert.

weak-kajuma avatar Oct 21 '24 13:10 weak-kajuma

Feel free to ping @Cyrilvallez once you think this is ready for review!

ArthurZucker avatar Oct 24 '24 15:10 ArthurZucker

Could you review this PR? to: @Cyrilvallez

bzantium avatar Oct 28 '24 14:10 bzantium

could you make all test passed? to: @weak-kajuma

bzantium avatar Oct 29 '24 08:10 bzantium

I found that you need to place diffllama alphabetically on the src/transformers/__init__.py to pass check_code_quality. to: @weak-kajuma

bzantium avatar Oct 30 '24 00:10 bzantium

I think runing make fixup should help you with this!

ArthurZucker avatar Oct 30 '24 09:10 ArthurZucker

Rebasing / merging from main will fix the other unrelated tests!

ArthurZucker avatar Oct 30 '24 09:10 ArthurZucker

To pass the test of test_initialization and test_mismatched_shapes_have_properly_initialized_weights, I want to change/add to the code of tests/test_modeking_common.py. But this is common code. Could I change/add to the code like below?

tests/test_modeking_common.py:705

def test_initialization(self):
    config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

    configs_no_init = _config_zero_init(config)
+   configs_no_init.zero_init = True
    for model_class in self.all_model_classes:
        model = model_class(config=configs_no_init)
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.assertIn(
                    ((param.data.mean() * 1e9).round() / 1e9).item(),
                    [0.0, 1.0],
                    msg=f"Parameter {name} of model {model_class} seems not properly initialized",
                )

similarly for test_mismatched_shapes_have_properly_initialized_weights at 3437

weak-kajuma avatar Oct 31 '24 00:10 weak-kajuma

All tests passed other than tests/utils/test_modeling_utils.py::ModelUtilsTest::test_generation_config_is_loaded_with_model, unrelated to adding this model.

Please review this PR again? And could you tell me how to fix the error? to: @Cyrilvallez

weak-kajuma avatar Nov 01 '24 12:11 weak-kajuma

All tests passed other than tests/utils/test_modeling_utils.py::ModelUtilsTest::test_generation_config_is_loaded_with_model, unrelated to adding this model.

Please review this PR again? And could you tell me how to fix the error? to: @Cyrilvallez

This failing test seems to only be due to a CI internal error (this happens sometimes unfortunately). When it happens, you can push an empty commit to re-trigger the CIs.

Cyrilvallez avatar Nov 05 '24 15:11 Cyrilvallez

To pass the test of test_initialization and test_mismatched_shapes_have_properly_initialized_weights, I want to change/add to the code of tests/test_modeking_common.py. But this is common code. Could I change/add to the code like below?

tests/test_modeking_common.py:705

def test_initialization(self):
    config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

    configs_no_init = _config_zero_init(config)
+   configs_no_init.zero_init = True
    for model_class in self.all_model_classes:
        model = model_class(config=configs_no_init)
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.assertIn(
                    ((param.data.mean() * 1e9).round() / 1e9).item(),
                    [0.0, 1.0],
                    msg=f"Parameter {name} of model {model_class} seems not properly initialized",
                )

similarly for test_mismatched_shapes_have_properly_initialized_weights at 3437

We want to avoid it, see my review for more infos 🤗

Cyrilvallez avatar Nov 05 '24 15:11 Cyrilvallez

All of your review implemented. And I tried the test many times, but it didn't pass. What should I do? To: @Cyrilvallez

weak-kajuma avatar Nov 07 '24 13:11 weak-kajuma

Hey! Sorry we were all off for a week on a company-wide offsite! 🤗 @Cyrilvallez should be back on monday!

ArthurZucker avatar Nov 15 '24 22:11 ArthurZucker

I wonder this pr is still working in progress? Or, most of the implementation has been finalized and waiting for the test coverage review?

effortprogrammer avatar Nov 20 '24 08:11 effortprogrammer

Hey, sorry for the delay! In order to use modular transformers, you need to create a new file, modular_diffllama.py, in which you can use inheritance from the different Llama classes. Then, to automatically create the modeling_diffllama.py file, just use our CLI: python utils/modular_model_converter.py --files_to_parse src/transformers/models/diffllama/modular_diffllama.py from the root of the transformers repo 🤗 LMK if you need more guidance for this! You can find some modular example, e.g. here Basically, any class similar to a Llama class you can directly inherit from to avoid rewriting it, e.g. if DiffLlamaRotaryEmbedding is similar to LlamaRotaryEmbedding, you can use

class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
    pass

in the modular file. In your case, you will probably need to only rewrite the attention classes 😉

Cyrilvallez avatar Nov 20 '24 16:11 Cyrilvallez

Are you still working on this PR, @weak-kajuma ?

effortprogrammer avatar Nov 30 '24 15:11 effortprogrammer

@Cyrilvallez Could you review again? I made modular_diffllama.py.

weak-kajuma avatar Dec 04 '24 07:12 weak-kajuma

You may need to rebase/merge on main though for modular to work perfectly as you seem to be a bit far behind. If something does not work as expected after my comments, you should try that first 🤗

Cyrilvallez avatar Dec 04 '24 17:12 Cyrilvallez

@Cyrilvallez Could you review again? Moduler transformers is very easy and good. And also I can pass all tests by merging latest changes.

weak-kajuma avatar Dec 06 '24 12:12 weak-kajuma

@Cyrilvallez any plannings to review this pr?

effortprogrammer avatar Dec 10 '24 01:12 effortprogrammer

The main change of https://github.com/huggingface/transformers/pull/35235 is about Attention, I know. But I may not be able to change differential attention like https://github.com/huggingface/transformers/pull/35235. You are so busy, but I want you to make PR.

weak-kajuma avatar Dec 24 '24 12:12 weak-kajuma

Cool let's merge then! 🤗

ArthurZucker avatar Jan 07 '25 10:01 ArthurZucker