bert
bert copied to clipboard
How is the number of BERT model parameters calculated?
I‘m a bit confused about the 110M parameters. How is it calculated?
通常情况 transformer 模型有很多参数需要训练。譬如 BERT BASE 模型: L=12, H=768, A=12, 需要训练的模型参数总数是 12 * 768 * 12 = 110M
https://zhuanlan.zhihu.com/p/51413773
我对110M参数感到有点困惑。它是如何计算的?
通常情况变压器模型有很多参数需要训练。譬如BERT BASE模型:L = 12,H = 768,A = 12,需要训练的模型参数总数是12 * 768 * 12 = 110M
https://zhuanlan.zhihu.com/p/51413773
12 * 768 * 12 = 110M ?
here is one layer Transformer
# parameters: 10152448 | 10152448 | |||
---|---|---|---|---|
weight name: encoder.src_word_emb.weight | size: | [5395, 512] | count: | 2762240 |
weight name: encoder.position_enc.weight | size: | [33, 512] | count: | 16896 |
weight name: encoder.layer_stack.0.slf_attn.w_qs.weight | size: | [512, 512] | count: | 262144 |
weight name: encoder.layer_stack.0.slf_attn.w_qs.bias | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.slf_attn.w_ks.weight | size: | [512, 512] | count: | 262144 |
weight name: encoder.layer_stack.0.slf_attn.w_ks.bias | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.slf_attn.w_vs.weight | size: | [512, 512] | count: | 262144 |
weight name: encoder.layer_stack.0.slf_attn.w_vs.bias | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.slf_attn.layer_norm.weight | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.slf_attn.layer_norm.bias | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.slf_attn.fc.weight | size: | [512, 512] | count: | 262144 |
weight name: encoder.layer_stack.0.slf_attn.fc.bias | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.pos_ffn.w_1.weight | size: | [2048, 512, 1] | count: | 1048576 |
weight name: encoder.layer_stack.0.pos_ffn.w_1.bias | size: | [2048] | count: | 2048 |
weight name: encoder.layer_stack.0.pos_ffn.w_2.weight | size: | [512, 2048, 1] | count: | 1048576 |
weight name: encoder.layer_stack.0.pos_ffn.w_2.bias | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.pos_ffn.layer_norm.weight | size: | [512] | count: | 512 |
weight name: encoder.layer_stack.0.pos_ffn.layer_norm.bias | size: | [512] | count: | 512 |
weight name: decoder.tgt_word_emb.weight | size: | [5395, 512] | count: | 2762240 |
weight name: decoder.position_enc.weight | size: | [33, 512] | count: | 16896 |
weight name: decoder.layer_stack.0.slf_attn.w_qs.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.slf_attn.w_qs.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.slf_attn.w_ks.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.slf_attn.w_ks.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.slf_attn.w_vs.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.slf_attn.w_vs.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.slf_attn.layer_norm.weight | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.slf_attn.layer_norm.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.slf_attn.fc.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.slf_attn.fc.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.enc_attn.w_qs.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.enc_attn.w_qs.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.enc_attn.w_ks.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.enc_attn.w_ks.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.enc_attn.w_vs.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.enc_attn.w_vs.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.enc_attn.layer_norm.weight | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.enc_attn.layer_norm.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.enc_attn.fc.weight | size: | [512, 512] | count: | 262144 |
weight name: decoder.layer_stack.0.enc_attn.fc.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.pos_ffn.w_1.weight | size: | [2048, 512, 1] | count: | 1048576 |
weight name: decoder.layer_stack.0.pos_ffn.w_1.bias | size: | [2048] | count: | 2048 |
weight name: decoder.layer_stack.0.pos_ffn.w_2.weight | size: | [512, 2048, 1] | count: | 1048576 |
weight name: decoder.layer_stack.0.pos_ffn.w_2.bias | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.pos_ffn.layer_norm.weight | size: | [512] | count: | 512 |
weight name: decoder.layer_stack.0.pos_ffn.layer_norm.bias | size: | [512] | count: | 512 |
weight name: tgt_word_prj.weight | size: | [5395, 512] | count: | 2762240 |
-- vocab size=5395 seq len=33, embedding=512, you cand replace it with BERT
- bert-base-uncased, 110M parameters
Bert-base-uncased | Key | Shape | Count | |
---|---|---|---|---|
Embedding | embeddings.word_embeddings.weight | [30522, 768] | 23,440,896 | 23,837,184 |
embeddings.position_embeddings.weight | [512, 768] | 393,216 | ||
embeddings.token_type_embeddings.weight | [2, 768] | 1,536 | ||
embeddings.LayerNorm.weight | [768] | 768 | ||
embeddings.LayerNorm.bias | [768] | 768 | ||
Transformer * 12 | encoder.layer.0.attention.self.query.weight | [768, 768] | 589,824 | 7,087,872 * 12 = 85,054,464 |
encoder.layer.0.attention.self.query.bias | [768] | 768 | ||
encoder.layer.0.attention.self.key.weight | [768, 768] | 589,824 | ||
encoder.layer.0.attention.self.key.bias | [768] | 768 | ||
encoder.layer.0.attention.self.value.weight | [768, 768] | 589,824 | ||
encoder.layer.0.attention.self.value.bias | [768] | 768 | ||
encoder.layer.0.attention.output.dense.weight | [768, 768] | 589,824 | ||
encoder.layer.0.attention.output.dense.bias | [768] | 768 | ||
encoder.layer.0.attention.output.LayerNorm.weight | [768] | 768 | ||
encoder.layer.0.attention.output.LayerNorm.bias | [768] | 768 | ||
encoder.layer.0.intermediate.dense.weight | [3072, 768] | 2,359,296 | ||
encoder.layer.0.intermediate.dense.bias | [3072] | 3072 | ||
encoder.layer.0.output.dense.weight | [768, 3072] | 2,359,296 | ||
encoder.layer.0.output.dense.bias | [768] | 768 | ||
encoder.layer.0.output.LayerNorm.weight | [768] | 768 | ||
encoder.layer.0.output.LayerNorm.bias | [768] | 768 | ||
Pooler | pooler.dense.weight | [768, 768] | 589,824 | 590,592 |
pooler.dense.bias | [768] | 768 | ||
109,482,240 |
- bert-large-uncased, 340M parameters
Bert-large-uncased | Key | Shape | Count | Count All |
---|---|---|---|---|
Embedding | embeddings.word_embeddings.weight | [30522, 1024] | 31,254,528 | 31,782,912 |
embeddings.position_embeddings.weight | [512, 1024] | 524,288 | ||
embeddings.token_type_embeddings.weight | [2, 1024] | 2,048 | ||
embeddings.LayerNorm.weight | [1024] | 1,024 | ||
embeddings.LayerNorm.bias | [1024] | 1,024 | ||
Transformer * 24 | encoder.layer.0.attention.self.query.weight | [1024, 1024] | 1,048,576 | 12,592,128 * 24 = 302,211,072 |
encoder.layer.0.attention.self.query.bias | [1024] | 1,024 | ||
encoder.layer.0.attention.self.key.weight | [1024, 1024] | 1,048,576 | ||
encoder.layer.0.attention.self.key.bias | [1024] | 1,024 | ||
encoder.layer.0.attention.self.value.weight | [1024, 1024] | 1,048,576 | ||
encoder.layer.0.attention.self.value.bias | [1024] | 1,024 | ||
encoder.layer.0.attention.output.dense.weight | [1024, 1024] | 1,048,576 | ||
encoder.layer.0.attention.output.dense.bias | [1024] | 1,024 | ||
encoder.layer.0.attention.output.LayerNorm.weight | [1024] | 1,024 | ||
encoder.layer.0.attention.output.LayerNorm.bias | [1024] | 1,024 | ||
encoder.layer.0.intermediate.dense.weight | [4096, 1024] | 4,194,304 | ||
encoder.layer.0.intermediate.dense.bias | [4096] | 4,096 | ||
encoder.layer.0.output.dense.weight | [1024, 4096] | 4,194,304 | ||
encoder.layer.0.output.dense.bias | [1024] | 1,024 | ||
encoder.layer.0.output.LayerNorm.weight | [1024] | 1,024 | ||
encoder.layer.0.output.LayerNorm.bias | [1024] | 1,024 | ||
Pooler | pooler.dense.weight | [1024, 1024] | 1,048,576 | 1,049,600 |
pooler.dense.bias | [1024] | 1,024 | ||
335,043,584 |
So does the attention head number get included?
I think the attention head number is chosen such that H / A = 64 for all models, where H is the hidden size and A is the number of attention heads
Thanks @liuqiangict So the query, key and value weights are shared across all the attention heads of the same layer?
这是来自QQ邮箱的假期自动回复邮件。 您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。
Thanks @liuqiangict So the query, key and value weights are shared across all the attention heads of the same layer?
They are different. If they are shared, weight size can be reduced by the number of heads.
这是来自QQ邮箱的假期自动回复邮件。 您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。
So does the attention head number get included?
Yes, It does. Actually, for each head, the attention layer project input (which is [768]) to a small size (which is [64]). There are 12 heads in attention layer. We can see that 64 * 12 = 768. The implementation in transformer do not have 12 head explicitly, otherwise, 12 head was put together which is one linear layer (768 * 768). For the code, actually, they are the same.