torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Add Phi4

Open krammnic opened this issue 1 year ago • 7 comments

Context

What is the purpose of this PR? Is it to

  • [X] add a new feature
  • [ ] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • #2190

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • [X] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [X] add unit tests for any new functionality
  • [X] update docstrings for any new or updated methods or classes
  • [X] run unit tests via pytest tests
  • [ ] run recipe tests via pytest tests -m integration_test
  • [ ] manually run any new or modified recipes with sufficient proof of correctness
  • [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example

  • [ ] I did not change any public API
  • [X] I have added an example to docs or docstrings

krammnic avatar Dec 21 '24 23:12 krammnic

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2197

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 4f38c14726723304671bd0adfbafffde65040610 with merge base b3964af5aeca5ec314af0af2202219b0bb89deab (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Dec 21 '24 23:12 pytorch-bot[bot]

Should we wait until Phi-4 will be on HF?

krammnic avatar Dec 22 '24 06:12 krammnic

We require on changes in only tokenizer actually, as we already used full Attention in Phi3 without sliding one.

krammnic avatar Dec 22 '24 07:12 krammnic

I assume that I need to do run with Phi4

krammnic avatar Dec 24 '24 09:12 krammnic

@joecummings It seems to me that you haven't gone to holidays) Maybe you can give me some comments about this PR?

krammnic avatar Dec 25 '24 08:12 krammnic

@joecummings It seems to me that you haven't gone to holidays) Maybe you can give me some comments about this PR?

Haha yes I'm still (somewhat) here. I asked and it looks like the Phi4 team is planning on fixing some license issues with Hugging Face and should have the model on the Hub soon. So eventually the true test will be to grab the official model from Hugging Face and do a forward pass; however, if you want to iron out any potential discrepancies right away, I'd just grab one of the unofficial uploads like this for your testing.

Happy holidays to you @krammnic - been a pleasure working with you on torchtune this year!

joecummings avatar Dec 25 '24 13:12 joecummings

@joecummings Thanks for the comments!) Will do some runs with this then

krammnic avatar Dec 25 '24 14:12 krammnic

Hi @krammnic just checking in on this PR. I saw the model is on Hugging Face (as of yesterday I believe). Have you done a parity check with their model? And is this ready for review? If so let me know and we can take a look

ebsmothers avatar Jan 10 '25 00:01 ebsmothers

Hi @krammnic just checking in on this PR. I saw the model is on Hugging Face (as of yesterday I believe). Have you done a parity check with their model? And is this ready for review? If so let me know and we can take a look

Hi, will run tests today and will ping you when it will be ready for review!

krammnic avatar Jan 10 '25 23:01 krammnic

Probably forward is now working (I get OOM cause my cards a busy with some experiments). There is pretty weird point that I had to set num_heads = 20 which is twice less then real num_heads (I assume that it is feature of torchtune?). Also, there is some inconsistency with naming. Official description is: Phi-4 small language model but probably we can't name it "small".

krammnic avatar Jan 11 '25 20:01 krammnic

No, I can't do forward both for num_heads = 20 and num_heads=40:

For 20 I get: size mismatch for layers.39._checkpoint_wrapped_module.attn.q_proj.weight: copying a param with shape torch.Size([2560, 5120]) from checkpoint, the shape in current model is torch.Size([5120, 5120]). For 40 (original value) I get:

        size mismatch for layers.39._checkpoint_wrapped_module.attn.k_proj.weight: copying a param with shape torch.Size([2560, 5120]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
        size mismatch for layers.39._checkpoint_wrapped_module.attn.v_proj.weight: copying a param with shape torch.Size([2560, 5120]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).```

Similar issue for each layer. Took params directly from config.json. Am I missing something?

krammnic avatar Jan 11 '25 20:01 krammnic

Hardcoding like this fixes the issue:

 q_proj=nn.Linear(embed_dim, 2560, bias=False),
 k_proj=nn.Linear(embed_dim, 2560, bias=False),
 v_proj=nn.Linear(embed_dim, 2560, bias=False),

Probably we should revise formulas especially for phi4.

krammnic avatar Jan 11 '25 21:01 krammnic

Nit: For all configs should change tokenizer field

krammnic avatar Jan 11 '25 21:01 krammnic

Getting RuntimeError: shape '[2, 308, 40, 128]' is invalid for input of size 1576960 0%| Probably from same reason

krammnic avatar Jan 11 '25 23:01 krammnic

For num_heads=40, num_kv_heads=10, embed_dim=5120. Let's calculate:

head_dim = 5120 / 40 = 128

Already here:

 q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
 k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
 v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),

Already have a problem here as it will be not 2560 in all cases, but 5120, 1280, 1280. Assume that we "hardcoded" in a way that I have shown earlier. But then we get same problem here:

 q_per_kv = self.num_heads // self.num_kv_heads
 q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)

Error: RuntimeError: shape '[2, 308, 40, 128]' is invalid for input of size 1576960

Part of config.json for reference:

  "hidden_size": 5120,
  "initializer_range": 0.02,
  "intermediate_size": 17920,
  "max_position_embeddings": 16384,
  "model_type": "phi3",
  "num_attention_heads": 40,
  "num_hidden_layers": 40,
  "num_key_value_heads": 10,
  "original_max_position_embeddings": 16384,

So, the product should be twice less. Am I missing something? (I hope I have not miscalculated). Something weird is behind this problem. Will try to work out it asap. @ebsmothers I'm not really sure if it fixable without touching phi3 model or creating separate model for phi4.

krammnic avatar Jan 11 '25 23:01 krammnic

Oh, and also I assume that we first of all need to speak about this... #2212

krammnic avatar Jan 12 '25 22:01 krammnic

For num_heads=40, num_kv_heads=10, embed_dim=5120. Let's calculate:

head_dim = 5120 / 40 = 128

Already here:

 q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
 k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
 v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),

Already have a problem here as it will be not 2560 in all cases, but 5120, 1280, 1280. Assume that we "hardcoded" in a way that I have shown earlier. But then we get same problem here:

 q_per_kv = self.num_heads // self.num_kv_heads
 q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)

Error: RuntimeError: shape '[2, 308, 40, 128]' is invalid for input of size 1576960

Part of config.json for reference:

  "hidden_size": 5120,
  "initializer_range": 0.02,
  "intermediate_size": 17920,
  "max_position_embeddings": 16384,
  "model_type": "phi3",
  "num_attention_heads": 40,
  "num_hidden_layers": 40,
  "num_key_value_heads": 10,
  "original_max_position_embeddings": 16384,

So, the product should be twice less. Am I missing something? (I hope I have not miscalculated). Something weird is behind this problem. Will try to work out it asap. @ebsmothers I'm not really sure if it fixable without touching phi3 model or creating separate model for phi4.

Only point about architecture is this. And this is not align with tech report actually

krammnic avatar Jan 14 '25 20:01 krammnic

Hey @krammnic, sorry I completely missed this comment of yours previously. Did you get this figured out? If not I think we need to make some updates like this to the Phi3 convert weights function to support loading of models with GQA. I think you will also need to update num_heads from 20 to 40 like I suggested previously, but at least after these changes you are able to load the checkpoint. We should also make sure that the forward matches what's on HF -- it's possible that we may need to permute the fused QKV projection instead of naively splitting it as done in my changes. I've put together a minimal script showing roughly how you can do this here. The numbers do not line up, so it needs further investigation whether my torch.split is incorrect or whether I am just passing incorrect arguments to the HF version in that gist.

Separately, you mentioned in another comment there are some tokenizer issues blocking you. Can you elaborate on that? I'm happy to take a look here as well.

ebsmothers avatar Jan 23 '25 01:01 ebsmothers

Hey @krammnic, sorry I completely missed this comment of yours previously. Did you get this figured out? If not I think we need to make some updates like this to the Phi3 convert weights function to support loading of models with GQA. I think you will also need to update num_heads from 20 to 40 like I suggested previously, but at least after these changes you are able to load the checkpoint. We should also make sure that the forward matches what's on HF -- it's possible that we may need to permute the fused QKV projection instead of naively splitting it as done in my changes. I've put together a minimal script showing roughly how you can do this here. The numbers do not line up, so it needs further investigation whether my torch.split is incorrect or whether I am just passing incorrect arguments to the HF version in that gist.

Separately, you mentioned in another comment there are some tokenizer issues blocking you. Can you elaborate on that? I'm happy to take a look here as well.

Thanks for the comments, will check! Speaking about tokenizer I assume I can reference to #2212

krammnic avatar Jan 29 '25 12:01 krammnic

@krammnic regarding the tokenizer: it doesn't have tokenizer.model, but it does have vocab.json and merges.txt. So I think this should be sufficient to construct the base tokenizer from one of our existing classes (though separately I think we can start supporting HF tokenizer.json without too too much extra effort)

ebsmothers avatar Jan 29 '25 15:01 ebsmothers

Have done all requested changes! Probably the only point that we still need to fix @ebsmothers observation

krammnic avatar Feb 02 '25 13:02 krammnic

Hi @krammnic I took another look at the parity issue. I think you should push the changes I suggested here and start running E2E tests. I adapted the script I shared before to mock out RoPE embeddings (there are some minor differences between our implementation and the one in HF). With those mocked out, I get forward pass parity within 1e-5. We should probably still make sure that nothing funky is happening with RoPE, but hopefully this is enough to unblock you for now. Also sorry I couldn't get it up sooner, but there is some initial progress on supporting HF tokenizers built from tokenizer.json files in #2350

ebsmothers avatar Feb 06 '25 01:02 ebsmothers

@ebsmothers Haven't seen this message! Thanks for work on this. I came up to something similar either and was able to forward (but with worse implementation then yours), so I assume it should be fine. For now, we need to only fix the tokenizer problem and we will be able to merge.

krammnic avatar Feb 08 '25 16:02 krammnic

Created new GPT2BaseTokenizer class to solve #2212 for this model

krammnic avatar Feb 08 '25 19:02 krammnic

Loss for alpaca_cleaned after ~500 steps of LoRA single device. W B Chart 2_9_2025, 1_47_01 AM

krammnic avatar Feb 08 '25 22:02 krammnic

@ebsmothers It is ready for review. But, I don't really like this merge conflict... ^_^

krammnic avatar Feb 08 '25 23:02 krammnic

@krammnic reviewing now. For the merge conflict, you should put GPT2BaseTokenizer under torchtune/modules/transforms/tokenizers, not torchtune/modules/tokenizers (the story will be similar for the unit test). You can see #2231 for the cause of the merge conflict. Your final changes should import GPT2BaseTokenizer in torchtune/modules/transforms/tokenizers/__init__.py, but leave torchtune/modules/tokenizers/__init__.py unchanged (we are deprecating it, so with your changes coming in after the deprecation has started you don't ned to expose in the old path).

ebsmothers avatar Feb 09 '25 22:02 ebsmothers

@ebsmothers @SalmanMohammadi Have done probably almost all fixes that were required (I need to still cover Phi4Tokenizer with test about special tokens). You can check it out

krammnic avatar Feb 10 '25 12:02 krammnic

@ebsmothers Probably done, except the merge conflict. I'm not sure how to resolve it without write access

krammnic avatar Feb 10 '25 18:02 krammnic

Done with all requested fixes

krammnic avatar Feb 11 '25 08:02 krammnic