一、从DINO到DINOv2

DINO发布后,iBOT在DINO的基础上进行了扩展,引入了Patch-level蒸馏等新特性。随后,DINOv2又在iBOT的思想基础上进行了进一步改进,提升了训练稳定性和特征表达能力。

可以说:

1
2
3
4
DINO → iBOT → DINOv2
iBOT = DINO + Patch-level 蒸馏等改进
DINOv2 = iBOT + 训练稳定性 & 特征表达能力提升等技巧

关于DINO的内容在上一篇文章中已有介绍,因此本文首先介绍iBOT,再讲DINOv2的独特trick。

二、iBOT

iBOT延续了DINO的全部思想,包括教师网络EMA更新参数,[CLS] token 蒸馏对齐等。如果你去查看官方github代码,会发现iBOT的代码就是基于DINO的代码框架进行二次修改的。此外,iBOT还引入了类BERT的MLM训练思路,接下来对这些方法进行逐一讲解。

2.1 和DINO一致的多视角增强策略

iBOT的数据增强逻辑和DINO是一致的,直接上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class DataAugmentationiBOT(object):
def __init__(self, global_crops_scale, local_crops_scale, global_crops_number, local_crops_number):
# 基础增强:翻转 + 颜色抖动 + 灰度
flip_and_color_jitter = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 水平翻转
transforms.RandomApply( # 随机应用颜色扰动
[transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8
),
transforms.RandomGrayscale(p=0.2), # 随机转灰度
])

# 标准化
normalize = transforms.Compose([
transforms.ToTensor(), # [H, W, C] (0~255) → [C, H, W] (0~1)
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)), # ImageNet 统计
])

self.global_crops_number = global_crops_number

# 第 1 张全局 crop(强增强)
self.global_transfo1 = transforms.Compose([
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
# 输出 [3, 224, 224]
flip_and_color_jitter,
utils.GaussianBlur(1.0), # 必定高斯模糊
normalize,
])

# 其他全局 crop(弱增强 + solarization)
self.global_transfo2 = transforms.Compose([
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
flip_and_color_jitter,
utils.GaussianBlur(0.1), # 低概率模糊
utils.Solarization(0.2), # 随机太阳化
normalize,
])

# 局部 crop(小 patch,96×96)
self.local_crops_number = local_crops_number
self.local_transfo = transforms.Compose([
transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
# 输出 [3, 96, 96]
flip_and_color_jitter,
utils.GaussianBlur(p=0.5),
normalize,
])

def __call__(self, image):
"""
输入:
image: 原始 PIL Image (H, W, 3)
输出:
crops: 一个 list,包含 global_crops_number + local_crops_number 张增强后的图像
- 前 global_crops_number 张:224x224 全局视图
- 后 local_crops_number 张:96x96 局部视图
每张图 shape = [3, H, W] (Tensor)
"""
crops = []
# 第一张全局 crop
crops.append(self.global_transfo1(image))
# 其余全局 crop
for _ in range(self.global_crops_number - 1):
crops.append(self.global_transfo2(image))
# 局部 crop
for _ in range(self.local_crops_number):
crops.append(self.local_transfo(image))
return crops

2.2 在DINO的CLS token 蒸馏基础上,新增类BERT的patch token蒸馏

DINO里,蒸馏的对象是[CLS] token,也就是整张图的全局表征,目标是让学生网络的全局特征去拟合教师网络的全局特征。
而iBOT在此基础上,更进一步,不仅蒸馏[CLS] token,还蒸馏patch token(局部表征)。

具体做法和BERT的Masked Language Modeling(MLM)类似,即:

在学生网络的输入图像中,随机选择一部分patch进行mask,而教师网络则看到完整图像。学生网络需要预测被mask掉的patch的embedding,目标是尽可能接近教师网络在这些位置上的patch embedding。

这是一种高级语义向量的预测,而不像MAE那样直接预测具体的像素值。

因此iBOT的总损失由两部分组成:

  • CLS-level Loss(和DINO一致,跨视图蒸馏全局语义)
  • Patch-level Loss(类似BERT,预测缺失局部patch表征)

总的loss为两者的加权和,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class IBOTLoss(nn.Module):
def __init__(self, out_dim, patch_out_dim, lambda_patch=1.0, student_temp=0.1, teacher_temp=0.04, center_momentum=0.9):
super().__init__()
self.out_dim = out_dim
self.patch_out_dim = patch_out_dim
self.lambda_patch = lambda_patch # patch loss 权重
self.student_temp = student_temp
self.teacher_temp = teacher_temp
self.center_momentum = center_momentum

# 初始化教师中心
self.register_buffer("center_cls", torch.zeros(1, out_dim)) # [1, out_dim]
self.register_buffer("center_patch", torch.zeros(1, patch_out_dim)) # [1, patch_out_dim]

def forward(self, student_cls, teacher_cls, student_patch, teacher_patch, epoch):
"""
student_cls: [ (2+N)*B , out_dim ] # 学生 CLS 输出
teacher_cls: [ 2*B , out_dim ] # 教师 CLS 输出
student_patch: [ (2+N)*B , P , patch_out_dim ] # 学生 patch token 输出
teacher_patch: [ 2*B , P , patch_out_dim ] # 教师 patch token 输出


out_dim 是image-->backbone-->fc head的维度
patch_out_dim 是image-->backbone的维度
"""

# =======================
# 1. CLS loss (全局 view)
# =======================
# 对教师 CLS 做中心化和温度缩放
teacher_cls_centered = (teacher_cls - self.center_cls) / self.teacher_temp
# shape: [2*B, out_dim]
teacher_cls_soft = torch.softmax(teacher_cls_centered, dim=-1)
# shape: [2*B, out_dim]

# 对学生 CLS 做温度缩放
student_cls_scaled = student_cls / self.student_temp
# shape: [(2+N)*B, out_dim]

# 计算 CLS 交叉熵 loss
# 学生所有视图 vs 教师全局视图
cls_loss = -(teacher_cls_soft.detach() * F.log_softmax(student_cls_scaled, dim=-1)).sum(dim=-1).mean()
# cls_loss: scalar

# =======================
# 2. Patch loss (局部 token)
# =======================
# 随机 mask 掉部分 patch (mask_prob 可调)
mask = (torch.rand(student_patch.shape[0], student_patch.shape[1], device=student_patch.device) > 0.15)
# shape: [(2+N)*B, P] bool

# 教师 patch 也做中心化 + 温度缩放
teacher_patch_centered = (teacher_patch - self.center_patch) / self.teacher_temp
# shape: [2*B, P, patch_out_dim]
teacher_patch_soft = torch.softmax(teacher_patch_centered, dim=-1)
# shape: [2*B, P, patch_out_dim]

# 学生 patch 温度缩放
student_patch_scaled = student_patch / self.student_temp
# shape: [(2+N)*B, P, patch_out_dim]

# 只计算 mask 掉的部分 loss
patch_loss = -(teacher_patch_soft.detach() * F.log_softmax(student_patch_scaled, dim=-1))
patch_loss = (patch_loss.sum(dim=-1) * mask).sum() / mask.sum()
# patch_loss: scalar

# =======================
# 3. 总 loss
# =======================
loss = cls_loss + self.lambda_patch * patch_loss

return loss

@torch.no_grad()
def update_center(self, teacher_cls, teacher_patch):
"""
更新教师中心向量 (EMA)
teacher_cls: [2*B, out_dim]
teacher_patch: [2*B, P, patch_out_dim]
"""
batch_center_cls = teacher_cls.mean(dim=0, keepdim=True) # [1, out_dim]
batch_center_patch = teacher_patch.mean(dim=(0,1), keepdim=True) # [1, patch_out_dim]

# EMA 更新中心
self.center_cls = self.center_cls * self.center_momentum + batch_center_cls * (1 - self.center_momentum)
self.center_patch = self.center_patch * self.center_momentum + batch_center_patch * (1 - self.center_momentum)

注意,由于iBOT多了一个patch token的预测,因此不仅需要维护教师网络的cls token center,还要额外维护一个patch token center。

上述是一个简化版本的代码,便于理解,完整的iBOTLoss代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class iBOTLoss(nn.Module):
def __init__(self, out_dim, patch_out_dim, ngcrops, nlcrops,
warmup_teacher_temp, teacher_temp,
warmup_teacher_temp2, teacher_temp2,
warmup_teacher_temp_epochs, nepochs,
student_temp=0.1, center_momentum=0.9, center_momentum2=0.9,
lambda1=1.0, lambda2=1.0, mim_start_epoch=0):
super().__init__()

# ---------------------- 参数 ----------------------
self.student_temp = student_temp
self.center_momentum = center_momentum
self.center_momentum2 = center_momentum2
self.ngcrops = ngcrops
self.nlcrops = nlcrops
self.ncrops = ngcrops + nlcrops
self.lambda1 = lambda1
self.lambda2 = lambda2

# ---------------------- teacher center ----------------------
self.register_buffer("center", torch.zeros(1, out_dim))
# center: [1, out_dim] # CLS token 平均中心
self.register_buffer("center2", torch.zeros(1, 1, patch_out_dim))
# center2: [1, 1, patch_out_dim] # patch token 平均中心

# ---------------------- teacher 温度调度 ----------------------
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
))
# shape: [nepochs]

self.teacher_temp2_schedule = (
np.concatenate((
np.linspace(warmup_teacher_temp2, teacher_temp2, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp2
)) if mim_start_epoch == 0 else np.concatenate((
np.ones(mim_start_epoch) * warmup_teacher_temp2,
np.linspace(warmup_teacher_temp2, teacher_temp2, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs - mim_start_epoch) * teacher_temp2
))
)
# shape: [nepochs]

def forward(self, student_output, teacher_output, student_local_cls, student_mask, epoch):
"""
计算 student 与 teacher 的蒸馏 loss
"""

# ---------------------- unpack student ----------------------
student_cls, student_patch = student_output
# student_cls: [(ng+nl)*B, out_dim] # CLS token
# student_patch: [(ng+nl)*B, num_patches, patch_out_dim] # patch token

# ---------------------- unpack teacher ----------------------
teacher_cls, teacher_patch = teacher_output
# teacher_cls: [ng*B, out_dim] # 仅全局 crop
# teacher_patch: [ng*B, num_patches, patch_out_dim] # 仅全局 crop

# ---------------------- 拼接 student_local_cls ----------------------
if student_local_cls is not None:
student_cls = torch.cat([student_cls, student_local_cls])
# student_cls: [(ng+nl)*B, out_dim]

# ---------------------- 温度缩放 + 分块 ----------------------
student_cls = student_cls / self.student_temp
student_cls_c = student_cls.chunk(self.ncrops)
# student_cls_c: list, len = ng+nl, 每个 [B, out_dim]

student_patch = student_patch / self.student_temp
student_patch_c = student_patch.chunk(self.ngcrops)
# student_patch_c: list, len = ng, 每个 [B, num_patches, patch_out_dim]

# ---------------------- teacher softmax + 分块 ----------------------
temp = self.teacher_temp_schedule[epoch]
temp2 = self.teacher_temp2_schedule[epoch]

teacher_cls_c = F.softmax((teacher_cls - self.center) / temp, dim=-1)
teacher_cls_c = teacher_cls_c.detach().chunk(self.ngcrops)
# teacher_cls_c: list, len = ng, 每个 [B, out_dim]

teacher_patch_c = F.softmax((teacher_patch - self.center2) / temp2, dim=-1)
teacher_patch_c = teacher_patch_c.detach().chunk(self.ngcrops)
# teacher_patch_c: list, len = ng, 每个 [B, num_patches, patch_out_dim]

# ---------------------- loss 初始化 ----------------------
total_loss1, n_loss_terms1 = 0, 0 # cls loss
total_loss2, n_loss_terms2 = 0, 0 # patch loss

# ---------------------- 逐视图计算 ----------------------
for q in range(len(teacher_cls_c)): # q ∈ [0..ng-1] teacher global crop
for v in range(len(student_cls_c)): # v ∈ [0..ng+nl-1] student 所有 crops
if v == q:# 只有两个全局增强视图才执行mask操作
# patch-level loss
# teacher_patch_c[q]: [B, num_patches, patch_out_dim]
# student_patch_c[v]: [B, num_patches, patch_out_dim]
# student_mask[v]: [B, num_patches]
loss2 = torch.sum(
-teacher_patch_c[q] * F.log_softmax(student_patch_c[v], dim=-1),
dim=-1
) # [B, num_patches]

mask = student_mask[v].flatten(-2, -1) # [B, num_patches]
loss2 = torch.sum(loss2 * mask.float(), dim=-1) / mask.sum(dim=-1).clamp(min=1.0)
# [B]

total_loss2 += loss2.mean() # scalar
n_loss_terms2 += 1
else:# 同DINO
# cls-level loss
# teacher_cls_c[q]: [B, out_dim]
# student_cls_c[v]: [B, out_dim]
loss1 = torch.sum(
-teacher_cls_c[q] * F.log_softmax(student_cls_c[v], dim=-1),
dim=-1
) # [B]

total_loss1 += loss1.mean() # scalar
n_loss_terms1 += 1

# ---------------------- 汇总 ----------------------
total_loss1 = total_loss1 / n_loss_terms1 * self.lambda1 # scalar
total_loss2 = total_loss2 / n_loss_terms2 * self.lambda2 # scalar
total_loss = dict(
cls=total_loss1, patch=total_loss2, loss=total_loss1 + total_loss2
)
self.update_center(teacher_cls, teacher_patch)
return total_loss

@torch.no_grad()
def update_center(self, teacher_cls, teacher_patch):
"""
更新 teacher 的中心
"""

# ---------------------- CLS center ----------------------
cls_center = torch.sum(teacher_cls, dim=0, keepdim=True) # [1, out_dim]
dist.all_reduce(cls_center)
cls_center = cls_center / (len(teacher_cls) * dist.get_world_size()) # [1, out_dim]
self.center = self.center * self.center_momentum + cls_center * (1 - self.center_momentum)
# self.center: [1, out_dim]

# ---------------------- patch center ----------------------
patch_center = torch.sum(teacher_patch.mean(1), dim=0, keepdim=True) # [1, patch_out_dim]
dist.all_reduce(patch_center)
patch_center = patch_center / (len(teacher_patch) * dist.get_world_size()) # [1, patch_out_dim]
self.center2 = self.center2 * self.center_momentum2 + patch_center * (1 - self.center_momentum2)
# self.center2: [1, 1, patch_out_dim]

举个例子,假设batch_size=64,局部crop数量N=8,前向传播过程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# forward student
# 输入 images 是 list,包含 2 个 global crops + N 个 local crops
# 例如 batch_size=64, N=8, 则:
# images[0] -> [64, 3, 224, 224] # global view 1
# images[1] -> [64, 3, 224, 224] # global view 2
# images[2] -> [64, 3, 96, 96] # local view 1
# ...
# images[9] -> [64, 3, 96, 96] # local view 8
student_output, student_output_tokens = student(images)
# student_output: [ (2+N)*64 , out_dim ] # CLS token 输出 (分类头)
# student_output_tokens:[ (2+N)*64 , num_patches, patch_out_dim ] # patch tokens 输出


# forward teacher
# 教师只输入两个 global crops
teacher_output, teacher_output_tokens = teacher(images[:2])
# teacher_output: [ 2*64 , out_dim ] # 只包含全局 view
# teacher_output_tokens:[ 2*64 , num_patches, patch_out_dim ]


# loss
loss = ibot_loss(
student_output, # 学生 CLS 输出
teacher_output, # 教师 CLS 输出
student_output_tokens, # 学生 patch 输出
teacher_output_tokens, # 教师 patch 输出
epoch
)
# ibot_loss 内部:
# 1) 对 teacher 输出做中心化+温度缩放 (和 DINO 类似)
# 2) 计算 CLS loss: 学生所有视图 vs 教师两个全局视图 (交叉熵)
# 3) 计算 Patch loss: 随机 mask 掉部分 patch,学生预测 vs 教师预测 (交叉熵)
# 最终 loss = CLS_loss + λ * Patch_loss


# EMA 更新教师网络参数
# m 是动量系数,通常 0.996 ~ 0.999
with torch.no_grad():
m = momentum_schedule[step] # 根据训练步数调整 momentum
for param_q, param_k in zip(student.parameters(), teacher.parameters()):
param_k.data.mul_(m).add_(param_q.data * (1. - m))
# 保证教师 = 学生参数的 EMA,不直接参与梯度回传


# 更新教师输出的中心向量 (用于防止 collapse)
ibot_loss.update_center(teacher_output, teacher_output_tokens)
# teacher_output: [ 2*64 , out_dim ]
# teacher_output_tokens:[ 2*64 , num_patches, patch_out_dim ]
# 内部计算 batch 的均值,然后 EMA 更新 center_cls 和 center_patch

三、DINOv2的模型优化trick

3.1 更强大的数据处理管道

论文构建了LVD-142M数据集,整个流程仅依赖图像,不需要文本或元数据,主要包括数据来源、去重和自监督检索三个步骤。

3.1.1. 数据来源

  • 精选数据(Curated):ImageNet-22k、ImageNet-1k训练集、Google Landmarks等。
  • 非精选数据(Uncurated):从公开网络爬取的原始图像,经过URL筛选、去重、NSFW过滤和模糊面部处理,得到约12亿张图像。

3.1.2. 去重

  • 对非精选数据去除近似重复,增加多样性。
  • 移除测试/验证集中的重复图像,避免数据泄漏。

3.1.3 自监督图像检索

  • 使用自监督ViT-H/16网络计算图像嵌入。
  • 对非精选数据做k-means聚类。
  • 检索策略:
    • 大查询集:每张图检索4个最近邻。
    • 小查询集:从对应簇中采样图像。

3.1.4 实现细节

  • 去重与检索依赖Faiss GPU 加速索引
  • 使用20节点×8GPU集群,全流程耗时<2天。

3.2 更强大的训练trick

3.2.1 头权重解耦(Untying head weights)

在早期工作的消融实验中,共享投影头(DINO与iBOT共用一个MLP)在小规模数据或低分辨率训练时可以带来更稳定的训练和略微的性能提升。

但是在大规模训练和高分辨率图像下,共享头容易导致冲突

  • DINO的图像级损失关注全局特征
  • iBOT的块级损失关注局部patch特征。
  • 如果使用同一个头,优化方向会相互干扰,导致全局与局部特征难以同时优化。

独立投影头允许DINO和iBOT在大规模自监督训练中各自优化自己的特征目标,从而提升全局和局部特征的表达能力,同时提高下游任务的性能。

3.2.2 Sinkhorn-Knopp 中心化

DINO和iBOT使用教师网络输出的softmax分布做中心化,然而,这种方式的EMA center是单一向量,无法强制batch内特征分布均匀,对于大batch或高维输出向量,可能出现输出向量某些维度几乎不被激活,分布不平衡。

Sinkhorn-Knopp是一种迭代式的矩阵归一化算法,最早在SwAV中使用,用于保证特征在无监督学习中的均衡分布。它的输入和DINO中心化方法的输入一样也是一个[batch_size, out_dim]的tensor。

Sinkhorn-Knopp会迭代调整矩阵,使输出向量的每个维度(原型,prototype)在batch内的总概率均衡,迭代多次后,得到一个双向归一化矩阵。通过强制batch内每个输出向量的维度都被充分利用,使特征分布更均匀,从而增强判别能力。

以下举例说明Sinkhorn-Knopp和DINO/iBOT中心化的区别:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn.functional as F

torch.manual_seed(42)

# ===== 模拟输入 =====
batch_size = 4
out_dim = 10 # class token 输出维度
patch_out_dim = 10 # patch token 输出维度
n_patches = 8
mask_indices = [1, 3, 5] # mask patch

student_temp = 0.1
teacher_temp = 0.04

# class token (image-level) shape: [batch, out_dim] = [4,10]
student_class_token = torch.randn(batch_size, out_dim)
teacher_class_token = torch.randn(batch_size, out_dim)

# patch tokens shape: [batch, n_patches, patch_out_dim] = [4,8,10]
student_patch_tokens = torch.randn(batch_size, n_patches, patch_out_dim)
teacher_patch_tokens = torch.randn(batch_size, n_patches, patch_out_dim)

# 教师中心向量 c,DINO/iBOT EMA center
c = torch.zeros(out_dim)

# ===== 封装函数 =====

def dino_centering(teacher_out, student_out, teacher_temp=0.04, student_temp=0.1, center=c):
# teacher_out: [batch, out_dim]
# student_out: [batch, out_dim]
teacher_probs = F.softmax((teacher_out - center) / teacher_temp, dim=1)
student_probs = F.softmax(student_out / student_temp, dim=1)
loss = -(teacher_probs * student_probs.log()).sum(dim=1).mean()
return teacher_probs, student_probs, loss

def ibot_centering(teacher_patch, student_patch, mask_indices, teacher_temp=0.04, student_temp=0.1, center=c):
# teacher_patch: [batch, n_patches, patch_out_dim]
# student_patch: [batch, n_patches, patch_out_dim]
student_mask_tokens = student_patch[:, mask_indices, :] # [batch, n_mask, patch_out_dim]
teacher_visible_tokens = teacher_patch[:, mask_indices, :] # [batch, n_mask, patch_out_dim]

teacher_probs = F.softmax((teacher_visible_tokens - center) / teacher_temp, dim=2)
student_probs = F.softmax(student_mask_tokens / student_temp, dim=2)
loss = -(teacher_probs * student_probs.log()).sum(dim=2).mean()
return teacher_probs, student_probs, loss

def sinkhorn_knopp_centering(teacher_out, student_out, student_temp=0.1, n_iters=3, epsilon=1e-6):
# teacher_out: [batch, out_dim]
# student_out: [batch, out_dim]
Q = torch.exp(teacher_out / 0.05).t() # [out_dim, batch]
Q /= Q.sum()
r = torch.ones(Q.shape[0], device=Q.device) / Q.shape[0] # [out_dim]
c_vec = torch.ones(Q.shape[1], device=Q.device) / Q.shape[1] # [batch]

for _ in range(n_iters):
Q /= (Q.sum(dim=1, keepdim=True) + epsilon)
Q *= r.view(-1, 1)
Q /= (Q.sum(dim=0, keepdim=True) + epsilon)
Q *= c_vec.view(1, -1)
teacher_probs = Q.t() # [batch, out_dim]

student_probs = F.softmax(student_out / student_temp, dim=1)
loss = -(teacher_probs * student_probs.log()).sum(dim=1).mean()
return teacher_probs, student_probs, loss

# ===== 对比运行 =====
t_dino, s_dino, loss_dino = dino_centering(teacher_class_token, student_class_token)
t_ibot, s_ibot, loss_ibot = ibot_centering(teacher_patch_tokens, student_patch_tokens, mask_indices)
t_sk, s_sk, loss_sk = sinkhorn_knopp_centering(teacher_class_token, student_class_token)

# ===== 输出 =====
print("DINO Loss:", loss_dino.item())
print("iBOT Loss:", loss_ibot.item())
print("SK Loss:", loss_sk.item())

print("Teacher probs DINO shape:", t_dino.shape) # [4,10]
print("Student probs DINO shape:", s_dino.shape) # [4,10]

print("Teacher probs iBOT shape:", t_ibot.shape) # [4,3,10]
print("Student probs iBOT shape:", s_ibot.shape) # [4,3,10]

print("Teacher probs SK shape:", t_sk.shape) # [4,10]
print("Student probs SK shape:", s_sk.shape) # [4,10]

3.2.3 KoLeo 正则化

KoLeo 正则化项是加在总的loss中的,它直接操作embedding,无需后续的softmax,它可以使得batch内特征在向量空间中均匀分布,避免特征坍塌。

具体的原理为:对batch内每个向量,经过L2归一化之后,计算它与最近邻的距离,最大化最小距离。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def koleo_loss(features):
"""
features: [batch, num_patches, dim] 或 [batch, dim]
自动展开 batch * num_patches
"""
if features.dim() == 3:# 学生网络是3
batch, num_patches, dim = features.shape
features = features.reshape(batch * num_patches, dim) # [batch*num_patches, dim]
features = F.normalize(features, p=2, dim=1) # [N, dim]
N = features.size(0)
dist = torch.cdist(features, features, p=2) # [N, N]
mask = torch.eye(N, device=features.device).bool()
dist.masked_fill_(mask, float('inf'))
d_min, _ = dist.min(dim=1) # [N]
loss = -torch.log(d_min).mean()
return loss

举个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------
# KoLeo 正则化器
# ------------------------------
def koleo_loss(features):
"""
features: [batch, num_patches, dim] 或 [batch, dim]
自动展开 batch * num_patches
"""
if features.dim() == 3:
batch, num_patches, dim = features.shape
features = features.reshape(batch * num_patches, dim) # [batch*num_patches, dim]
features = F.normalize(features, p=2, dim=1) # [N, dim]
N = features.size(0)
dist = torch.cdist(features, features, p=2) # [N, N]
mask = torch.eye(N, device=features.device).bool()
dist.masked_fill_(mask, float('inf'))
d_min, _ = dist.min(dim=1) # [N]
loss = -torch.log(d_min).mean()
return loss

# ------------------------------
# DINO 类别级损失
# ------------------------------
def dino_loss(student_class, teacher_class, center, temp=0.1):
"""
student_class: [batch, out_dim]
teacher_class: [batch, out_dim]
center: [1, out_dim]
"""
teacher_probs = F.softmax((teacher_class - center) / temp, dim=1) # [batch, out_dim]
student_probs = F.log_softmax(student_class / temp, dim=1) # [batch, out_dim]
loss = -(teacher_probs * student_probs).sum(dim=1).mean() # scalar
return loss

# ------------------------------
# iBOT patch-level 损失
# ------------------------------
def ibot_patch_loss(student_patch, teacher_patch, center, temp=0.1):
"""
student_patch: [batch, num_patches, patch_out_dim]
teacher_patch: [batch, num_patches, patch_out_dim]
center: [1, patch_out_dim]
"""
batch, num_patches, dim = student_patch.shape
# reshape 为 [batch*num_patches, dim]
student_patch_flat = student_patch.reshape(batch*num_patches, dim)
teacher_patch_flat = teacher_patch.reshape(batch*num_patches, dim)

teacher_probs = F.softmax((teacher_patch_flat - center) / temp, dim=1) # [B*P, dim]
student_probs = F.log_softmax(student_patch_flat / temp, dim=1) # [B*P, dim]
loss = -(teacher_probs * student_probs).sum(dim=1).mean() # scalar
return loss

# ------------------------------
# 模拟模型输出
# ------------------------------
batch_size = 8
out_dim = 128
num_patches = 16
patch_out_dim = 64

# 学生和教师随机输出(模拟网络预测输出)
student_class = torch.randn(batch_size, out_dim) # [8, 128]
teacher_class = torch.randn(batch_size, out_dim) # [8, 128]
student_patch = torch.randn(batch_size, num_patches, patch_out_dim) # [8,16,64]
teacher_patch = torch.randn(batch_size, num_patches, patch_out_dim) # [8,16,64]

# EMA 中心向量
center_class = torch.zeros(1, out_dim) # [1,128]
center_patch = torch.zeros(1, patch_out_dim) # [1,64]

# ------------------------------
# 超参数
# ------------------------------
lambda_ibot = 1.0
lambda_koleo = 0.1
temp = 0.1

# ------------------------------
# 计算各项损失
# ------------------------------
loss_class = dino_loss(student_class, teacher_class, center_class, temp) # scalar
loss_patch = ibot_patch_loss(student_patch, teacher_patch, center_patch, temp) # scalar
loss_koleo_class = koleo_loss(student_class) # scalar
loss_koleo_patch = koleo_loss(student_patch) # scalar

# ------------------------------
# 总损失
# ------------------------------
loss_total = loss_class + lambda_ibot * loss_patch + lambda_koleo * (loss_koleo_class + loss_koleo_patch)

print("Loss DINO/class-level:", loss_class.item())
print("Loss iBOT/patch-level:", loss_patch.item())
print("Loss KoLeo/class:", loss_koleo_class.item())
print("Loss KoLeo/patch:", loss_koleo_patch.item())
print("Loss total:", loss_total.item())

# ------------------------------
# 反向传播示例
# ------------------------------
optimizer = torch.optim.Adam([student_class, student_patch], lr=1e-3)
optimizer.zero_grad()
loss_total.backward()
optimizer.step()

3.2.4 高分辨率短时训练策略

提高图像分辨率对于像素级下游任务(如分割或检测)至关重要,在这些任务中,小物体在低分辨率下会消失。然而,在高分辨率下训练既耗时又耗内存,因此,DINOv2在预训练结束时的短时间内将图像分辨率增加到518×518,最终得到的模型既高效又能适应下游像素级任务.

四、DINOv2的工程优化trick

  • 引入了FlashAttention,替换标准的self-attention层。

  • 使用Sequence Packings来同时执行全局视图和局部视图的前向推理过程。

    举个直观的比喻:

    可以把每个序列看成一条“火车”,长度不同:大裁剪=196个token,小裁剪=49个token

    原方法:分别让两列火车独立通过 Transformer,比较低效

    Sequence Packing:把火车连接成一列长火车,用隔板(block-diagonal mask)分隔,从而可以并行计算但不会互相干扰

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn.functional as F

def sequence_packing_forward(student_tokens_list, transformer):
"""
将不同长度的 token 序列打包一起前向 Transformer

Args:
student_tokens_list: list of tensor, 每个 tensor shape [B, L_i, C]
- B: batch size
- L_i: 当前裁剪 token 数量
- C: token embedding 维度
transformer: transformer 模块

Returns:
output: 打包后的 transformer 输出, shape [B, L_total, C]
"""
B = student_tokens_list[0].shape[0]
C = student_tokens_list[0].shape[2]

# 计算总长度
L_list = [tokens.shape[1] for tokens in student_tokens_list]
L_total = sum(L_list)

# 拼接成一个长序列
packed_tokens = torch.cat(student_tokens_list, dim=1) # [B, L_total, C]

# 构建 block-diagonal mask
# mask shape [B, L_total, L_total]
# True 表示 **屏蔽** 注意力(不能注意),默认是全True
mask = torch.ones(B, L_total, L_total, device=packed_tokens.device, dtype=torch.bool)
start = 0
for L in L_list:
mask[:, start:start+L, start:start+L] = False # 只有每个序列内部可注意
start += L

# Transformer 前向
# 假设 transformer 支持 mask 参数
output = transformer(packed_tokens, attention_mask=mask) # [B, L_total, C]

# 可以按原序列切分输出
outputs_list = []
start = 0
for L in L_list:
outputs_list.append(output[:, start:start+L, :]) # shape [B, L_i, C]
start += L

return outputs_list
  • 使用PyTorch2.0中的Fully-Shareded Data Parallel(FSDP)将模型切分到不同的GPU上。
  • 通过知识蒸馏将大模型,比如ViT-G,蒸馏到小模型,比如ViT-L。

参考: