VIT/DeiT 代码分析
代码参考:pytorch_classification/vision_transformer
原理参考:DeiT:注意力也能蒸馏
原理介绍
VIT
原理思想比较简单,将图片分块然后展开,添加上位置编码输入到Transformer模型中。
就像论文作者所提到的那样当不使用 JFT-300 大数据集时,效果不如CNN模型。这说明Transformer结构若想取得理想的性能和泛化能力就需要很大的数据集。在小规模数据集上训练的VIT效果将会很差,为了解决这个问题,一种称为DeiT的改进方法出现了。
DeiT:Data-efficient image Transformers
本质是利用知识蒸馏的方法让模型不仅能学习到硬标签,还能学习到软标签。
硬标签是指只有label(softmax之后只有0和1)
软标签是指有softmax的各个类的预测值(就不再是0和1了)。
如何实现呢,其实也很简单,只需要加上一个distillation-token就好了,而对于Transformer本身应该也要有一个Class-token,所以DeiT就需要额外加两个token了。训练原理如下:
模型有两个损失:一个是监督损失,一个是蒸馏损失。其中蒸馏损失包含软蒸馏损失和硬蒸馏损失。监督损失、硬蒸馏损失都是用分类交叉熵作为损失函数,软蒸馏损失是用KL散度作为损失函数。
编码模块
编码模块包括分块、展开、换位、引入cls-token和dist-token、位置编码等步骤。
transfomer的输入格式为:[batch_size, seq_len, hidden_size],因此对于图片数据[B, C, H, W],我们也要处理为类似的结构才能输入到transformer中。
- 分块/展开/换位
分块通过卷积就可以实现,展开只要flatten即可。
# 只展示核心代码
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
super().__init__()
...
# 通道数目:in_c = 3
# 输出通道数(多少个卷积核):embed_dim = 768
# 卷积核大小:kenerl_size = 16 * 16
# 步长:stride = 16 * 16
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
# 层归一化:在单个数据样本的所有通道上进行操作
# 可以学习这个条件判断,nn.Identity()表示啥也不做!
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# proj: [B, C, H, W] -> [B, embed_dim, H, W]
# flatten: [B, embed_dim, H, W] -> [B, embed_dim, HW]
# transpose: [B, embed_dim, HW] -> [B, HW, embed_dim]符合Transformer的格式
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
- 引入cls-token和dist-token
...
# cls-token是必须要加的否则做不了蒸馏
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
# dist-token是可选的,根据是否蒸馏来决定
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
- 位置编码
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
...
x = self.pos_drop(x + self.pos_embed)
Transformer模块
blocks = block * depth = block * 12
block = attention + mlp
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
- drop_path
本次使用到的是Stochastic Depth正则化技术,主要用于深度神经网络,尤其是在 ResNet(Residual Networks)中。它的主要思想是在训练过程中随机地丢弃(skip)一些残差块(residual blocks),从而减少网络的有效深度。这种方法有助于防止过拟合,并提高网络的泛化能力。
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# (B,1,1,1) = (B,) + (1,1,1) = (B,) + (1,) * (4 - 1)
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
# 随机数:[0,1] + keep_prob
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
# 二值化,要么0要么1
random_tensor.floor_()
# 除以 keep_prob:因为删去的一些神经元,但是要保证整体输出不太缩小
# 乘以 random_tensor:随机删除一些神经元
output = x.div(keep_prob) * random_tensor
return output
- Muti-head Attention
多头注意力机制的核心思想是并行计算。通过将输入的特征维度 C 分割成多个头,每个头独立计算注意力,这样可以并行处理,提高计算效率和模型性能。
def forward(self, x):
# [batch_size, num_patches + 1, total_embed_dim]
B, N, C = x.shape
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # 一般8个头
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2]
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# ! 整合多头信息:线性层可以将不同头的输出组合在一起,学习到更高级的特征。
x = self.proj(x)
x = self.proj_drop(x)
return x
- MLP
GELU 是一种激活函数,广泛应用于深度学习模型中,特别是在 Transformer 架构中。GELU 比 ReLU 等激活函数更加平滑,这有助于梯度流动,从而可能导致更好的训练性能。
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x