DDIM 采样与 DDIM 逆向代码分析
参考代码:HF:DDIM Inversion
原理讲解
DDIM采样(去噪过程)
快速实现
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
image = pipe(prompt, negative_prompt=negative_prompt).images[0]
具体实现
其实代码很清楚了,但是为了使逻辑更加清晰,我画了一个思维导图帮助理解。
有以下值得注意的事情:
-
set_timesteps是去噪过程(采样过程)t 的生成方式,而加噪的时候应该要反过来。这里在后面DDIM Inversion的时候应该详细查看相应代码。
-
文字信息是作为encoder_hidden_states传入Unet中的,而实际上很多信息都可以以这样的方式传入,包括图片、结构等信息。
-
这里使用的是隐空间的扩散模型,但是噪声是直接生成隐空间的噪声!
# Sample function (regular DDIM)
@torch.no_grad()
def sample(
prompt,
start_step=0,
start_latents=None,
guidance_scale=3.5,
num_inference_steps=30,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt="",
device=device,
):
# Encode prompt
text_embeddings = pipe._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# create timesteps
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
# Create start latents
if start_latents is None:
start_latents = torch.randn(1, 4, 64, 64, device=device)
start_latents *= pipe.scheduler.init_noise_sigma
latents = start_latents.clone()
for i in tqdm(range(start_step, num_inference_steps)):
t = pipe.scheduler.timesteps[i]
# Expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
# Predict the noise
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
prev_t = max(1, t.item() - (1000 // num_inference_steps)) # t-1
alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
predicted_x0 = (latents - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * noise_pred
latents = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt
# decode latents
images = pipe.decode_latents(latents)
images = pipe.numpy_to_pil(images)
return images
调用就十分简单了,给出代码不做分析:
image = sample("Watercolor painting of a beach sunset",
negative_prompt=negative_prompt,
num_inference_steps=50)[0]
DDIM Inversion
- 处理图片得到起始的latents( Encode the image )
# torchvision.transforms.functional.to_tensor(img) 将值转化到0~1之间
# unsqueeze(0) 升维度得到Batchsize这个维度
# * 2 - 1 将值转化到-1~1之间
with torch.no_grad():
latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)
# latent.latent_dist.sample() 得到latents,.latent_dist 是潜在分布
# 这是因为VAE_encoder的输出是Distribution(只有均值和方差),所以需要sample()
# 0.18215 是一个缩放因子,是VAE固定的不用管
l = 0.18215 * latent.latent_dist.sample()
- Inversion
注意返回的是一个列表,而不是一个latent。
## Inversion
@torch.no_grad()
def invert(
start_latents,
prompt,
guidance_scale=3.5,
num_inference_steps=80,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt="",
device=device,
):
# Encode prompt
text_embeddings = pipe._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# get latents
latents = start_latents.clone()
# keep a list of the inverted latents
intermediate_latents = []
# set_timesteps
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = reversed(pipe.scheduler.timesteps)
for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):
# skip the final iteration
if i >= num_inference_steps - 1:
continue
t = timesteps[i]
# Expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
current_t = max(0, t.item() - (1000 // num_inference_steps)) # t - 1
next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t
alpha_t = pipe.scheduler.alphas_cumprod[current_t]
alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]
# Inverted update step: x(t) <- x(t-1)
latents = (latents - (1 - alpha_t).sqrt() * noise_pred)
* (alpha_t_next.sqrt() / alpha_t.sqrt())
+ (1 - alpha_t_next).sqrt() * noise_pred
intermediate_latents.append(latents)
return torch.cat(intermediate_latents)
# Decode the final inverted latents
with torch.no_grad():
im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
pipe.numpy_to_pil(im)[0]