litgpt
litgpt copied to clipboard
Add Danube2
It includes the addition of Danube2 from H20.ai https://huggingface.co/h2oai/h2o-danube2-1.8b-chat
Hi @rasbt
Here when I tried inference I could not get the desired response. I took the prompt style using this code
import torch
from transformers import pipeline
pipe = pipeline(
"text-generation",
model="h2oai/h2o-danube2-1.8b-chat",
torch_dtype=torch.bfloat16,
device_map="auto",
)
# We use the HF Tokenizer chat template to format each message
# https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
{"role": "user", "content": "What is 2+2?"},
]
prompt = pipe.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
res = pipe(
prompt,
max_new_tokens=256,
)
print(res[0]["generated_text"])
The above code generates a response after passing through the model:
<|prompt|>What is 2+2?</s> <|answer|> Two (2) plus two (2) equals four (4).</s>
So even by using the chat prompt template HF uses for this above code, when I tried it with litegpt I’m getting random numbers.
I extensively tried to find the issue but couldn’t find it. So please can you guide me on what could be the potential issue?
And for failing tokenizer
test for Danube2, where do I have to make the change because I see there is no relevant field in config to change.
Hi there, and sorry for the late response, it's been a super intense week. Regarding the tokenizer, usually there shouldn't be a modification necessary as it loads the tokenizer from the hub. But maybe there is a special case here ... I have to think about this more ... Regarding the random numbers you are getting, this could potentially be related. Otherwise, when I tried to add a checkpoint in the past and got weird results, that was usually because there was something weird about the architecture that required some additional adjustments. How I debugged this in the past was adding a small model test based on a small tensor, e.g.,
https://github.com/Lightning-AI/litgpt/blob/e60f21a49a435efd215cb858200396f7b24baf17/tests/test_model.py#L335
And then in the model starting at the bottom, printing the outputs one layer at a time comparing it to the reference implementation to see at which layer the issue appears.
Maybe @Andrei-Aksionov has additional tips as he went through the ordeal with the different Gemma implementations, which all had some minor non-documented things in them.
I'll take a look at it tomorrow.
Hi there, and sorry for the late response, it's been a super intense week. Regarding the tokenizer, usually there shouldn't be a modification necessary as it loads the tokenizer from the hub. But maybe there is a special case here ... I have to think about this more ... Regarding the random numbers you are getting, this could potentially be related. Otherwise, when I tried to add a checkpoint in the past and got weird results, that was usually because there was something weird about the architecture that required some additional adjustments. How I debugged this in the past was adding a small model test based on a small tensor, e.g.,
https://github.com/Lightning-AI/litgpt/blob/e60f21a49a435efd215cb858200396f7b24baf17/tests/test_model.py#L335
And then in the model starting at the bottom, printing the outputs one layer at a time comparing it to the reference implementation to see at which layer the issue appears.
Maybe @Andrei-Aksionov has additional tips as he went through the ordeal with the different Gemma implementations, which all had some minor non-documented things in them.
Thanks for informing me.
Hey @Dev-Khant
The config seems to be ok, the only missing part was rotary_percentage
, which needs to be 1.0
, since it's used in calculation of rope_n_elem
(for RoPE embeddings):
https://github.com/Lightning-AI/litgpt/blob/main/litgpt/config.py#L92
and in the RoPE for Mistral they use self.dim
(analogous to LitGPT head_dim
).
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L100
That means that we need to use 100% of the head_dim
size, hence rotary_percentage=1.0
.
So, this little change + #1328 results in the identical output between LitGPT and HF variant. You can verify it with a simple test code
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["h2o-danube2-1.8b-chat"])
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
],
)
def test_against_original_danube(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = AutoConfig.from_pretrained(
"/".join(ours_config.hf_config.values()),
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = AutoModelForCausalLM.from_config(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
The problem that I couldn't solve yet, is that despite the tokenizer output (with fixes from #1328) and the model output to be identical to the HF variant, the generation itself produces somewhat weird response.
I'll revisit it a bit later. Or maybe you can try to find it in the meantime. We can make it a race: who can find it quicker 🏎️.
Update: after I redownloaded the weights, the model started to show a decent output. Maybe a bit too many new tokens are generated for my liking, but I guess it can be tweaked with generation parameters. In other words, we need to wait till #1328 is merged and this PR should be ready.
Hey @Dev-Khant The config seems to be ok, the only missing part was
rotary_percentage
, which needs to be1.0
, since it's used in calculation ofrope_n_elem
(for RoPE embeddings): https://github.com/Lightning-AI/litgpt/blob/main/litgpt/config.py#L92and in the RoPE for Mistral they use
self.dim
(analogous to LitGPThead_dim
). https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L100That means that we need to use 100% of the
head_dim
size, hencerotary_percentage=1.0
.So, this little change + #1328 results in the identical output between LitGPT and HF variant. You can verify it with a simple test code
@torch.inference_mode() @pytest.mark.parametrize("model_name", ["h2o-danube2-1.8b-chat"]) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), ], ) def test_against_original_danube(model_name, device, dtype): torch.set_default_dtype(dtype) T = 5 ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) theirs_config = AutoConfig.from_pretrained( "/".join(ours_config.hf_config.values()), vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = AutoModelForCausalLM.from_config(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y)
The problem that I couldn't solve yet, is that despite the tokenizer output (with fixes from #1328) and the model output to be identical to the HF variant, the generation itself produces somewhat weird response.
I'll revisit it a bit later. Or maybe you can try to find it in the meantime. We can make it a race: who can find it quicker 🏎️.
Update: after I redownloaded the weights, the model started to show a decent output. Maybe a bit too many new tokens are generated for my liking, but I guess it can be tweaked with generation parameters. In other words, we need to wait till #1328 is merged and this PR should be ready.
Thanks, @Andrei-Aksionov for thoroughly going through the PR. I'll run the above code with changes from #1328 and check the output for both tokenizers.
Also prior to this #1328, what I did was take the tokens generated from HF tokenizer and then pass them to the LitGPT model but still somehow it generated random words. So as you said we can wait for #1328 to get merged and meanwhile, I'll again try to see what is going wrong with the prediction.
Update: Looks like it works properly with changes from #1328 and with latest weights. @Andrei-Aksionov Please confirm on your side as well. Thanks!
Hi @Andrei-Aksionov @rasbt Can you please review the PR now as it's passing the tests and also text generation is working properly.
Hey @Dev-Khant You don't need my approval, since I'm not a maintainer (though I approved anyway 🙂).
@rasbt Since you are a markdown Jedi, could you look at the changes like added empty lines? From the commit history I see that I introduced those changes when I merged changes from the main, but apparently they are no longer there (in the main branch), or maybe my markdown auto-formatting did it 🤷♂️.
Thanks for the ping @Dev-Khant & @Andrei-Aksionov , and thanks so much for this valuable contribution. I'll take a look!
Just played around with it for a bit and it works great. Thanks again for this great contrib!
Thanks @rasbt!