一、什么是知识蒸馏

在传统监督学习中,模型学习的是ground-truth,而在知识蒸馏中,学生模型不仅学习ground-truth,还学习 教师模型的输出分布(soft targets),从而获得更丰富的上下文、模糊边界和类间关系信息。

最常见的方式,让学生学习教师的输出概率分布:

$$
\mathcal{L}_{\text{KD}} = T^2 \cdot \text{KL} \left( \text{softmax}\left(\frac{z_t}{T}\right) ,||, \text{softmax}\left(\frac{z_s}{T}\right) \right)
$$

  • $z_t$: 教师模型 logits,$z_s$: 学生模型 logits
  • $T$: 温度(temperature),用于平滑 logits

结合原始 Cross-Entropy loss:

$$
\mathcal{L}{\text{Total}} = \alpha \cdot \mathcal{L}{\text{CE}} + (1 - \alpha) \cdot \mathcal{L}_{\text{KD}}
$$

相应的蒸馏损失代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
# student_logits: Tensor of shape (batch_size, seq_len, vocab_size)
# teacher_logits: Tensor of shape (batch_size, seq_len, vocab_size)

# --- 1. 计算教师概率分布(soft targets) ---
with torch.no_grad():
# teacher_probs: (batch_size, seq_len, vocab_size),每个位置的概率分布,softmax 后每行加和为1
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()

# --- 2. 计算学生的 log softmax ---
# student_log_probs: (batch_size, seq_len, vocab_size),log 概率,用于 KL 散度计算
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)

# --- 3. 计算 KL 散度 ---
# KL(P_teacher || P_student):蒸馏损失,衡量教师分布和学生分布的差距
# 输出是标量(如果 reduction='batchmean'),否则依 reduction 而变
kl = F.kl_div(
student_log_probs, # Input log-probabilities from student
teacher_probs, # Target probabilities from teacher
reduction=reduction # batchmean:KL 总和除以 batch_size
)

# --- 4. 温度缩放补偿 ---
# 因为 logits 除以了 temperature,所以梯度变小了,需要乘 T^2 来保持梯度 scale 一致
return (temperature ** 2) * kl

二、知识蒸馏的训练过程

原始的有监督微调过程只计算了模型预测输出与真实标签之间的损失,引入知识蒸馏后,还需要额外计算(学生)模型预测输出与教师模型输出之间的损失,然后把这两部分加权求和即可。

在MiniMind中,教师模型和学生模型的配置如下(由于并没有非常强大的MiniMind系列模型作为教师模型,这里仅用于演示知识蒸馏的过程):

1
2
lm_config_student = MiniMindConfig(hidden_size=512, num_hidden_layers=8)
lm_config_teacher = MiniMindConfig(hidden_size=768, num_hidden_layers=16)

基于知识蒸馏的训练代码如下,这里仍然复用SFT的数据加载器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
start_time = time.time()

# 设置教师模型为 eval 模式,不参与梯度计算
if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)

# 遍历一个 epoch 中所有的 batch
for step, (X, Y, loss_mask) in enumerate(train_loader):
# X: [batch_size, seq_len],模型输入
# Y: [batch_size, seq_len],ground truth label(一般是右移的 X)
# loss_mask: [batch_size, seq_len],标记有效位置

X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)

# 获取当前 step 的学习率(可选 warmup/cosine)
lr = get_lr(epoch * iter_per_epoch + step,
args.epochs * iter_per_epoch,
args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr

# === 学生模型前向传播 ===
with ctx: # 支持 mixed precision 或 no_grad
res = model(X) # forward pass
student_logits = res.logits
# student_logits: [batch_size, seq_len, vocab_size]

# === 教师模型前向传播 ===
if teacher_model is not None:
with torch.no_grad():
teacher_logits = teacher_model(X).logits
# teacher_logits: [batch_size, seq_len, vocab_size_teacher]

vocab_size_student = student_logits.size(-1)
teacher_logits = teacher_logits[..., :vocab_size_student]
# 裁剪到 vocab_size,避免词表不一致

# === Cross Entropy Loss(标准监督训练)===
loss_mask_flat = loss_mask.view(-1)
# loss_mask_flat: [batch_size * seq_len]

ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)), # [batch_size * seq_len, vocab_size]
Y.view(-1), # [batch_size * seq_len]
ignore_index=0, # 忽略 padding token
reduction='none'
) # ce_loss: [batch_size * seq_len],逐 token loss

ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
# 仅统计有效 token 的平均 loss

# 若模型为 MoE(Mixture of Experts),加上稀疏门控损失
if lm_config_student.use_moe:
ce_loss += res.aux_loss

# === Distillation Loss(知识蒸馏)===
if teacher_model is not None:
# 筛选有效位置再计算 KL 蒸馏 loss
student_logits_flat = student_logits.view(-1, student_logits.size(-1)) # [batch_size * seq_len, vocab_size]
teacher_logits_flat = teacher_logits.view(-1, teacher_logits.size(-1)) # 同上

distill_loss = distillation_loss_fn(
student_logits_flat[loss_mask_flat == 1], # [num_valid_tokens, vocab_size]
teacher_logits_flat[loss_mask_flat == 1], # [num_valid_tokens, vocab_size]
temperature=temperature
) # 一般为 soft cross-entropy 或 KLDivLoss
else:
distill_loss = torch.tensor(0.0, device=args.device)

# === 总损失(加权融合)===
# alpha 越大,越靠 ground-truth;越小,越依赖 teacher
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
# 注意:这里支持梯度累积(accumulation_steps > 1)

# 反向传播(支持 AMP 自动混合精度)
scaler.scale(loss).backward()

# 每 accumulation_steps 次才执行一次优化器更新
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) # 解除 AMP 放大
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) # 梯度裁剪
scaler.step(optimizer) # 优化器更新
scaler.update() # AMP 更新 scaler
optimizer.zero_grad(set_to_none=True) # 清空梯度

三、推理(Reasoning)模型的知识蒸馏

上面的知识蒸馏方式是一种白盒蒸馏,即:学生模型学习的是教师模型输出的logits信息。

还有一种黑盒蒸馏的方式,在这种情况下,学生模型只能学习教师模型的最终输出,而不能洞察到logits信息。

这里将使用黑盒蒸馏的方式,实现推理模型的蒸馏。

3.1 查看推理模型的训练数据格式

1
2
3
4
5
6
7
8
import json
pretrain_dataset_path=r'D:\MyFile\github\minimind-master\minimind_dataset\r1_mix_1024.jsonl'
with open(pretrain_dataset_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
break

print(data)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
{
'conversations':
[
{
'role': 'user',
'content': '请用一段话描述阿里巴巴集团的企业文化。'
},

{
'role': 'assistant',
'content':
'<think>\n嗯,用户让我用一段话描述阿里巴巴集团的企业文化。首先,我需要明确阿里巴巴集团的平均文化是什么。企业文化的定义是企业内部发生、指导管理方向的准则和规范性的精神或规范。\n\n阿里巴巴集团在2000年代中期开始发展,之前也有类似的概念,比如微软的人工智能理念。所以,我可以从这个发展中来寻找关键词和结构。\n\n接下来,我要考虑用户可能的意图。用户可能想了解阿里巴巴的企业文化,或者需要用于某个报告、宣传材料或者内部交流中。因此,我需要确保描述既全面又有深度,涵盖战略、创新、效率等各个方面。\n\n我还需要用简洁的语言来表达,避免过于复杂的词汇,让段落易于理解。同时,要注意逻辑连贯,段落之间要有过渡,让读者能够顺畅地跟随观点。\n\n可能还需要引用一些关键点,比如战略协作、用户体验、创新、效率、数字化转型、客户支持、全球化等,这些都是阿里巴巴成功的关键点,能够帮助用户全面了解企业文化。\n\n最后,检查语法和结构是否正确,确保专业性和准确性,这样用户可以直接使用这段文字,或者稍作修改,就能得到满意的效果。\n</think>\n<answer>\n阿里巴巴集团的企业文化以战略协作为核心,倡导员工之间的资源整合与创新思维的碰撞,旨在实现企业价值的提升和持续发展。阿里巴巴将其视为一个协同共生的整体,致力于通过技术创新、用户体验优化和数字化转型,致力于为客户创造最大价值。企业内外部力量协同作战,推动企业从单纯的业务拓展延伸至价值创造和社会服务,并在全球范围内进行全方位的数字化转型,以满足多样化、个性化、高端化的客户需求。阿里巴巴集团ix platform的建立旨在帮助员工实现高效协作,激发创新精神,吸引更多优秀人才加入,共同推动企业不断向前发展。\n</answer>'
}
]
}

可以看到,推理模型的数据格式和SFT一致,只不过在assistant的回答中多了如下特殊token:

1
2
3
4
<think>
</think>
<answer>
</answer>

<think></think>之间是模型的思考(推理)过程,<answer></answer>之间是模型的回答。在之前介绍的SFT中,assistant的回答只包含<answer></answer>之间的内容。

3.2 推理模型的黑盒蒸馏过程

训练代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def train_epoch(epoch, wandb):
# 处理标签中的特殊占位符 token(用于思考链等任务)
start_of_think_ids = tokenizer('<think>').input_ids
end_of_think_ids = tokenizer('</think>').input_ids
start_of_answer_ids = tokenizer('<answer>').input_ids
end_of_answer_ids = tokenizer('</answer>').input_ids

loss_fct = nn.CrossEntropyLoss(reduction='none') # 不加权平均,为了后续自定义加权
start_time = time.time()

for step, (X, Y, loss_mask) in enumerate(train_loader):
# X: [batch_size, seq_len] - 输入 token 序列
# Y: [batch_size, seq_len] - 目标 token 序列(label)
# loss_mask: [batch_size, seq_len] - loss 权重掩码,1 表示参与 loss,0 表示忽略

X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)

# 更新学习率(逐步衰减)
lr = get_lr(
epoch * iter_per_epoch + step,
args.epochs * iter_per_epoch,
args.learning_rate
)
for param_group in optimizer.param_groups:
param_group['lr'] = lr

with ctx: # 混合精度上下文
res = model(X)
# res.logits: [batch_size, seq_len, vocabulary_size]
# 将 logits reshape 为 [batch_size * seq_len, vocabulary_size]
# 将 labels reshape 为 [batch_size * seq_len]
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)), # [batch_size * seq_len, vocabulary_size]
Y.view(-1) # [batch_size * seq_len]
).view(Y.size()) # reshape 回 [batch_size, seq_len],便于与 mask 对应

# 识别出特殊 token 所在位置,用于增强 loss 权重
special_token_mask = torch.isin(
Y.view(-1),
torch.tensor(
start_of_think_ids +
end_of_think_ids +
start_of_answer_ids +
end_of_answer_ids
).to(args.device)
)

# 增强特殊 token 的 loss 惩罚权重
loss_mask = loss_mask.view(-1) # [batch_size * seq_len]
loss_mask_sum = loss_mask.sum() # 总有效 token 数
loss_mask[special_token_mask] = 10 # 提升特殊 token loss 权重为 10
loss_mask = loss_mask.view(Y.size()) # reshape 回 [batch_size, seq_len]

# 应用 mask 后计算最终 loss
# loss: scalar,单个样本平均 loss(考虑 mask 权重)
loss = (loss * loss_mask).sum() / loss_mask_sum

# 加入辅助 loss(例如知识蒸馏用到的 loss)
loss += res.aux_loss

# 梯度累积
loss = loss / args.accumulation_steps

# 自动混合精度训练
scaler.scale(loss).backward()

# 每 accumulation_steps 步执行一次反向传播更新
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) # 取消缩放用于梯度裁剪
torch.nn.utils.clip_grad_norm_(
model.parameters(),
args.grad_clip
) # 梯度裁剪,防止梯度爆炸

scaler.step(optimizer) # optimizer 更新权重
scaler.update() # 更新缩放器

optimizer.zero_grad(set_to_none=True) # 清空梯度以备下次迭代