一、什么是知识蒸馏 在传统监督学习中,模型学习的是ground-truth,而在知识蒸馏中,学生模型不仅学习ground-truth,还学习 教师模型的输出分布(soft targets) ,从而获得更丰富的上下文、模糊边界和类间关系信息。
最常见的方式,让学生学习教师的输出概率分布:
$$ \mathcal{L}_{\text{KD}} = T^2 \cdot \text{KL} \left( \text{softmax}\left(\frac{z_t}{T}\right) ,||, \text{softmax}\left(\frac{z_s}{T}\right) \right) $$
$z_t$: 教师模型 logits,$z_s$: 学生模型 logits
$T$: 温度(temperature),用于平滑 logits
结合原始 Cross-Entropy loss:
$$ \mathcal{L}{\text{Total}} = \alpha \cdot \mathcal{L} {\text{CE}} + (1 - \alpha) \cdot \mathcal{L}_{\text{KD}} $$
相应的蒸馏损失代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 def distillation_loss_fn (student_logits, teacher_logits, temperature=1.0 , reduction='batchmean' ): with torch.no_grad(): teacher_probs = F.softmax(teacher_logits / temperature, dim=-1 ).detach() student_log_probs = F.log_softmax(student_logits / temperature, dim=-1 ) kl = F.kl_div( student_log_probs, teacher_probs, reduction=reduction ) return (temperature ** 2 ) * kl
二、知识蒸馏的训练过程 原始的有监督微调过程只计算了模型预测输出与真实标签之间的损失,引入知识蒸馏后,还需要额外计算(学生)模型预测输出与教师模型输出之间的损失,然后把这两部分加权求和即可。
在MiniMind中,教师模型和学生模型的配置如下(由于并没有非常强大的MiniMind系列模型作为教师模型,这里仅用于演示知识蒸馏的过程):
1 2 lm_config_student = MiniMindConfig(hidden_size=512 , num_hidden_layers=8 ) lm_config_teacher = MiniMindConfig(hidden_size=768 , num_hidden_layers=16 )
基于知识蒸馏的训练代码如下,这里仍然复用SFT的数据加载器:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 def train_epoch (epoch, wandb, alpha=0.0 , temperature=1.0 ): start_time = time.time() if teacher_model is not None : teacher_model.eval () teacher_model.requires_grad_(False ) for step, (X, Y, loss_mask) in enumerate (train_loader): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr' ] = lr with ctx: res = model(X) student_logits = res.logits if teacher_model is not None : with torch.no_grad(): teacher_logits = teacher_model(X).logits vocab_size_student = student_logits.size(-1 ) teacher_logits = teacher_logits[..., :vocab_size_student] loss_mask_flat = loss_mask.view(-1 ) ce_loss = F.cross_entropy( student_logits.view(-1 , student_logits.size(-1 )), Y.view(-1 ), ignore_index=0 , reduction='none' ) ce_loss = torch.sum (ce_loss * loss_mask_flat) / loss_mask_flat.sum () if lm_config_student.use_moe: ce_loss += res.aux_loss if teacher_model is not None : student_logits_flat = student_logits.view(-1 , student_logits.size(-1 )) teacher_logits_flat = teacher_logits.view(-1 , teacher_logits.size(-1 )) distill_loss = distillation_loss_fn( student_logits_flat[loss_mask_flat == 1 ], teacher_logits_flat[loss_mask_flat == 1 ], temperature=temperature ) else : distill_loss = torch.tensor(0.0 , device=args.device) loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps scaler.scale(loss).backward() if (step + 1 ) % args.accumulation_steps == 0 : scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True )
三、推理(Reasoning)模型的知识蒸馏 上面的知识蒸馏方式是一种白盒蒸馏 ,即:学生模型学习的是教师模型输出的logits信息。
还有一种黑盒蒸馏 的方式,在这种情况下,学生模型只能学习教师模型的最终输出,而不能洞察到logits信息。
这里将使用黑盒蒸馏的方式,实现推理模型的蒸馏。
3.1 查看推理模型的训练数据格式 1 2 3 4 5 6 7 8 import jsonpretrain_dataset_path=r'D:\MyFile\github\minimind-master\minimind_dataset\r1_mix_1024.jsonl' with open (pretrain_dataset_path, 'r' , encoding='utf-8' ) as f: for line_num, line in enumerate (f, 1 ): data = json.loads(line.strip()) break print (data)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 { 'conversations': [ { 'role': 'user', 'content': '请用一段话描述阿里巴巴集团的企业文化。' }, { 'role': 'assistant', 'content': '<think>\n嗯,用户让我用一段话描述阿里巴巴集团的企业文化。首先,我需要明确阿里巴巴集团的平均文化是什么。企业文化的定义是企业内部发生、指导管理方向的准则和规范性的精神或规范。\n\n阿里巴巴集团在2000年代中期开始发展,之前也有类似的概念,比如微软的人工智能理念。所以,我可以从这个发展中来寻找关键词和结构。\n\n接下来,我要考虑用户可能的意图。用户可能想了解阿里巴巴的企业文化,或者需要用于某个报告、宣传材料或者内部交流中。因此,我需要确保描述既全面又有深度,涵盖战略、创新、效率等各个方面。\n\n我还需要用简洁的语言来表达,避免过于复杂的词汇,让段落易于理解。同时,要注意逻辑连贯,段落之间要有过渡,让读者能够顺畅地跟随观点。\n\n可能还需要引用一些关键点,比如战略协作、用户体验、创新、效率、数字化转型、客户支持、全球化等,这些都是阿里巴巴成功的关键点,能够帮助用户全面了解企业文化。\n\n最后,检查语法和结构是否正确,确保专业性和准确性,这样用户可以直接使用这段文字,或者稍作修改,就能得到满意的效果。\n</think>\n<answer>\n阿里巴巴集团的企业文化以战略协作为核心,倡导员工之间的资源整合与创新思维的碰撞,旨在实现企业价值的提升和持续发展。阿里巴巴将其视为一个协同共生的整体,致力于通过技术创新、用户体验优化和数字化转型,致力于为客户创造最大价值。企业内外部力量协同作战,推动企业从单纯的业务拓展延伸至价值创造和社会服务,并在全球范围内进行全方位的数字化转型,以满足多样化、个性化、高端化的客户需求。阿里巴巴集团ix platform的建立旨在帮助员工实现高效协作,激发创新精神,吸引更多优秀人才加入,共同推动企业不断向前发展。\n</answer>' } ] }
可以看到,推理模型的数据格式和SFT一致,只不过在assistant的回答中多了如下特殊token:
1 2 3 4 <think> </think> <answer> </answer>
<think>和</think>之间是模型的思考(推理)过程,<answer>和</answer>之间是模型的回答。在之前介绍的SFT中,assistant的回答只包含<answer>和</answer>之间的内容。
3.2 推理模型的黑盒蒸馏过程 训练代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 def train_epoch (epoch, wandb ): start_of_think_ids = tokenizer('<think>' ).input_ids end_of_think_ids = tokenizer('</think>' ).input_ids start_of_answer_ids = tokenizer('<answer>' ).input_ids end_of_answer_ids = tokenizer('</answer>' ).input_ids loss_fct = nn.CrossEntropyLoss(reduction='none' ) start_time = time.time() for step, (X, Y, loss_mask) in enumerate (train_loader): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) lr = get_lr( epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate ) for param_group in optimizer.param_groups: param_group['lr' ] = lr with ctx: res = model(X) loss = loss_fct( res.logits.view(-1 , res.logits.size(-1 )), Y.view(-1 ) ).view(Y.size()) special_token_mask = torch.isin( Y.view(-1 ), torch.tensor( start_of_think_ids + end_of_think_ids + start_of_answer_ids + end_of_answer_ids ).to(args.device) ) loss_mask = loss_mask.view(-1 ) loss_mask_sum = loss_mask.sum () loss_mask[special_token_mask] = 10 loss_mask = loss_mask.view(Y.size()) loss = (loss * loss_mask).sum () / loss_mask_sum loss += res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward() if (step + 1 ) % args.accumulation_steps == 0 : scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( model.parameters(), args.grad_clip ) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True )