跳转至

Latent Diffusion Models

Code: nn.labml.ai: Latent Diffusion Models

Paper: High-Resolution Image Synthesis with Latent Diffusion Models

DDPM和LDM的核心差异在于Unet网络。在 DDPM 的代码分析中,我们已经实现了 Unet,但是这里的 Unet 和 DDPM 的 Unet 不同。因为 Stable Diffusion 需要引入条件信息(c)的控制。

Unet for Stable Diffusion

Unet for DDPM 我们之前详细分析过,主要是由ResidualBlock和AttentionBlock构成,其中ResidualBlock可以融合时间信息t,AttentionBlock单纯为了模型性能更好。

在 Unet for Stable Diffusion 中,ResidualBlock保持一致,可以融合时间信息t;AttentionBlock则引入了条件信息c引导,从多头注意力变成了交叉注意力!同时升级了一下,变成了SpatialTransformer。

SpatialTransformer

为了更好的理解,我画了一下其大致结构图,如下所示:

An image caption
def forward(self, x: torch.Tensor, cond: torch.Tensor):

    b, c, h, w = x.shape
    x_in = x # For residual connection
    x = self.norm(x)

    # Transpose and reshape from `[batch_size, channels, height, width]`
    # to `[batch_size, height * width, channels]`这个没办法视觉Transformer只能这样!
    x = self.proj_in(x)
    x = x.permute(0, 2, 3, 1).view(b, h * w, c)

    for block in self.transformer_blocks:
        x = block(x, cond)

    # Reshape
    x = x.view(b, h, w, c).permute(0, 3, 1, 2)
    x = self.proj_out(x)
    return x + x_in

其中使用到的transformer_blocks为多个BasicTransformerBlock:

self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)

def forward(self, x: torch.Tensor, cond: torch.Tensor):
        # Self attention
        x = self.attn1(self.norm1(x)) + x

        # Cross-attention with conditioning
        x = self.attn2(self.norm2(x), cond=cond) + x

        # Feed-forward network
        x = self.ff(self.norm3(x)) + x

        return x

其中(attn1、attn2)非常重要的交叉注意力和自注意力都可以实现如下:

def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
    # If `cond` is `None` we perform self attention
    has_cond = cond is not None
    if not has_cond:
        cond = x

    # Get query, key and value vectors
    q = self.to_q(x)
    k = self.to_k(cond)
    v = self.to_v(cond)

    # Use flash attention if it's available and the head size is less than or equal to `128`
    if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
        return self.flash_attention(q, k, v)
    # Otherwise, fallback to normal attention
    else:
        return self.normal_attention(q, k, v)

Autoencoder (first_stage_model)

encode & decode

def encode(self, img: torch.Tensor) -> 'GaussianDistribution':
    # [batch_size, z_channels * 2, z_height, z_height]
    z = self.encoder(img) 

    # [batch_size, emb_channels * 2, z_height, z_height]
    moments = self.quant_conv(z) 

    return GaussianDistribution(moments)

高斯分布主要参数是均值 (mean) 和方差 (variance),即第一和第二阶矩 (moments)

def decode(self, z: torch.Tensor):
    z = self.post_quant_conv(z)
    return self.decoder(z)

GaussianDistribution

class GaussianDistribution:
    def __init__(self, parameters: torch.Tensor):
        self.mean, log_var = torch.chunk(parameters, 2, dim=1)
        self.log_var = torch.clamp(log_var, -30.0, 20.0)
        self.std = torch.exp(0.5 * self.log_var)

    def sample(self):
        return self.mean + self.std * torch.randn_like(self.std)

CLIPTextEmbedder (cond_stage_model)

class CLIPTextEmbedder(nn.Module):

     def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version).eval()
        self.device = device
        self.max_length = max_length

    def forward(self, prompts: List[str]):
        batch_encoding = self.tokenizer(
            prompts, 
            truncation=True, 
            max_length=self.max_length, 
            return_length=True, # 返回的一个字段为length!
            return_overflowing_tokens=False, 
            # 在某些情况下,你可能需要保留被截断的部分以便后续处理。但这里不用
            padding="max_length", 
            return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        return self.transformer(input_ids=tokens).last_hidden_state