Latent Diffusion Models
Code: nn.labml.ai: Latent Diffusion Models
Paper: High-Resolution Image Synthesis with Latent Diffusion Models
DDPM和LDM的核心差异在于Unet网络。在 DDPM 的代码分析中,我们已经实现了 Unet,但是这里的 Unet 和 DDPM 的 Unet 不同。因为 Stable Diffusion 需要引入条件信息(c)的控制。
Unet for Stable Diffusion
Unet for DDPM 我们之前详细分析过,主要是由ResidualBlock和AttentionBlock构成,其中ResidualBlock可以融合时间信息t,AttentionBlock单纯为了模型性能更好。
在 Unet for Stable Diffusion 中,ResidualBlock保持一致,可以融合时间信息t;AttentionBlock则引入了条件信息c引导,从多头注意力变成了交叉注意力!同时升级了一下,变成了SpatialTransformer。
SpatialTransformer
为了更好的理解,我画了一下其大致结构图,如下所示:
def forward(self, x: torch.Tensor, cond: torch.Tensor):
b, c, h, w = x.shape
x_in = x # For residual connection
x = self.norm(x)
# Transpose and reshape from `[batch_size, channels, height, width]`
# to `[batch_size, height * width, channels]`这个没办法视觉Transformer只能这样!
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).view(b, h * w, c)
for block in self.transformer_blocks:
x = block(x, cond)
# Reshape
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
其中使用到的transformer_blocks为多个BasicTransformerBlock:
self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
def forward(self, x: torch.Tensor, cond: torch.Tensor):
# Self attention
x = self.attn1(self.norm1(x)) + x
# Cross-attention with conditioning
x = self.attn2(self.norm2(x), cond=cond) + x
# Feed-forward network
x = self.ff(self.norm3(x)) + x
return x
其中(attn1、attn2)非常重要的交叉注意力和自注意力都可以实现如下:
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
# If `cond` is `None` we perform self attention
has_cond = cond is not None
if not has_cond:
cond = x
# Get query, key and value vectors
q = self.to_q(x)
k = self.to_k(cond)
v = self.to_v(cond)
# Use flash attention if it's available and the head size is less than or equal to `128`
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
return self.flash_attention(q, k, v)
# Otherwise, fallback to normal attention
else:
return self.normal_attention(q, k, v)
Autoencoder (first_stage_model)
encode & decode
def encode(self, img: torch.Tensor) -> 'GaussianDistribution':
# [batch_size, z_channels * 2, z_height, z_height]
z = self.encoder(img)
# [batch_size, emb_channels * 2, z_height, z_height]
moments = self.quant_conv(z)
return GaussianDistribution(moments)
高斯分布主要参数是均值 (mean) 和方差 (variance),即第一和第二阶矩 (moments)
def decode(self, z: torch.Tensor):
z = self.post_quant_conv(z)
return self.decoder(z)
GaussianDistribution
class GaussianDistribution:
def __init__(self, parameters: torch.Tensor):
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
self.log_var = torch.clamp(log_var, -30.0, 20.0)
self.std = torch.exp(0.5 * self.log_var)
def sample(self):
return self.mean + self.std * torch.randn_like(self.std)
CLIPTextEmbedder (cond_stage_model)
class CLIPTextEmbedder(nn.Module):
def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version).eval()
self.device = device
self.max_length = max_length
def forward(self, prompts: List[str]):
batch_encoding = self.tokenizer(
prompts,
truncation=True,
max_length=self.max_length,
return_length=True, # 返回的一个字段为length!
return_overflowing_tokens=False,
# 在某些情况下,你可能需要保留被截断的部分以便后续处理。但这里不用
padding="max_length",
return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return self.transformer(input_ids=tokens).last_hidden_state