一、查看有监督微调数据集格式 1 2 3 4 5 6 7 8 9 import jsonpretrain_dataset_path=r'D:\MyFile\github\minimind-master\minimind_dataset\sft_mini_512.jsonl' with open (pretrain_dataset_path, 'r' , encoding='utf-8' ) as f: for line_num, line in enumerate (f, 1 ): data = json.loads(line.strip()) break print (data.keys()) print (data)
1 2 3 4 5 6 7 8 9 10 11 { 'conversations': [ { 'role': 'user', 'content': '请告诉我在中国古代的“四大发明”是什么?' }, { 'role': 'assistant', 'content': '中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。' } ] }
二、准备有监督微调数据加载器 构建符合PyTorch的Dataloader的Dataset类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 import jsonimport torchfrom torch.utils.data import Datasetclass SFTDataset (Dataset ): def __init__ (self, jsonl_path, tokenizer, max_length=1024 ): super ().__init__() self .tokenizer = tokenizer self .max_length = max_length self .samples = self .load_data(jsonl_path) self .bos_id = tokenizer('<|im_start|>assistant' , add_special_tokens=False ).input_ids self .eos_id = tokenizer('<|im_end|>' , add_special_tokens=False ).input_ids def __len__ (self ): return len (self .samples) def load_data (self, path ): """从 jsonl 文件加载对话数据""" samples = [] with open (path, 'r' , encoding='utf-8' ) as f: for line_num, line in enumerate (f, 1 ): data = json.loads(line.strip()) samples.append(data) return samples def _create_chat_prompt (self, conversations ): """ 将对话轮构造成符合 ChatML 格式的字符串: 每一轮用户/助手对话被标注为 'user' / 'assistant' 最终用 tokenizer 的 apply_chat_template 统一构造 prompt。 """ messages = [] for i, turn in enumerate (conversations): role = 'user' if i % 2 == 0 else 'assistant' messages.append({"role" : role, "content" : turn['content' ]}) return self .tokenizer.apply_chat_template( messages, tokenize=False , add_generation_prompt=False ) def _generate_loss_mask (self, input_ids ): """ 构建损失掩码,只有 assistant 的回答部分才参与 loss 计算。 找出每一段 assistant 的响应,在其 <|im_start|>assistant 和 <|im_end|> 之间设置 loss_mask 为 1。 """ loss_mask = [0 ] * len (input_ids) i = 0 while i < len (input_ids): if input_ids[i:i + len (self .bos_id)] == self .bos_id: start = i + len (self .bos_id) end = start while end < len (input_ids): if input_ids[end:end + len (self .eos_id)] == self .eos_id: break end += 1 for j in range (start + 1 , min (end + len (self .eos_id) + 1 , self .max_length)): loss_mask[j] = 1 i = end + len (self .eos_id) if end < len (input_ids) else len (input_ids) else : i += 1 return loss_mask def __getitem__ (self, index ): sample = self .samples[index] prompt = self ._create_chat_prompt(sample['conversations' ]) input_ids = self .tokenizer(prompt).input_ids[:self .max_length] input_ids += [self .tokenizer.pad_token_id] * (self .max_length - len (input_ids)) loss_mask = self ._generate_loss_mask(input_ids) X = torch.tensor(input_ids[:-1 ], dtype=torch.long) Y = torch.tensor(input_ids[1 :], dtype=torch.long) loss_mask = torch.tensor(loss_mask[1 :], dtype=torch.long) return X, Y, loss_mask
沿着__getitem__方法,逐行向下解析。
2.1 sample = self.samples[index] sample = self.samples[index]用于获取self.samples中对应index的一条数据,这是从原始.jsonl数据集中读取的,如上所述,它只有一个key叫做conversations,取出其value,示例如下:
1 2 3 4 5 6 7 8 9 [ { 'role': 'user', 'content': '请告诉我在中国古代的“四大发明”是什么?' }, { 'role': 'assistant', 'content': '中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。' } ]
2.2 prompt = self._create_chat_prompt(sample['conversations']) self._create_chat_prompt(sample['conversations'])将上述sample应用一种称之为ChatML的模板,它是一种专门为多轮对话任务设计的Prompt模板格式,用于格式化输入,模板如下:
1 2 3 4 { <|im_start|>{{ message['role' ] }} {{ message['content' ] }}<|im_end|> {
上述代码使用了self.tokenizer.apply_chat_template方法来应用ChatML模板,其中tokenize=False表示只返回字符串形式的prompt,不进行分词。add_generation_prompt=False表示是否在最后自动添加<|im_start|>assistant这样的生成提示,用于推理阶段.如果是训练数据(已经包括答案),一般设为 False。
应用ChatML模板后得到的prompt为:
1 '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n请告诉我在中国古代的“四大发明”是什么?<|im_end|>\n<|im_start|>assistant\n中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。<|im_end|>\n'
紧接着对这个prompt使用tokenizer编码成input_ids,并根据最大序列长度进行padding处理。
这里仅对assistant响应位置(也就是assistant回复的内容)计算loss,因此需要找出每一段assistant的响应,在其<|im_start|>assistant和<|im_end|>之间设置loss_mask为1,其余位置的loss_mask均为0。
使用_generate_loss_mask方法实现上述功能。
基本思想就是遍历整个input_ids,查找出现<|im_start|>assistant的位置start,这是模型回复开始的标志;然后继续遍历,找到第一个出现的<|im_end|>的位置end,start到end之间的计算模型的回复,loss_mask设置为1。
如果是多轮对话,就继续往后遍历,查找第二个模型预测开始的位置<|im_start|>assistant,以此类推。
最后,和预训练一样,返回X, Y以及Y对应的loss mask。
现在来构建数据加载器:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 from torch.utils.data import DataLoaderfrom transformers import AutoTokenizermax_length=512 data_path=r'D:\MyFile\github\minimind-master\minimind_dataset\sft_mini_512.jsonl' tokenizer = AutoTokenizer.from_pretrained(r'D:\MyFile\github\minimind-master\model' ) train_ds = SFTDataset(data_path, tokenizer, max_length) train_loader = DataLoader( train_ds, batch_size=2 , pin_memory=True , drop_last=False , shuffle=False , num_workers=0 , )
查看一下有监督微调的数据总量以及数据的维度信息:
1 2 3 4 print (len (train_loader)) for item in train_loader: print ([i.shape for i in item]) break
通过打印看到,有监督微调的数据总量为607362,每一条数据都包含3个PyTorch Tensor,分别是X, Y以及Y对应的padding mask(用于掩掉padding token的loss),shape都是2x511,2是batch_size,511是max_length-1,因为X和Y是正好是偏移一位的。这一点和预训练一样。
三、开始有监督微调 有监督微调代码和常规的模型预训练代码几乎没有区别,直接核心代码段粘贴过来:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 loss_fct = nn.CrossEntropyLoss(reduction='none' ) for step, (X, Y, loss_mask) in enumerate (train_loader): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr' ] = lr with ctx: res = model(X) loss = loss_fct( res.logits.view(-1 , res.logits.size(-1 )), Y.view(-1 ) ).view(Y.size()) loss = (loss * loss_mask).sum () / loss_mask.sum () loss += res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward() if (step + 1 ) % args.accumulation_steps == 0 : scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True )