Denoising Diffusion Implicit Models
DDPM到DDIM我之前在另一篇笔记中已经很详细的分析过了,下面直接给DDIM表达式:
DDPM代码我也在其他笔记中详细分析过了,因此只阐述代码的核心不同点。
本质上来说,DDIM就是新的一种采样方式,但是模型可以和DDPM一模一样!
DDIM只需要50步就能接近DDPM1000步的效果!
离散化采样
self.n_steps = model.n_steps
# 均匀离散化
if ddim_discretize == 'uniform':
c = self.n_steps // n_steps
self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1
# 平方离散化
elif ddim_discretize == 'quad':
self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
else:
raise NotImplementedError(ddim_discretize)
- 均匀离散化(uniform):生成时间步间隔固定,适用于对所有时间步要求一致的情况。
- 平方离散化(quad):生成前期密集、后期稀疏的时间步,这种方法可以在生成早期快速去噪,而在生成后期保持更多细节,提高生成图像的质量。
采样Sample(去噪过程)
得益于离散化和非马尔科夫性质,DDIM采样有多好处:
- 生成过程具有一致性:DDIM的生成过程是确定性的,这意味着在相同的潜在变量条件下,生成的多个样本应该具有相似的高层次特征。这使其生成样本时表现更稳定一致。
- 语义插值:在潜在空间中插值两个不同的潜在变量,DDIM生成的中间样本能够在语义上逐渐过渡。这使得DDIM在图像生成、图像编辑和图像转换等任务中具有很大的应用潜力。
体现在代码上,DDIM的采样如下:
bs = shape[0]
# 可以通过x_last指定初始噪声
x = x_last if x_last is not None else torch.randn(shape, device=device)
# 可以通过skip_steps跳过一些时间步
time_steps = np.flip(self.time_steps)[skip_steps:] # flip是翻转顺序
for i, step in monit.enum('Sample', time_steps):
index = len(time_steps) - i - 1
# 扩充维度
ts = x.new_full((bs,), step, dtype=torch.long)
x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
repeat_noise=repeat_noise,
temperature=temperature,
uncond_scale=uncond_scale,
uncond_cond=uncond_cond)
return x
具体的每一步采样过程如下:
# 得到预测噪声值e_t
e_t = self.get_eps(x, t, c, uncond_scale=uncond_scale, uncond_cond=uncond_cond)
# 根据e_t计算x_prev, pred_x0
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
temperature=temperature,
repeat_noise=repeat_noise)
return x_prev, pred_x0, e_t
def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
temperature: float,
repeat_noise: bool):
alpha = self.ddim_alpha[index]
alpha_prev = self.ddim_alpha_prev[index]
sigma = self.ddim_sigma[index]
sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]
pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t
if sigma == 0.:
noise = 0.
elif repeat_noise:
noise = torch.randn((1, *x.shape[1:]), device=x.device)
else:
noise = torch.randn(x.shape, device=x.device)
noise = noise * temperature
x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
return x_prev, pred_x0
其中核心的代码公式对应关系如下:
Paint(定制化采样/去噪)
定制化体现在:
- 可以从指定时间步开始,可以跳步骤
- 可以指定输入x,比如可以是一张图片加噪之后得到的噪声
- 可以使用Mask进行特征融合
def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
orig: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None,
):
bs = x.shape[0]
time_steps = np.flip(self.time_steps[:t_start])
for i, step in monit.enum('Paint', time_steps):
index = len(time_steps) - i - 1
ts = x.new_full((bs,), step, dtype=torch.long)
x, _, _ = self.p_sample(x, cond, ts, step, index=index,
uncond_scale=uncond_scale,
uncond_cond=uncond_cond)
if orig is not None:
orig_t = self.q_sample(orig, index, noise=orig_noise)
x = orig_t * mask + x * (1 - mask)
return x