GLMKD icon indicating copy to clipboard operation
GLMKD copied to clipboard

Using different language models

Open j-datta opened this issue 1 year ago • 4 comments

How can I use different language models from Hugging Face for knowledge distillation in this set up?

j-datta avatar Jun 28 '24 14:06 j-datta

Recently, I've been going through your repository. As a starting point, I want to ask you one thing. Is it possible to use the Mistral-7b and LLama-7b models as a 'Teacher' in your framework? How can I add 'AutoTokenizer' to your framework?

j-datta avatar Jul 01 '24 13:07 j-datta

You can modify the GLMModel based on the architecture of other models. To add an AutoTokenizer, you can implement a class similar to BertWordPieceTokenizer.

Currently, there is no simple and quick method to directly use Hugging Face's language models.

aitsc avatar Jul 01 '24 13:07 aitsc

Thank you for your reply.

I want to ask another question. How can I pass GPT-2 as a teacher model in command-line argument?

I am new in this area. Sorry for all these questions.

j-datta avatar Jul 03 '24 13:07 j-datta

I've loaded the pre-trained gpt-2 model from huggingface and convert .bin file to .pt file. But when I pass this model as a teacher model, I found error as key missmatch and unexpected keys.

Missing keys ['transformer.position_embeddings.weight', 'transformer.block_position_embeddings.weight', 'transformer.layers.0.input_layernorm.weight', 'transformer.layers.0.input_layernorm.bias', 'transformer.layers.0.attention.query_key_value.weight', 'transformer.layers.0.attention.query_key_value.bias', 'transformer.layers.0.attention.dense.weight', 'transformer.layers.0.attention.dense.bias', 'transformer.layers.0.post_attention_layernorm.weight', 'transformer.layers.0.post_attention_layernorm.bias', 'transformer.layers.0.mlp.dense_h_to_4h.weight', 'transformer.layers.0.mlp.dense_h_to_4h.bias', 'transformer.layers.0.mlp.dense_4h_to_h.weight', 'transformer.layers.0.mlp.dense_4h_to_h.bias', 'transformer.layers.1.input_layernorm.weight', 'transformer.layers.1.input_layernorm.bias', 'transformer.layers.1.attention.query_key_value.weight', 'transformer.layers.1.attention.query_key_value.bias', 'transformer.layers.1.attention.dense.weight', 'transformer.layers.1.attention.dense.bias', 'transformer.layers.1.post_attention_layernorm.weight', 'transformer.layers.1.post_attention_layernorm.bias', 'transformer.layers.1.mlp.dense_h_to_4h.weight', 'transformer.layers.1.mlp.dense_h_to_4h.bias', 'transformer.layers.1.mlp.dense_4h_to_h.weight', 'transformer.layers.1.mlp.dense_4h_to_h.bias', 'transformer.layers.2.input_layernorm.weight', 'transformer.layers.2.input_layernorm.bias', 'transformer.layers.2.attention.query_key_value.weight', 'transformer.layers.2.attention.query_key_value.bias', 'transformer.layers.2.attention.dense.weight', 'transformer.layers.2.attention.dense.bias', 'transformer.layers.2.post_attention_layernorm.weight', 'transformer.layers.2.post_attention_layernorm.bias', 'transformer.layers.2.mlp.dense_h_to_4h.weight', 'transformer.layers.2.mlp.dense_h_to_4h.bias', 'transformer.layers.2.mlp.dense_4h_to_h.weight', 'transformer.layers.2.mlp.dense_4h_to_h.bias', 'transformer.layers.3.input_layernorm.weight', 'transformer.layers.3.input_layernorm.bias', 'transformer.layers.3.attention.query_key_value.weight', 'transformer.layers.3.attention.query_key_value.bias', 'transformer.layers.3.attention.dense.weight', 'transformer.layers.3.attention.dense.bias', 'transformer.layers.3.post_attention_layernorm.weight', 'transformer.layers.3.post_attention_layernorm.bias', 'transformer.layers.3.mlp.dense_h_to_4h.weight', 'transformer.layers.3.mlp.dense_h_to_4h.bias', 'transformer.layers.3.mlp.dense_4h_to_h.weight', 'transformer.layers.3.mlp.dense_4h_to_h.bias', 'transformer.layers.4.input_layernorm.weight', 'transformer.layers.4.input_layernorm.bias', 'transformer.layers.4.attention.query_key_value.weight', 'transformer.layers.4.attention.query_key_value.bias', 'transformer.layers.4.attention.dense.weight', 'transformer.layers.4.attention.dense.bias', 'transformer.layers.4.post_attention_layernorm.weight', 'transformer.layers.4.post_attention_layernorm.bias', 'transformer.layers.4.mlp.dense_h_to_4h.weight', 'transformer.layers.4.mlp.dense_h_to_4h.bias', 'transformer.layers.4.mlp.dense_4h_to_h.weight', 'transformer.layers.4.mlp.dense_4h_to_h.bias', 'transformer.layers.5.input_layernorm.weight', 'transformer.layers.5.input_layernorm.bias', 'transformer.layers.5.attention.query_key_value.weight', 'transformer.layers.5.attention.query_key_value.bias', 'transformer.layers.5.attention.dense.weight', 'transformer.layers.5.attention.dense.bias', 'transformer.layers.5.post_attention_layernorm.weight', 'transformer.layers.5.post_attention_layernorm.bias', 'transformer.layers.5.mlp.dense_h_to_4h.weight', 'transformer.layers.5.mlp.dense_h_to_4h.bias', 'transformer.layers.5.mlp.dense_4h_to_h.weight', 'transformer.layers.5.mlp.dense_4h_to_h.bias', 'transformer.layers.6.input_layernorm.weight', 'transformer.layers.6.input_layernorm.bias', 'transformer.layers.6.attention.query_key_value.weight', 'transformer.layers.6.attention.query_key_value.bias', 'transformer.layers.6.attention.dense.weight', 'transformer.layers.6.attention.dense.bias', 'transformer.layers.6.post_attention_layernorm.weight', 'transformer.layers.6.post_attention_layernorm.bias', 'transformer.layers.6.mlp.dense_h_to_4h.weight', 'transformer.layers.6.mlp.dense_h_to_4h.bias', 'transformer.layers.6.mlp.dense_4h_to_h.weight', 'transformer.layers.6.mlp.dense_4h_to_h.bias', 'transformer.layers.7.input_layernorm.weight', 'transformer.layers.7.input_layernorm.bias', 'transformer.layers.7.attention.query_key_value.weight', 'transformer.layers.7.attention.query_key_value.bias', 'transformer.layers.7.attention.dense.weight', 'transformer.layers.7.attention.dense.bias', 'transformer.layers.7.post_attention_layernorm.weight', 'transformer.layers.7.post_attention_layernorm.bias', 'transformer.layers.7.mlp.dense_h_to_4h.weight', 'transformer.layers.7.mlp.dense_h_to_4h.bias', 'transformer.layers.7.mlp.dense_4h_to_h.weight', 'transformer.layers.7.mlp.dense_4h_to_h.bias', 'transformer.layers.8.input_layernorm.weight', 'transformer.layers.8.input_layernorm.bias', 'transformer.layers.8.attention.query_key_value.weight', 'transformer.layers.8.attention.query_key_value.bias', 'transformer.layers.8.attention.dense.weight', 'transformer.layers.8.attention.dense.bias', 'transformer.layers.8.post_attention_layernorm.weight', 'transformer.layers.8.post_attention_layernorm.bias', 'transformer.layers.8.mlp.dense_h_to_4h.weight', 'transformer.layers.8.mlp.dense_h_to_4h.bias', 'transformer.layers.8.mlp.dense_4h_to_h.weight', 'transformer.layers.8.mlp.dense_4h_to_h.bias', 'transformer.layers.9.input_layernorm.weight', 'transformer.layers.9.input_layernorm.bias', 'transformer.layers.9.attention.query_key_value.weight', 'transformer.layers.9.attention.query_key_value.bias', 'transformer.layers.9.attention.dense.weight', 'transformer.layers.9.attention.dense.bias', 'transformer.layers.9.post_attention_layernorm.weight', 'transformer.layers.9.post_attention_layernorm.bias', 'transformer.layers.9.mlp.dense_h_to_4h.weight', 'transformer.layers.9.mlp.dense_h_to_4h.bias', 'transformer.layers.9.mlp.dense_4h_to_h.weight', 'transformer.layers.9.mlp.dense_4h_to_h.bias', 'transformer.layers.10.input_layernorm.weight', 'transformer.layers.10.input_layernorm.bias', 'transformer.layers.10.attention.query_key_value.weight', 'transformer.layers.10.attention.query_key_value.bias', 'transformer.layers.10.attention.dense.weight', 'transformer.layers.10.attention.dense.bias', 'transformer.layers.10.post_attention_layernorm.weight', 'transformer.layers.10.post_attention_layernorm.bias', 'transformer.layers.10.mlp.dense_h_to_4h.weight', 'transformer.layers.10.mlp.dense_h_to_4h.bias', 'transformer.layers.10.mlp.dense_4h_to_h.weight', 'transformer.layers.10.mlp.dense_4h_to_h.bias', 'transformer.layers.11.input_layernorm.weight', 'transformer.layers.11.input_layernorm.bias', 'transformer.layers.11.attention.query_key_value.weight', 'transformer.layers.11.attention.query_key_value.bias', 'transformer.layers.11.attention.dense.weight', 'transformer.layers.11.attention.dense.bias', 'transformer.layers.11.post_attention_layernorm.weight', 'transformer.layers.11.post_attention_layernorm.bias', 'transformer.layers.11.mlp.dense_h_to_4h.weight', 'transformer.layers.11.mlp.dense_h_to_4h.bias', 'transformer.layers.11.mlp.dense_4h_to_h.weight', 'transformer.layers.11.mlp.dense_4h_to_h.bias'], unexpected keys ['position_embeddings.weight', 'transformer.layers.0.ln_1.weight', 'transformer.layers.0.ln_1.bias', 'transformer.layers.0.attn.c_attn.weight', 'transformer.layers.0.attn.c_attn.bias', 'transformer.layers.0.attn.c_proj.weight', 'transformer.layers.0.attn.c_proj.bias', 'transformer.layers.0.ln_2.weight', 'transformer.layers.0.ln_2.bias', 'transformer.layers.0.mlp.c_fc.weight', 'transformer.layers.0.mlp.c_fc.bias', 'transformer.layers.0.mlp.c_proj.weight', 'transformer.layers.0.mlp.c_proj.bias', 'transformer.layers.1.ln_1.weight', 'transformer.layers.1.ln_1.bias', 'transformer.layers.1.attn.c_attn.weight', 'transformer.layers.1.attn.c_attn.bias', 'transformer.layers.1.attn.c_proj.weight', 'transformer.layers.1.attn.c_proj.bias', 'transformer.layers.1.ln_2.weight', 'transformer.layers.1.ln_2.bias', 'transformer.layers.1.mlp.c_fc.weight', 'transformer.layers.1.mlp.c_fc.bias', 'transformer.layers.1.mlp.c_proj.weight', 'transformer.layers.1.mlp.c_proj.bias', 'transformer.layers.2.ln_1.weight', 'transformer.layers.2.ln_1.bias', 'transformer.layers.2.attn.c_attn.weight', 'transformer.layers.2.attn.c_attn.bias', 'transformer.layers.2.attn.c_proj.weight', 'transformer.layers.2.attn.c_proj.bias', 'transformer.layers.2.ln_2.weight', 'transformer.layers.2.ln_2.bias', 'transformer.layers.2.mlp.c_fc.weight', 'transformer.layers.2.mlp.c_fc.bias', 'transformer.layers.2.mlp.c_proj.weight', 'transformer.layers.2.mlp.c_proj.bias', 'transformer.layers.3.ln_1.weight', 'transformer.layers.3.ln_1.bias', 'transformer.layers.3.attn.c_attn.weight', 'transformer.layers.3.attn.c_attn.bias', 'transformer.layers.3.attn.c_proj.weight', 'transformer.layers.3.attn.c_proj.bias', 'transformer.layers.3.ln_2.weight', 'transformer.layers.3.ln_2.bias', 'transformer.layers.3.mlp.c_fc.weight', 'transformer.layers.3.mlp.c_fc.bias', 'transformer.layers.3.mlp.c_proj.weight', 'transformer.layers.3.mlp.c_proj.bias', 'transformer.layers.4.ln_1.weight', 'transformer.layers.4.ln_1.bias', 'transformer.layers.4.attn.c_attn.weight', 'transformer.layers.4.attn.c_attn.bias', 'transformer.layers.4.attn.c_proj.weight', 'transformer.layers.4.attn.c_proj.bias', 'transformer.layers.4.ln_2.weight', 'transformer.layers.4.ln_2.bias', 'transformer.layers.4.mlp.c_fc.weight', 'transformer.layers.4.mlp.c_fc.bias', 'transformer.layers.4.mlp.c_proj.weight', 'transformer.layers.4.mlp.c_proj.bias', 'transformer.layers.5.ln_1.weight', 'transformer.layers.5.ln_1.bias', 'transformer.layers.5.attn.c_attn.weight', 'transformer.layers.5.attn.c_attn.bias', 'transformer.layers.5.attn.c_proj.weight', 'transformer.layers.5.attn.c_proj.bias', 'transformer.layers.5.ln_2.weight', 'transformer.layers.5.ln_2.bias', 'transformer.layers.5.mlp.c_fc.weight', 'transformer.layers.5.mlp.c_fc.bias', 'transformer.layers.5.mlp.c_proj.weight', 'transformer.layers.5.mlp.c_proj.bias', 'transformer.layers.6.ln_1.weight', 'transformer.layers.6.ln_1.bias', 'transformer.layers.6.attn.c_attn.weight', 'transformer.layers.6.attn.c_attn.bias', 'transformer.layers.6.attn.c_proj.weight', 'transformer.layers.6.attn.c_proj.bias', 'transformer.layers.6.ln_2.weight', 'transformer.layers.6.ln_2.bias', 'transformer.layers.6.mlp.c_fc.weight', 'transformer.layers.6.mlp.c_fc.bias', 'transformer.layers.6.mlp.c_proj.weight', 'transformer.layers.6.mlp.c_proj.bias', 'transformer.layers.7.ln_1.weight', 'transformer.layers.7.ln_1.bias', 'transformer.layers.7.attn.c_attn.weight', 'transformer.layers.7.attn.c_attn.bias', 'transformer.layers.7.attn.c_proj.weight', 'transformer.layers.7.attn.c_proj.bias', 'transformer.layers.7.ln_2.weight', 'transformer.layers.7.ln_2.bias', 'transformer.layers.7.mlp.c_fc.weight', 'transformer.layers.7.mlp.c_fc.bias', 'transformer.layers.7.mlp.c_proj.weight', 'transformer.layers.7.mlp.c_proj.bias', 'transformer.layers.8.ln_1.weight', 'transformer.layers.8.ln_1.bias', 'transformer.layers.8.attn.c_attn.weight', 'transformer.layers.8.attn.c_attn.bias', 'transformer.layers.8.attn.c_proj.weight', 'transformer.layers.8.attn.c_proj.bias', 'transformer.layers.8.ln_2.weight', 'transformer.layers.8.ln_2.bias', 'transformer.layers.8.mlp.c_fc.weight', 'transformer.layers.8.mlp.c_fc.bias', 'transformer.layers.8.mlp.c_proj.weight', 'transformer.layers.8.mlp.c_proj.bias', 'transformer.layers.9.ln_1.weight', 'transformer.layers.9.ln_1.bias', 'transformer.layers.9.attn.c_attn.weight', 'transformer.layers.9.attn.c_attn.bias', 'transformer.layers.9.attn.c_proj.weight', 'transformer.layers.9.attn.c_proj.bias', 'transformer.layers.9.ln_2.weight', 'transformer.layers.9.ln_2.bias', 'transformer.layers.9.mlp.c_fc.weight', 'transformer.layers.9.mlp.c_fc.bias', 'transformer.layers.9.mlp.c_proj.weight', 'transformer.layers.9.mlp.c_proj.bias', 'transformer.layers.10.ln_1.weight', 'transformer.layers.10.ln_1.bias', 'transformer.layers.10.attn.c_attn.weight', 'transformer.layers.10.attn.c_attn.bias', 'transformer.layers.10.attn.c_proj.weight', 'transformer.layers.10.attn.c_proj.bias', 'transformer.layers.10.ln_2.weight', 'transformer.layers.10.ln_2.bias', 'transformer.layers.10.mlp.c_fc.weight', 'transformer.layers.10.mlp.c_fc.bias', 'transformer.layers.10.mlp.c_proj.weight', 'transformer.layers.10.mlp.c_proj.bias', 'transformer.layers.11.ln_1.weight', 'transformer.layers.11.ln_1.bias', 'transformer.layers.11.attn.c_attn.weight', 'transformer.layers.11.attn.c_attn.bias', 'transformer.layers.11.attn.c_proj.weight', 'transformer.layers.11.attn.c_proj.bias', 'transformer.layers.11.ln_2.weight', 'transformer.layers.11.ln_2.bias', 'transformer.layers.11.mlp.c_fc.weight', 'transformer.layers.11.mlp.c_fc.bias', 'transformer.layers.11.mlp.c_proj.weight', 'transformer.layers.11.mlp.c_proj.bias']

j-datta avatar Jul 04 '24 11:07 j-datta