跳转至

Denoising Diffusion Probabilistic Models

code: nn.labml.ai: Denoising Diffusion Probabilistic Models

paper: Denoising Diffusion Probabilistic Models

基础定义

diffusion中的几个符号要比较清楚:

An image caption
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
self.alpha = 1 - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma2 = self.beta
self.n_steps = n_steps # 这个训练的时候生成一批t,这些t是随机的且小于n_steps

前向过程 q ( x_t | x_0 )

整体就是实现一个采样的过程,也就是根据时间步 t 和初始状态 x0,采样一个状态 x_t。

𝐼 是一个单位矩阵(identity matrix),也称为恒等矩阵。在矩阵乘法中类似于数字1的作用。

An image caption
def q_sample(self, 
            x0: torch.Tensor, 
            t: torch.Tensor, 
            eps: Optional[torch.Tensor] = None):  

    if eps is None:
        eps = torch.randn_like(x0)
    mean, var = self.q_xt_x0(x0, t)

    # 从均值为mean,方差为var的正态分布中采样,返回这个采样结果
    return mean + (var ** 0.5) * eps 
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) 
                    -> Tuple[torch.Tensor, torch.Tensor]:

    # 复现公式罢了
    mean = gather(self.alpha_bar, t) ** 0.5 * x0
    var = 1 - gather(self.alpha_bar, t)
    return mean, var

逆向过程 p ( x_t-1 | x_t )

值得注意的是,DDPM中的方差是不预测的,即前向过程和逆向过程分布的方差一致!但是代码还是简化了方差的表示,原本的方差是求解出来的,会更加复杂一点,原本的方差如下:

An image caption

为什么要简化这个方差呢?理由如下:

An image caption

逆向过程整体就是实现一个采样的过程,公式如下:

An image caption

训练的时候计算损失函数,使用MSE损失函数(均方误差)即可(因为输入是一批数据)。

def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
    # 获得一些参数
    eps_theta = self.eps_model(xt, t)
    alpha_bar = gather(self.alpha_bar, t)
    alpha = gather(self.alpha, t)
    eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

    # 采样公式
    mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
    var = gather(self.sigma2, t)
    eps = torch.randn(xt.shape, device=xt.device)
    return mean + (var ** .5) * eps

Unet model for DDPM

Unet模型结构如下所示,主要可以分为Up Block、Down Block、Middle Block:

An image caption

为什么要将Encoder和Decoder连接起来呢?一般认为,Encoder包含更多的空间信息,Decoder包含更多的语义信息,将两者结合起来,可以实现像素级别的分割效果。

由于Up Block、Down Block、Middle Block都是由ResidualBlock和AttentionBlock构成的。而且AttentionBlock就是常见的Multi-head attention,因此这里详细分析一下ResidualBlock。

需要记住一个很重要的点是ResidualBlock和AttentionBlock的输入和输出维度一致。

Residual Block

先补充一些概念:

  • 组归一化(Group Normalization)通过将通道划分为多个组,并在每个组内进行归一化,来解决批归一化在小批量情况下表现不佳的问题。
  • Shortcut 是指捷径连接,用于在输入和输出维度不同时调整,以正确地执行加法操作。
  • 先归一化,再激活,最后卷积,在激活之前进行归一化,可以确保激活函数输入的分布更稳定和集中,从而使得网络训练更加稳定。在归一化和激活后的特征图上进行卷积操作,在稳定和集中分布的数据上进行操作,能够更有效地学习和提取特征。

  • Residual Block实现了时间信息t的编码。但是输入的t是经过编码的向量!而不是一个数字!

x has shape [batch_size, in_channels, height, width] t has shape [batch_size, time_channels]

class ResidualBlock(Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
                 n_groups: int = 32, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.GroupNorm(n_groups, in_channels)
        self.act1 = Swish()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        self.norm2 = nn.GroupNorm(n_groups, out_channels)
        self.act2 = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
            else:
                self.shortcut = nn.Identity()
        self.time_emb = nn.Linear(time_channels, out_channels)
        self.time_act = Swish()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        h = self.conv1(self.act1(self.norm1(x)))
        h += self.time_emb(self.time_act(t))[:, :, None, None]
        h = self.conv2(self.dropout(self.act2(self.norm2(h))))
        return h + self.shortcut(x)

Down

  • DownBlock是图中的水平箭头,不过稍加改进由ResidualBlock和AttentionBlock构成。
  • Downsample则是个卷积层,用于下采样,是图中的向下的箭头。
down = []
out_channels = in_channels = n_channels
for i in range(n_resolutions):
    out_channels = in_channels * ch_mults[i]
    for _ in range(n_blocks):
        down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
        in_channels = out_channels
    if i < n_resolutions - 1:
        down.append(Downsample(in_channels))
self.down = nn.ModuleList(down)

Middle、Up类似,这里不再赘述了,可以自行查看代码。

Unet

def forward(self, x: torch.Tensor, t: torch.Tensor):
    t = self.time_emb(t)
    x = self.image_proj(x)
    h = [x]

    # 下采样
    for m in self.down:
        x = m(x, t)
        h.append(x)
    x = self.middle(x, t)

    # 上采样
    for m in self.up:
        if isinstance(m, Upsample):
            x = m(x, t)
        # 拼接操作
        else:
            s = h.pop()
            x = torch.cat((x, s), dim=1)
            x = m(x, t)
    return self.final(self.act(self.norm(x)))