跳转至

Denoising Diffusion Implicit Models

code: nn.labml.ai: Denoising Diffusion Implicit Models

paper: Denoising Diffusion Implicit Models

DDPM到DDIM我之前在另一篇笔记中已经很详细的分析过了,下面直接给DDIM表达式:

An image caption

DDPM代码我也在其他笔记中详细分析过了,因此只阐述代码的核心不同点。

本质上来说,DDIM就是新的一种采样方式,但是模型可以和DDPM一模一样!

DDIM只需要50步就能接近DDPM1000步的效果!

离散化采样

self.n_steps = model.n_steps
# 均匀离散化
if ddim_discretize == 'uniform':
    c = self.n_steps // n_steps
    self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1
# 平方离散化
elif ddim_discretize == 'quad':
    self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
else:
    raise NotImplementedError(ddim_discretize)
  • 均匀离散化(uniform):生成时间步间隔固定,适用于对所有时间步要求一致的情况。
  • 平方离散化(quad):生成前期密集、后期稀疏的时间步,这种方法可以在生成早期快速去噪,而在生成后期保持更多细节,提高生成图像的质量。

采样Sample(去噪过程)

得益于离散化和非马尔科夫性质,DDIM采样有多好处:

  • 生成过程具有一致性:DDIM的生成过程是确定性的,这意味着在相同的潜在变量条件下,生成的多个样本应该具有相似的高层次特征。这使其生成样本时表现更稳定一致。
  • 语义插值:在潜在空间中插值两个不同的潜在变量,DDIM生成的中间样本能够在语义上逐渐过渡。这使得DDIM在图像生成、图像编辑和图像转换等任务中具有很大的应用潜力。

体现在代码上,DDIM的采样如下:

bs = shape[0]

# 可以通过x_last指定初始噪声
x = x_last if x_last is not None else torch.randn(shape, device=device)

# 可以通过skip_steps跳过一些时间步
time_steps = np.flip(self.time_steps)[skip_steps:] # flip是翻转顺序

for i, step in monit.enum('Sample', time_steps):
    index = len(time_steps) - i - 1
    # 扩充维度
    ts = x.new_full((bs,), step, dtype=torch.long)
    x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index, 
                                    repeat_noise=repeat_noise, 
                                    temperature=temperature, 
                                    uncond_scale=uncond_scale, 
                                    uncond_cond=uncond_cond)
return x

具体的每一步采样过程如下:

# 得到预测噪声值e_t
e_t = self.get_eps(x, t, c, uncond_scale=uncond_scale, uncond_cond=uncond_cond)

# 根据e_t计算x_prev, pred_x0
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
        temperature=temperature,
        repeat_noise=repeat_noise)
return x_prev, pred_x0, e_t

def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
                           temperature: float,
                           repeat_noise: bool):
    alpha = self.ddim_alpha[index]
    alpha_prev = self.ddim_alpha_prev[index]
    sigma = self.ddim_sigma[index]
    sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]

    pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
    dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t

    if sigma == 0.:
        noise = 0.
    elif repeat_noise:
        noise = torch.randn((1, *x.shape[1:]), device=x.device)
    else:
        noise = torch.randn(x.shape, device=x.device)

    noise = noise * temperature
    x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
    return x_prev, pred_x0

其中核心的代码公式对应关系如下:

An image caption

Paint(定制化采样/去噪)

定制化体现在:

  • 可以从指定时间步开始,可以跳步骤
  • 可以指定输入x,比如可以是一张图片加噪之后得到的噪声
  • 可以使用Mask进行特征融合
def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
              orig: Optional[torch.Tensor] = None,
              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
              uncond_scale: float = 1.,
              uncond_cond: Optional[torch.Tensor] = None,
              ):

        bs = x.shape[0]
        time_steps = np.flip(self.time_steps[:t_start])

        for i, step in monit.enum('Paint', time_steps):
            index = len(time_steps) - i - 1
            ts = x.new_full((bs,), step, dtype=torch.long)

            x, _, _ = self.p_sample(x, cond, ts, step, index=index,
                                    uncond_scale=uncond_scale,
                                    uncond_cond=uncond_cond)
            if orig is not None:
                orig_t = self.q_sample(orig, index, noise=orig_noise)
                x = orig_t * mask + x * (1 - mask)
        return x