[大模型面试]手推 Bert 模型的参数量
2024-11-27

Bert 模型主要用的倒是左侧编码器部分
这几个关键参数需要牢记
让我们分层计算参数:
Embedding 层参数:
- Token Embeddings: V × H = 30,522 × 768
- Position Embeddings: 512 × 768 (最大序列长度为512 [Max_position_length])
- Segment Embeddings: 2 × 768 (句子A/B嵌入)
- Layer Normalization: 2 × 768 (gamma和beta参数)
每个Transformer层的参数:
Self-Attention:
- Query/Key/Value matrices: 3 × (768 × 768)
- Output projection: 768 × 768
- Attention层的偏置项: 4 × 768
Feed Forward Network:
这里包含两个全连接层,首先由768维升到3072 维,然后再由3072 维度降到 768 维「4*d_model 是 transformers 的标准中间层标准」
Input (768维) → FFN1 (3072维) → GELU → FFN2 (768维)
在高维空间中:
[数据点之间的距离更容易区分]
[特征之间的关系更容易被线性分离]
[非线性变换(如GELU)在高维空间中更有效]
First layer: 768 × 3072 + 3072 (权重+偏置)
Second layer: 3072 × 768 + 768 (权重+偏置)
Layer Normalization (2个):
每个Layer Norm: 2 × 768 (gamma和beta参数)
其实这里的 FFN 操作也可以类比人的思考过程:
1 先展开思考(768→3072):考虑所有可能性
2 深入分析(GELU):在更大的空间中处理信息
3 总结归纳(3072→768):提取最重要的信息
可以通过代码进行计算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
| def calculate_bert_base_params():
# Model dimensions
hidden_size = 768
num_layers = 12
vocab_size = 30522
max_position = 512
# Embedding parameters
embedding_params = (
vocab_size * hidden_size + # Token embeddings
max_position * hidden_size + # Position embeddings
2 * hidden_size + # Segment embeddings (Segment A/ Segment B)->用于标志句子的先后顺序,用于 NSP 任务
2 * hidden_size # Layer norm parameters
)
# Parameters for one transformer layer
attention_params = (
4 * hidden_size * hidden_size + # QKV matrices + output projection
4 * hidden_size # Biases
)
ffn_params = (
hidden_size * (4 * hidden_size) + # First FFN layer (hidden_size -> 4*hidden_size)
4 * hidden_size + # First layer bias
(4 * hidden_size) * hidden_size + # Second FFN layer (4*hidden_size -> hidden_size)
hidden_size # Second layer bias
)
layer_norm_params = 4 * hidden_size # 2 layer norms per transformer layer
# Parameters per layer
params_per_layer = attention_params + ffn_params + layer_norm_params
# Total transformer layers parameters
transformer_params = params_per_layer * num_layers
# Total parameters
total_params = embedding_params + transformer_params
return total_params
total = calculate_bert_base_params()
print(f"BERT-base total parameters: {total:,}")
|