DPO(Direct Preference Optimization) 是一种用于有监督指令微调后模型偏好对齐的训练方法,目标是让模型更倾向于输出人类偏好的回答(chosen),而不是次优回答(rejected)。
一、查看DPO训练数据集格式 1 2 3 4 5 6 7 8 9 import jsonpretrain_dataset_path=r'D:\MyFile\github\minimind-master\minimind_dataset\dpo.jsonl' with open (pretrain_dataset_path, 'r' , encoding='utf-8' ) as f: for line_num, line in enumerate (f, 1 ): data = json.loads(line.strip()) break print (data.keys()) print (data)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 { 'chosen': [ { 'content': 'How many moles of HBr are required to react ...', 'role': 'user' }, { 'content': 'To determine the number of moles of HBr ...', 'role': 'assistant' } ], 'rejected': [ { 'content': 'How many moles of HBr are required to react ...', 'role': 'user' }, { 'content': 'To answer this question, we need to write ...', 'role': 'assistant' } ] }
用于DPO训练的数据集中,每一条是数据都包含至少两个assistant回答,一个优、一个劣,“chosen”对应优,“rejected”对应劣。
在DPO训练时,模型会学习让“chosen”回答的概率高于“rejected”回答,从而实现偏好对齐。
二、准备DPO训练数据加载器 构建符合PyTorch的Dataloader的Dataset类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 import jsonimport torchfrom torch.utils.data import Datasetclass DPODataset (Dataset ): def __init__ (self, file_path, tokenizer, max_length=4096 ): super ().__init__() self .tokenizer = tokenizer self .max_length = max_length self .padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self .bos_id = tokenizer('<|im_start|>assistant' , add_special_tokens=False ).input_ids self .eos_id = tokenizer('<|im_end|>' , add_special_tokens=False ).input_ids with open (file_path, 'r' , encoding='utf-8' ) as f: self .data = [] for line in f: line = line.strip() obj = json.loads(line) self .data.append(obj) def __len__ (self ): return len (self .data) def __getitem__ (self, index ): item = self .data[index] chosen = item['chosen' ] rejected = item['rejected' ] chosen_prompt = self .tokenizer.apply_chat_template( chosen, tokenize=False , add_generation_prompt=False ) rejected_prompt = self .tokenizer.apply_chat_template( rejected, tokenize=False , add_generation_prompt=False ) chosen_encoding = self .tokenizer( chosen_prompt, truncation=True , max_length=self .max_length, padding='max_length' ) rejected_encoding = self .tokenizer( rejected_prompt, truncation=True , max_length=self .max_length, padding='max_length' ) chosen_input_ids = chosen_encoding['input_ids' ] rejected_input_ids = rejected_encoding['input_ids' ] chosen_loss_mask = self ._generate_loss_mask(chosen_input_ids) rejected_loss_mask = self ._generate_loss_mask(rejected_input_ids) x_chosen = torch.tensor(chosen_input_ids[:-1 ], dtype=torch.long) y_chosen = torch.tensor(chosen_input_ids[1 :], dtype=torch.long) mask_chosen = torch.tensor(chosen_loss_mask[1 :], dtype=torch.long) x_rejected = torch.tensor(rejected_input_ids[:-1 ], dtype=torch.long) y_rejected = torch.tensor(rejected_input_ids[1 :], dtype=torch.long) mask_rejected = torch.tensor(rejected_loss_mask[1 :], dtype=torch.long) return { 'x_chosen' : x_chosen, 'y_chosen' : y_chosen, 'mask_chosen' : mask_chosen, 'x_rejected' : x_rejected, 'y_rejected' : y_rejected, 'mask_rejected' : mask_rejected } def _generate_loss_mask (self, input_ids ): """ 根据 <|im_start|>assistant 和 <|im_end|> 的位置标记哪些 token 应该参与损失计算。 返回一个和 input_ids 等长的 0/1 mask。 """ loss_mask = [0 ] * len (input_ids) i = 0 while i < len (input_ids): if input_ids[i:i + len (self .bos_id)] == self .bos_id: start = i + len (self .bos_id) end = start while end < len (input_ids): if input_ids[end:end + len (self .eos_id)] == self .eos_id: break end += 1 for j in range (start + 1 , min (end + len (self .eos_id) + 1 , self .max_length)): loss_mask[j] = 1 i = end + len (self .eos_id) if end < len (input_ids) else len (input_ids) else : i += 1 return loss_mask
DPODataset和之前的SFTDataset的处理逻辑是完全一致的,只不过DPODataset中需要处理两次(chosen和rejected),因此上述代码中包含的函数介绍可以去看SFTDataset,这里不再重复介绍。
三、DPO 损失函数 DPO的目标是让训练后模型更偏好人类认为更好的答案(chosen),而不是差的答案(rejected),并且这种偏好是在对比参考模型(refrence model)的基础上学来的。
这里的参考模型,一般指的是微调前的模型,比如做了预训练和SFT之后的模型。
参考:https://allam.vercel.app/post/dpo/
DPO旨在以一种更简单、更稳定的方式替代传统RLHF中复杂的奖励建模过程。它的核心在于:使用一个直接可微的损失函数,来优化模型对人类偏好的响应倾向,而无需训练单独的奖励模型或使用复杂的强化学习方法(如PPO)。
具体来说,DPO在一对偏好样本上进行优化:它增加人类偏好响应中token的对数概率,同时减少非偏好响应中的对数概率,从而促使模型产生更符合人类意图的输出。
从数学角度看,这一过程相当于为模型引入了一个隐式奖励函数,该函数通过log-ratio的差值衡量当前策略相对于参考策略对人类偏好的一致程度,并直接用于梯度优化。
设:
$\pi$ 是当前模型(policy model)
$\pi_\text{ref}$ 是参考模型(reference model)
$x$ 是输入 prompt
$y^+$ 是人类偏好的回答(chosen)
$y^-$ 是次优回答(rejected)
$\beta$ 是温度超参(调节梯度幅度)
DPO loss 如下:
$$ \mathcal{L}{\text{DPO}} = \mathbb{E} {(x, y^+, y^-) \sim \mathcal{D}} \left[ -\log \sigma \left( \beta \cdot \left( \log \frac{\pi(y^+|x)}{\pi_{\text{ref}}(y^+|x)} - \log \frac{\pi(y^-|x)}{\pi_{\text{ref}}(y^-|x)} \right) \right) \right] $$
其中 $\sigma$ 是 sigmoid 函数。
在上述公式的log差值项中,前一个表示模型对于人类偏好chosen回复$y^+$的对数概率,后一个表示模型对于rejected回复$y^-$的对数概率,DPO loss的目标是最大化两者的差值,也就是鼓励模型$\pi$相较于$\pi_\text{ref}$更加偏好$y^+$而非$y^-$。其中除以$\pi_\text{ref}$的作用是作为一个正则化因子,确保训练后的模型过度偏离原始模型。
在MiniMind的代码实现中,根据对数运算的性质,调换了DPO loss中的对数项顺序,如下: $$ \mathcal{L}{\text{DPO}} = \mathbb{E} {(x, y^+, y^-) \sim \mathcal{D}} \left[ -\log \sigma \left( \beta \cdot \left( \log \frac{\pi(y^+|x)}{\pi(y^-|x)} - \log \frac{\pi_{\text{ref}}(y^+|x)}{\pi_{\text{ref}}(y^-|x)} \right) \right) \right] $$
代码实现上述DPO loss:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 def dpo_loss (ref_probs, probs, mask, beta ): seq_lengths = mask.sum (dim=1 , keepdim=True ) ref_probs = (ref_probs * mask).sum (dim=1 ) / seq_lengths.squeeze(1 ) probs = (probs * mask).sum (dim=1 ) / seq_lengths.squeeze(1 ) batch_size = ref_probs.shape[0 ] chosen_ref_probs = ref_probs[:batch_size // 2 ] reject_ref_probs = ref_probs[batch_size // 2 :] chosen_probs = probs[:batch_size // 2 ] reject_probs = probs[batch_size // 2 :] pi_logratios = chosen_probs - reject_probs ref_logratios = chosen_ref_probs - reject_ref_probs logits = pi_logratios - ref_logratios loss = -F.logsigmoid(beta * logits) return loss.mean()
在Step 3中,之所以取batch的前后一半分别作为chosen和rejected,是因为在MiniMind的train函数中,为了并行执行训练,对chosen和rejected做了拼接(在数据加载器中做了padding,因此可以拼接),相应的代码在下一节展示。
四、开始训练DPO 训练DPO的代码在SFT训练代码的基础上,将交叉熵损失换成了DPO loss,如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 for step, batch in enumerate (train_loader): x_chosen = batch['x_chosen' ].to(args.device) x_rejected = batch['x_rejected' ].to(args.device) y_chosen = batch['y_chosen' ].to(args.device) y_rejected = batch['y_rejected' ].to(args.device) mask_chosen = batch['mask_chosen' ].to(args.device) mask_rejected = batch['mask_rejected' ].to(args.device) x = torch.cat([x_chosen, x_rejected], dim=0 ) y = torch.cat([y_chosen, y_rejected], dim=0 ) mask = torch.cat([mask_chosen, mask_rejected], dim=0 ) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr' ] = lr with ctx: with torch.no_grad(): ref_outputs = ref_model(x) ref_logits = ref_outputs.logits ref_probs = logits_to_probs(ref_logits, y) ref_probs = ref_probs * mask outputs = model(x) logits = outputs.logits probs = logits_to_probs(logits, y) probs = probs * mask loss = dpo_loss(ref_probs, probs, mask, beta=0.1 ) loss = loss / args.accumulation_steps scaler.scale(loss).backward() if (step + 1 ) % args.accumulation_steps == 0 : scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True )
上述代码中有一个函数logits_to_probs,可以将输入的logits(shape为[2 x batch_size, seq_len, vocab_size])转换成输出的对数概率probs(shape为[2 x batch_size, seq_len]),其定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def logits_to_probs (logits, labels ): log_probs = F.log_softmax(logits, dim=2 ) probs = torch.gather(log_probs, dim=2 , index=labels.unsqueeze(2 )) probs = probs.squeeze(-1 ) return probs
输入的logits表示模型在该位置预测下一个token是词表中某个词的原始分数,shape为[batch_size, seq_len, vocab_size]。
第一步,将logits使用log_softmax转换为对数概率log_probs,即log_probs表示模型在该位置预测下一个token是词表中某个词的对数概率,shape不变。
第二步,通过torch.gather,从log_probs中查询输入的真实标签labels中每个token对应位置的log概率,shape为[batch_size, seq_len],这是每个位置上真实标签的模型预测对数概率,也就是DPO loss的输入。