跳转至

DDIM 采样与 DDIM 逆向代码分析

参考代码:HF:DDIM Inversion

原理讲解

DDIM采样(去噪过程)

快速实现

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
image = pipe(prompt, negative_prompt=negative_prompt).images[0]

具体实现

其实代码很清楚了,但是为了使逻辑更加清晰,我画了一个思维导图帮助理解。

An image caption

有以下值得注意的事情:

  • set_timesteps是去噪过程(采样过程)t 的生成方式,而加噪的时候应该要反过来。这里在后面DDIM Inversion的时候应该详细查看相应代码。

  • 文字信息是作为encoder_hidden_states传入Unet中的,而实际上很多信息都可以以这样的方式传入,包括图片、结构等信息。

  • 这里使用的是隐空间的扩散模型,但是噪声是直接生成隐空间的噪声!

# Sample function (regular DDIM)
@torch.no_grad()
def sample(
    prompt,
    start_step=0,
    start_latents=None,
    guidance_scale=3.5,
    num_inference_steps=30,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):

    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # create timesteps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Create start latents
    if start_latents is None:
        start_latents = torch.randn(1, 4, 64, 64, device=device)
        start_latents *= pipe.scheduler.init_noise_sigma
    latents = start_latents.clone()

    for i in tqdm(range(start_step, num_inference_steps)):

        t = pipe.scheduler.timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
        prev_t = max(1, t.item() - (1000 // num_inference_steps))  # t-1
        alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
        alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
        predicted_x0 = (latents - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
        direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * noise_pred
        latents = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt

    # decode latents
    images = pipe.decode_latents(latents)
    images = pipe.numpy_to_pil(images)

    return images

调用就十分简单了,给出代码不做分析:

image = sample("Watercolor painting of a beach sunset", 
                negative_prompt=negative_prompt,
                num_inference_steps=50)[0]

DDIM Inversion

  • 处理图片得到起始的latents( Encode the image )
# torchvision.transforms.functional.to_tensor(img) 将值转化到0~1之间
# unsqueeze(0) 升维度得到Batchsize这个维度
# * 2 - 1 将值转化到-1~1之间
with torch.no_grad():
    latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device) * 2 - 1)

# latent.latent_dist.sample() 得到latents,.latent_dist 是潜在分布 
# 这是因为VAE_encoder的输出是Distribution(只有均值和方差),所以需要sample()
# 0.18215 是一个缩放因子,是VAE固定的不用管
l = 0.18215 * latent.latent_dist.sample()
  • Inversion
An image caption
An image caption

注意返回的是一个列表,而不是一个latent。

## Inversion
@torch.no_grad()
def invert(
    start_latents,
    prompt,
    guidance_scale=3.5,
    num_inference_steps=80,
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
    negative_prompt="",
    device=device,
):
    # Encode prompt
    text_embeddings = pipe._encode_prompt(
        prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # get latents
    latents = start_latents.clone()

    # keep a list of the inverted latents
    intermediate_latents = []

    # set_timesteps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = reversed(pipe.scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        #  skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t - 1
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t 
        alpha_t = pipe.scheduler.alphas_cumprod[current_t]
        alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]

        # Inverted update step: x(t) <- x(t-1) 
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) 
            * (alpha_t_next.sqrt() / alpha_t.sqrt()) 
            + (1 - alpha_t_next).sqrt() * noise_pred

        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)

# Decode the final inverted latents
with torch.no_grad():
    im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
pipe.numpy_to_pil(im)[0]