Denoising Diffusion Probabilistic Models
基础定义
diffusion中的几个符号要比较清楚:
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
self.alpha = 1 - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma2 = self.beta
self.n_steps = n_steps # 这个训练的时候生成一批t,这些t是随机的且小于n_steps
前向过程 q ( x_t | x_0 )
整体就是实现一个采样的过程,也就是根据时间步 t 和初始状态 x0,采样一个状态 x_t。
𝐼 是一个单位矩阵(identity matrix),也称为恒等矩阵。在矩阵乘法中类似于数字1的作用。
def q_sample(self,
x0: torch.Tensor,
t: torch.Tensor,
eps: Optional[torch.Tensor] = None):
if eps is None:
eps = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
# 从均值为mean,方差为var的正态分布中采样,返回这个采样结果
return mean + (var ** 0.5) * eps
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor)
-> Tuple[torch.Tensor, torch.Tensor]:
# 复现公式罢了
mean = gather(self.alpha_bar, t) ** 0.5 * x0
var = 1 - gather(self.alpha_bar, t)
return mean, var
逆向过程 p ( x_t-1 | x_t )
值得注意的是,DDPM中的方差是不预测的,即前向过程和逆向过程分布的方差一致!但是代码还是简化了方差的表示,原本的方差是求解出来的,会更加复杂一点,原本的方差如下:
为什么要简化这个方差呢?理由如下:
逆向过程整体就是实现一个采样的过程,公式如下:
训练的时候计算损失函数,使用MSE损失函数(均方误差)即可(因为输入是一批数据)。
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
# 获得一些参数
eps_theta = self.eps_model(xt, t)
alpha_bar = gather(self.alpha_bar, t)
alpha = gather(self.alpha, t)
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
# 采样公式
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
var = gather(self.sigma2, t)
eps = torch.randn(xt.shape, device=xt.device)
return mean + (var ** .5) * eps
Unet model for DDPM
Unet模型结构如下所示,主要可以分为Up Block、Down Block、Middle Block:
为什么要将Encoder和Decoder连接起来呢?一般认为,Encoder包含更多的空间信息,Decoder包含更多的语义信息,将两者结合起来,可以实现像素级别的分割效果。
由于Up Block、Down Block、Middle Block都是由ResidualBlock和AttentionBlock构成的。而且AttentionBlock就是常见的Multi-head attention,因此这里详细分析一下ResidualBlock。
需要记住一个很重要的点是ResidualBlock和AttentionBlock的输入和输出维度一致。
Residual Block
先补充一些概念:
- 组归一化(Group Normalization)通过将通道划分为多个组,并在每个组内进行归一化,来解决批归一化在小批量情况下表现不佳的问题。
- Shortcut 是指捷径连接,用于在输入和输出维度不同时调整,以正确地执行加法操作。
-
先归一化,再激活,最后卷积,在激活之前进行归一化,可以确保激活函数输入的分布更稳定和集中,从而使得网络训练更加稳定。在归一化和激活后的特征图上进行卷积操作,在稳定和集中分布的数据上进行操作,能够更有效地学习和提取特征。
-
Residual Block实现了时间信息t的编码。但是输入的t是经过编码的向量!而不是一个数字!
x has shape [batch_size, in_channels, height, width] t has shape [batch_size, time_channels]
class ResidualBlock(Module):
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
n_groups: int = 32, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.GroupNorm(n_groups, in_channels)
self.act1 = Swish()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
self.norm2 = nn.GroupNorm(n_groups, out_channels)
self.act2 = Swish()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
else:
self.shortcut = nn.Identity()
self.time_emb = nn.Linear(time_channels, out_channels)
self.time_act = Swish()
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, t: torch.Tensor):
h = self.conv1(self.act1(self.norm1(x)))
h += self.time_emb(self.time_act(t))[:, :, None, None]
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
return h + self.shortcut(x)
Down
- DownBlock是图中的水平箭头,不过稍加改进由ResidualBlock和AttentionBlock构成。
- Downsample则是个卷积层,用于下采样,是图中的向下的箭头。
down = []
out_channels = in_channels = n_channels
for i in range(n_resolutions):
out_channels = in_channels * ch_mults[i]
for _ in range(n_blocks):
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
if i < n_resolutions - 1:
down.append(Downsample(in_channels))
self.down = nn.ModuleList(down)
Middle、Up类似,这里不再赘述了,可以自行查看代码。
Unet
def forward(self, x: torch.Tensor, t: torch.Tensor):
t = self.time_emb(t)
x = self.image_proj(x)
h = [x]
# 下采样
for m in self.down:
x = m(x, t)
h.append(x)
x = self.middle(x, t)
# 上采样
for m in self.up:
if isinstance(m, Upsample):
x = m(x, t)
# 拼接操作
else:
s = h.pop()
x = torch.cat((x, s), dim=1)
x = m(x, t)
return self.final(self.act(self.norm(x)))