一、查看有监督微调数据集格式

1
2
3
4
5
6
7
8
9
import json
pretrain_dataset_path=r'D:\MyFile\github\minimind-master\minimind_dataset\sft_mini_512.jsonl'
with open(pretrain_dataset_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
break

print(data.keys()) # dict_keys(['text'])
print(data)
1
2
3
4
5
6
7
8
9
10
11
{
'conversations':
[
{
'role': 'user', 'content': '请告诉我在中国古代的“四大发明”是什么?'
},
{
'role': 'assistant', 'content': '中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。'
}
]
}

二、准备有监督微调数据加载器

构建符合PyTorch的Dataloader的Dataset类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import json
import torch
from torch.utils.data import Dataset

class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer # 分词器
self.max_length = max_length # 最大输入长度(会进行截断或填充)
self.samples = self.load_data(jsonl_path) # 加载数据样本
self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids# [1, 1078, 538, 501], [1]是<|im_start|>这个特殊token的id,[1078, 538, 501]是assistant的分词id
self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids# [2]

def __len__(self):
return len(self.samples) # 返回样本数量

def load_data(self, path):
"""从 jsonl 文件加载对话数据"""
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip()) # 每行为一个 JSON 对象
samples.append(data)
return samples

def _create_chat_prompt(self, conversations):
"""
将对话轮构造成符合 ChatML 格式的字符串:
每一轮用户/助手对话被标注为 'user' / 'assistant'
最终用 tokenizer 的 apply_chat_template 统一构造 prompt。
"""
messages = []
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant' # 偶数轮为用户,奇数轮为助手
messages.append({"role": role, "content": turn['content']})

# 返回字符串形式的 prompt,而非直接 tokenize
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)

def _generate_loss_mask(self, input_ids):
"""
构建损失掩码,只有 assistant 的回答部分才参与 loss 计算。
找出每一段 assistant 的响应,在其 <|im_start|>assistant 和 <|im_end|> 之间设置 loss_mask 为 1。
"""
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
# 找 assistant 开头标志
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id) # 答案起点
end = start
while end < len(input_ids):
# 查找 assistant 的回答终止符 <|im_end|>
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
# 为 assistant 回答部分(从 start + 1 到 end 之间)设置 loss mask
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
# 跳过到下一个 segment
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask

def __getitem__(self, index):
sample = self.samples[index]

# 构建 ChatML 格式 prompt(字符串)
prompt = self._create_chat_prompt(sample['conversations'])

# 分词并截断,确保长度 <= max_length
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]

# 右侧填充 pad_token 直到 max_length 长度
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))

# 生成动态 loss mask,仅对 assistant 响应位置计算 loss
loss_mask = self._generate_loss_mask(input_ids)

# 构建训练样本:
# 模型输入为前 n-1 个 token,预测目标为第 2 到第 n 个 token
X = torch.tensor(input_ids[:-1], dtype=torch.long) # 输入序列
Y = torch.tensor(input_ids[1:], dtype=torch.long) # 目标标签(shifted)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐 Y 的位置(从第一个预测 token 开始)

return X, Y, loss_mask

沿着__getitem__方法,逐行向下解析。

2.1 sample = self.samples[index]

sample = self.samples[index]用于获取self.samples中对应index的一条数据,这是从原始.jsonl数据集中读取的,如上所述,它只有一个key叫做conversations,取出其value,示例如下:

1
2
3
4
5
6
7
8
9
[
{
'role': 'user', 'content': '请告诉我在中国古代的“四大发明”是什么?'
},

{
'role': 'assistant', 'content': '中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。'
}
]

2.2 prompt = self._create_chat_prompt(sample['conversations'])

self._create_chat_prompt(sample['conversations'])将上述sample应用一种称之为ChatML的模板,它是一种专门为多轮对话任务设计的Prompt模板格式,用于格式化输入,模板如下:

1
2
3
4
{% for message in messages %}
<|im_start|>{{ message['role'] }}
{{ message['content'] }}<|im_end|>
{% endfor %}

上述代码使用了self.tokenizer.apply_chat_template方法来应用ChatML模板,其中tokenize=False表示只返回字符串形式的prompt,不进行分词。add_generation_prompt=False表示是否在最后自动添加<|im_start|>assistant这样的生成提示,用于推理阶段.如果是训练数据(已经包括答案),一般设为 False。

应用ChatML模板后得到的prompt为:

1
'<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n请告诉我在中国古代的“四大发明”是什么?<|im_end|>\n<|im_start|>assistant\n中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。<|im_end|>\n'

紧接着对这个prompt使用tokenizer编码成input_ids,并根据最大序列长度进行padding处理。

2.3 loss_mask = self._generate_loss_mask(input_ids)

这里仅对assistant响应位置(也就是assistant回复的内容)计算loss,因此需要找出每一段assistant的响应,在其<|im_start|>assistant<|im_end|>之间设置loss_mask为1,其余位置的loss_mask均为0。

使用_generate_loss_mask方法实现上述功能。

基本思想就是遍历整个input_ids,查找出现<|im_start|>assistant的位置start,这是模型回复开始的标志;然后继续遍历,找到第一个出现的<|im_end|>的位置end,start到end之间的计算模型的回复,loss_mask设置为1。

如果是多轮对话,就继续往后遍历,查找第二个模型预测开始的位置<|im_start|>assistant,以此类推。

最后,和预训练一样,返回X, Y以及Y对应的loss mask。

现在来构建数据加载器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.utils.data import DataLoader
from transformers import AutoTokenizer


max_length=512
data_path=r'D:\MyFile\github\minimind-master\minimind_dataset\sft_mini_512.jsonl'
tokenizer = AutoTokenizer.from_pretrained(r'D:\MyFile\github\minimind-master\model')
train_ds = SFTDataset(data_path, tokenizer, max_length)

train_loader = DataLoader(
train_ds,
batch_size=2,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=0,
)

查看一下有监督微调的数据总量以及数据的维度信息:

1
2
3
4
print(len(train_loader)) # 607362
for item in train_loader:
print([i.shape for i in item]) # [torch.Size([2, 511]), torch.Size([2, 511]), torch.Size([2, 511])]
break

通过打印看到,有监督微调的数据总量为607362,每一条数据都包含3个PyTorch Tensor,分别是X, Y以及Y对应的padding mask(用于掩掉padding token的loss),shape都是2x511,2是batch_size,511是max_length-1,因为X和Y是正好是偏移一位的。这一点和预训练一样。

三、开始有监督微调

有监督微调代码和常规的模型预训练代码几乎没有区别,直接核心代码段粘贴过来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
loss_fct = nn.CrossEntropyLoss(reduction='none')
for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr

with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())

loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps

scaler.scale(loss).backward()

if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

scaler.step(optimizer)
scaler.update()

optimizer.zero_grad(set_to_none=True)