DiTok v1.0 实验记录
概述
DiTok (Diffusion Tokenizer) 是基于 TiTok (1D Tokenizer) 的扩散模型改进版本。
核心思想: 结合 TiTok 的高效 tokenization 和扩散模型的生成能力,实现高质量的图像生成。
主要改进:
- 使用扩散解码器替代 TiTok 的确定性解码器
- 引入时间步条件 \(t\) 和类别标签条件 \(y\)
- 支持 EMA (Exponential Moving Average) 用于稳定推理
- 两阶段训练: Stage 1 联合训练, Stage 2 微调解码器
模型架构:
- Encoder: Vision Transformer (TiTok)
- Quantizer: Vector Quantization
- Decoder: Diffusion Transformer (DiTok)
TimestepEmbedder - 时间步嵌入
时间步嵌入模块负责将连续时间步 \(t \in [0, 1]\) 映射到高维空间。
为什么需要: 在扩散模型中,不同的时间步需要不同的预测能力。早期时间步需要重构整体结构,晚期时间步只需要细化细节。
实现方式: 使用正弦位置编码 (sinusoidal position embedding) + MLP 投影。
class TimestepEmbedder(nn.Module):
"""时间步嵌入模块
将连续时间步 t ∈ [0, 1] 映射到 hidden_size 维度的向量。
使用正弦位置编码 + MLP 投影的方式,与 DiT 一致。
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""创建正弦时间步嵌入
使用不同频率的正弦和余弦函数来编码时间信息。
这是一种绝对位置编码,可以唯一确定每个时间步。
Args:
t: 时间步张量 [batch_size]
dim: 嵌入维度
max_period: 最大周期,控制频率范围
Returns:
embedding: [batch_size, dim]
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
"""
Args:
t: 连续时间步 [batch_size], 范围 [0, 1]
Returns:
embedding: [batch_size, hidden_size]
"""
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
设计要点:
- 正弦编码: 提供唯一的时间步表示
- MLP 投影: 将低维频率编码映射到高维空间
- SiLU 激活: 平滑的非线性变换
LabelEmbedder - 类别标签嵌入
类别标签嵌入模块负责将离散的类别标签映射到连续的嵌入空间。
为什么需要: 在条件生成任务中,模型需要根据类别信息生成相应的图像。直接使用 one-hot 编码维度太高且稀疏,嵌入映射更高效。
class LabelEmbedder(nn.Module):
"""类别标签嵌入模块
将离散的类别标签映射到连续的嵌入空间。
支持分类器自由引导 (Classifier-Free Guidance)。
"""
def __init__(self, num_classes, hidden_size):
super().__init__()
# 创建一个额外的 token 用于"无条件"生成
# 当 CFG 时,部分样本会将标签替换为这个特殊 token
self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
def forward(self, labels):
"""
Args:
labels: 类别标签 [batch_size], 范围 [0, num_classes-1]
对于 CFG,可能有 num_classes 作为特殊的"无条件"标签
Returns:
embedding: [batch_size, hidden_size]
"""
return self.embedding_table(labels)
CFG (Classifier-Free Guidance):
- 训练时: 以一定概率随机将标签替换为"无条件"标签 (num_classes)
- 推理时: 同时计算条件生成和无条件生成,然后加权组合
- 效果: 增强生成样本与条件标签的一致性
DiTokDecoder - 扩散解码器
DiTokDecoder 是核心改进部分,负责从加噪的图像 patch tokens 重建清晰的图像。
关键设计决策:
- 接受两个输入: 加噪的图像 tokens (\(z_{patches}\)) + 量化的 latent tokens (\(z_{quantized}\))
- 使用条件 token 替代了 TiTok 中的 class_tokens
- 条件 token 由条件信息 \(t\) 和 \(y\) 组成,通过独立 token 传入,而非 DiT 中常见的 AdaLN-Zero.
- 这学习了 UViT 的设计。
- \(t\) 与 \(y\) 两个条件信息因为是相对独立(正交)的,因此可以直接相加。
- 之所以不用 AdaLN-Zero, 是为了在结构上与 TiTok 保持一致。
输入输出:
- Input: \(z_{patches} \in \mathbb{R}^{B \times 1024 \times D}\), \(z_{quantized} \in \mathbb{R}^{B \times 8 \times 1 \times 32}\), \(t \in [0, 1]\), \(labels \in \{0, ..., 999\}\)
- Output: \(x_{pred} \in \mathbb{R}^{B \times 3 \times 256 \times 256}\)
初始化
class DiTokDecoder(TiTokDecoder):
"""扩散模型的解码器,继承自 TiTokDecoder。
完全参考 TiTokDecoder 的 forward 结构,但接受两个输入:
1. z: 加噪后的图像 patch tokens(1024个)
2. z_quantized: 量化的 latent tokens(32个)
条件注入方式:
1. 时间步条件 t(通过 TimestepEmbedder)
2. 类别标签 labels(通过 LabelEmbedder)
3. 通过独立的条件 token 传入(参考 UViT,而非 DiT 的 AdaLN-Zero)
"""
def __init__(self, config):
# 继承 TiTokDecoder 的所有组件
super().__init__(config)
# 获取隐藏层维度
hidden_size = self.width # 例如: 768 for base model
num_classes = config.get("num_classes", 1000)
# ========== 时间步嵌入 ==========
frequency_embedding_size = config.get("frequency_embedding_size", 256)
self.t_embedder = TimestepEmbedder(hidden_size, frequency_embedding_size)
# ========== 标签嵌入 ==========
self.y_embedder = LabelEmbedder(num_classes, hidden_size)
# ========== 条件投影 ==========
# 将 t_emb + y_emb 投影到 hidden_size
# 这样可以让条件信息更灵活地融合
self.cond_proj = nn.Linear(hidden_size, hidden_size)
组件说明:
- TiTokDecoder 基础结构: Transformer 解码器,包含 self-attention 和 MLP
- TimestepEmbedder: 将时间步 \(t\) 映射到高维空间
- LabelEmbedder: 将类别标签映射到嵌入空间
- cond_proj: 条件投影层,融合 \(t\) 和 \(y\) 信息
前向传播 - 第一步: 处理量化 latent tokens
def forward(self, z_patches, z_quantized, t, labels):
"""
前向传播,带时间步条件和标签条件。
Args:
z_patches: 加噪后的图像 patch tokens(已 patchified)
shape [B, grid_size**2, width]
例如: [2, 1024, 768]
z_quantized: 量化的潜在表示(32个tokens)
shape [B, token_size, 1, num_latent_tokens]
例如: [2, 8, 1, 32]
t: 连续时间步 ∈ [0, 1], shape [B]
labels: 类别标签, shape [B]
Returns:
decoded: 解码后的图像, shape [B, C, H, W]
例如: [2, 3, 256, 256]
"""
batch_size = z_patches.shape[0]
# ========== 第一步:处理 z_quantized(32个 latent tokens)==========
Nq, Cq, Hq, Wq = z_quantized.shape
assert Hq == 1 and Wq == self.num_latent_tokens, \
f"Expected z_quantized shape [B, C, 1, {self.num_latent_tokens}], got {z_quantized.shape}"
# Reshape: [B, C, 1, num_latent_tokens] → [B, num_latent_tokens, C]
# 例如: [2, 8, 1, 32] → [2, 32, 8]
x_latent = z_quantized.reshape(Nq, Cq*Hq, Wq).permute(0, 2, 1)
# 通过 decoder_embed 投影到 width 维度
# 例如: [2, 32, 8] → [2, 32, 768]
x_latent = self.decoder_embed(x_latent)
# 添加 latent_token_positional_embedding
# 这是为了让模型知道每个 latent token 的位置信息
x_latent = x_latent + self.latent_token_positional_embedding[:x_latent.shape[1]]
说明:
- z_quantized 是原始图像经过 encoder 和 quantizer 后得到的离散 latent 表示
- 这些 latent tokens 包含了图像的高层语义信息
- 通过 decoder_embed 将 token_size 维度投影到 width 维度
- 添加位置编码以保留位置信息
前向传播 - 第二步: 验证图像 patch tokens
# ========== 第二步:验证 z_patches(1024个图像 patch tokens)==========
# z_patches 应该已经是 [B, grid_size**2, width] 格式(patchified)
assert z_patches.ndim == 3, f"Expected z_patches to be 3D tensor, got shape {z_patches.shape}"
assert z_patches.shape[1] == self.grid_size**2, \
f"Expected z_patches shape [B, {self.grid_size**2}, width], got {z_patches.shape}"
# 如果 z_patches 的最后一维不等于 width,需要投影
z = z_patches
if z.shape[-1] != self.width:
if not hasattr(self, 'z_proj'):
# 动态创建投影层 (lazy initialization)
self.z_proj = nn.Linear(z.shape[-1], self.width, device=z.device)
z = self.z_proj(z) # [B, grid_size**2, width]
说明:
- z_patches 是加噪图像经过 patchify 后得到的 tokens
- 在训练时,这些 tokens 会包含不同程度的噪声
- 在推理时,初始时刻这些是纯噪声
- 动态投影层确保维度匹配 (只在需要时创建)
前向传播 - 第三步: 创建条件 token
# ========== 第三步:创建条件 token ==========
# 计算条件嵌入
t_emb = self.t_embedder(t) # [B, hidden_size]
y_emb = self.y_embedder(labels) # [B, hidden_size]
# 将 t_emb 和 y_emb 相加 (假设它们正交)
# 这是一种简单但有效的条件融合方式
cond = t_emb + y_emb # [B, hidden_size]
# 通过 cond_proj 进行投影,增加灵活性
cond = self.cond_proj(cond) # [B, hidden_size]
# 将 cond 作为条件 token (扩展一个序列维度)
cond_token = cond.unsqueeze(1) # [B, 1, width]
# 拼接条件 token 和 z (图像 patch tokens)
# 序列结构: [cond_token, z_0, z_1, ..., z_1023]
partial = torch.cat([cond_token, z], dim=1) # [B, 1 + grid_size**2, width]
# 加上完整的 positional_embedding(第一个位置给条件 token)
# 位置编码让模型知道每个 token 在序列中的位置
partial = partial + self.positional_embedding # [B, 1 + grid_size**2, width]
设计思路:
- 条件融合: \(t\) 和 \(y\) 相加是因为它们来自不同的嵌入空间,近似正交
- 独立 token: 将条件作为独立 token,参考 UViT 而非 DiT 的 AdaLN-Zero
- 位置编码: 条件 token 也有位置编码 (位置 0)
前向传播 - 第四步: 拼接所有 tokens
# ========== 第四步:拼接 x_latent(它已经有 latent_token_positional_embedding 了)==========
# 完整序列结构: [cond_token, z_patches, x_latent]
# 例如: [1, 1024, 32] = 1057 个 tokens
x = torch.cat([partial, x_latent], dim=1) # [B, 1 + grid_size**2 + num_latent_tokens, width]
序列结构:
[cond_token, z_0, z_1, ..., z_1023, latent_0, latent_1, ..., latent_31]
0 1 2 ... 1024 1025 1026 ... 1056
- cond_token: 全局条件 (时间步 + 类别)
- z_patches: 局部图像信息 (加噪的 patch tokens)
- x_latent: 全局语义信息 (量化的 latent tokens)
前向传播 - 第五步: Transformer 处理
# ========== 第五步:通过 Transformer(参考 TiTokDecoder)==========
# Layer Normalization (预处理)
x = self.ln_pre(x)
# 转置为 LND 格式 (Transformer 的标准输入格式)
x = x.permute(1, 0, 2) # NLD -> LND
# 例如: [2, 1057, 768] -> [1057, 2, 768]
# 通过多层 Transformer
for i in range(self.num_layers): # 例如: 12 层
x = self.transformer[i](x)
# 转置回 NLD 格式
x = x.permute(1, 0, 2) # LND -> NLD
Transformer 结构:
- Multi-head Self-Attention: 让所有 tokens 之间相互交互
- Feed-Forward Network: 非线性变换
- Layer Normalization: 稳定训练
前向传播 - 第六步: 提取图像 patch tokens
# ========== 第六步:提取图像 patch tokens ==========
# 序列是:[cond_token, z, x_latent]
# 我们要取 z 的部分:从索引 1 到 1+grid_size**2
x = x[:, 1:1+self.grid_size**2] # [B, grid_size**2, width]
为什么只提取 z:
- 我们只需要预测清晰的图像,不需要预测条件或 latent tokens
- 条件 token 和 latent tokens 只是辅助信息
前向传播 - 第七步: 重建图像
# ========== 第七步:最终处理(参考 TiTokDecoder)==========
# Layer Normalization (后处理)
x = self.ln_post(x)
# 重塑为空间格式: N L D -> N D H W
# 例如: [2, 1024, 768] -> [2, 768, 16, 16]
x = x.permute(0, 2, 1).reshape(batch_size, self.width, self.grid_size, self.grid_size)
# 通过 FFN 进行上采样
# FFN 包含: Conv2d(768, 768) + Rearrange (unpatchify)
# 将 16x16 grid 重建为 256x256 图像
x = self.ffn(x.contiguous())
# 最后的卷积层: 生成 RGB 图像
x = self.conv_out(x) # [B, 3, 256, 256]
return x
FFN (Feed-Forward Network):
- Conv2d(768, 768): 增加非线性
- Rearrange: unpatchify 操作,将 patch tokens 重排为像素
- 输出: [B, 3, 256, 256] 的 RGB 图像
DiTok - 完整模型
DiTok 是完整的端到端模型,包含 Encoder、Quantizer 和 Decoder。
两阶段训练策略:
- Stage 1: 联合训练 Encoder + Quantizer + Decoder (不需要,直接用预训练的 TiTok)
- Stage 2: 冻结 Encoder 和 Quantizer,只训练 Decoder
为什么冻结 Encoder 和 Quantizer:
- TiTok 已经在大量数据上训练好了,能很好地提取图像特征
- 我们只需要学习如何从这些特征生成清晰的图像
- 减少训练参数,加快训练速度
初始化
class DiTok(nn.Module):
def __init__(self, config):
super().__init__()
# ========== Encoder: TiTok 编码器 ==========
self.encoder = TiTokEncoder(config)
# ========== Learnable Latent Tokens ==========
# 可学习的 latent tokens,用于 cross-attention
# 与 TiTok 保持一致
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
scale = self.encoder.width ** -0.5 # Xavier 初始化
self.latent_tokens = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.encoder.width)
)
# ========== Quantizer: 向量量化器 ==========
vq_config = config.model.vq_model
self.quantizer = VectorQuantizer(
codebook_size=vq_config.codebook_size, # 例如: 1024
token_size=vq_config.token_size, # 例如: 8
commitment_cost=vq_config.commitment_cost, # 例如: 0.25
use_l2_norm=vq_config.use_l2_norm, # False
clustering_vq=vq_config.get("clustering_vq", False)
)
# ========== Decoder: DiTok 扩散解码器 ==========
self.diffusion_decoder = DiTokDecoder(config)
Stage 2 训练配置
# ========== Stage 2 训练配置 ==========
self.finetune_decoder = config.get("finetune_decoder", False)
self.freeze_encoder = config.get("freeze_encoder", False)
self.freeze_quantizer = config.get("freeze_quantizer", False)
# 如果启用 Stage 2,冻结 encoder 和 quantizer
if self.freeze_encoder:
self.encoder.eval() # 设置为 eval 模式
for param in self.encoder.parameters():
param.requires_grad = False # 冻结参数
print("Encoder frozen for Stage 2 training")
if self.freeze_quantizer:
self.quantizer.eval()
for param in self.quantizer.parameters():
param.requires_grad = False
print("Quantizer frozen for Stage 2 training")
冻结参数的意义:
- requires_grad = False: 不计算梯度,节省内存
- eval(): 关闭 Dropout 和 BatchNorm 的训练行为
扩散训练参数
# ========== Diffusion 训练参数 (from JiT) ==========
# 时间步采样参数 (logit-normal 分布)
self.P_std = config.get("P_std", 0.8) # 标准差
self.P_mean = config.get("P_mean", -0.8) # 均值
self.t_eps = config.get("t_eps", 5e-2) # 最小时间步 (避免除零)
# 噪声缩放
self.noise_scale = config.get("noise_scale", 1.0)
# 标签丢弃概率 (用于 Classifier-Free Guidance)
self.label_drop_prob = config.get("label_drop_prob", 0.1)
# ========== 损失权重 ==========
self.v_weight = config.get("v_weight", 1.0) # v-prediction loss 权重
self.gamma_weight = config.get("gamma_weight", 0.0) # 负对数似然权重
self.reconstruction_weight = config.get("reconstruction_weight", 1.0)
时间步采样:
- 使用 logit-normal 分布: \(t = \sigma(\mathcal{N}(\mu, \sigma^2))\)
- \(\mu = -0.8, \sigma = 0.8\): 更偏向小时间步 (晚期采样)
- t_eps: 防止 \(t=1\) 时除以零
Logvar 参数 (用于 gamma_loss)
# ========== Logvar 参数 (用于 gamma_loss) ==========
# 为了在连续时间扩散中使用 gamma_loss,需要:
# 1. 将连续时间 t ∈ [0, 1] 离散化为 num_timesteps 个 bin
# 2. 每个 bin 学习一个独立的 logvar 值
self.num_timesteps = config.get("num_timesteps", 1000)
logvar_init = config.get("logvar_init", 0.0)
self.learn_logvar = config.get("learn_logvar", False)
# 创建 logvar 张量
logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(logvar, requires_grad=True)
else:
self.register_buffer('logvar', logvar)
Gamma Loss 的作用:
- 约束 \(v\) 的分布为正态分布 \(\mathcal{N}(v_{pred}, \exp(\text{logvar}_t))\)
- 提供额外的训练信号,改进生成质量
感知损失配置 (可选)
# ========== 感知损失配置 (Stage 2 可选) ==========
# 支持 perceptual_loss_type: 'vgg', 'lpips', 'none'
# - 'vgg': 使用 VGG-16BN 感知损失 (计算快,效果一般)
# - 'lpips': 使用 LPIPS 感知损失 (需要 pip install lpips,效果最好)
# - 'none': 不使用感知损失
perceptual_loss_type = config.get("perceptual_loss_type", "none").lower()
if perceptual_loss_type == "lpips":
from modules.ditok_blocks import LPIPSLoss
self.perceptual_loss = LPIPSLoss(config)
self.perceptual_weight = config.get("perceptual_weight", 0.5)
self.use_perceptual_loss = True
elif perceptual_loss_type == "vgg":
from modules.ditok_blocks import VGGPerceptualLoss
self.perceptual_loss = VGGPerceptualLoss(config)
self.perceptual_weight = config.get("perceptual_weight", 0.5)
self.use_perceptual_loss = True
else:
self.perceptual_loss = None
self.perceptual_weight = 0.0
self.use_perceptual_loss = False
感知损失的作用:
- 像素级 MSE 不一定能反映感知质量
- 特征空间的距离更符合人眼感知
- LPIPS 与人类感知高度一致 (推荐使用)
EMA 参数
# ========== EMA (Exponential Moving Average) 参数 ==========
self.ema_decay1 = config.get("ema_decay1", 0.9999) # 主 EMA
self.ema_decay2 = config.get("ema_decay2", 0.9996) # 辅助 EMA (衰减更快)
self.ema_params1 = None # EMA1 参数
self.ema_params2 = None # EMA2 参数
# ========== 生成参数 ==========
self.sampling_method = config.get("sampling_method", "heun") # 采样方法
self.num_sampling_steps = config.get("num_sampling_steps", 50)
self.cfg_scale = config.get("cfg", 1.0) # CFG 强度
self.cfg_interval = (
config.get("interval_min", 0.0),
config.get("interval_max", 1.0)
)
EMA 的作用:
- 参数平滑,减少推理时的随机波动
- 提高生成质量和稳定性
- 双 EMA: 提供不同程度的平滑
前向传播 - 第一步: 编码和量化
def forward(self, x, labels):
"""
DiTok 前向传播和损失计算。
支持 Stage 1 和 Stage 2 训练:
- Stage 1: 训练整个模型(encoder + quantizer + decoder)
- Stage 2: 冻结 encoder 和 quantizer,只训练 decoder
损失组成:
- v_loss: 扩散模型的 v-prediction loss(MSE)
- gamma_loss: v 的负对数似然(可选)
- perceptual_loss: 感知损失(Stage 2,可选)
- bpp_loss: 在 Stage 2 中设为 0
"""
loss_dict = {}
# ========== 第一步:通过 Encoder 生成编码值 ==========
if self.freeze_encoder:
# Stage 2: 使用 torch.no_grad() 节省内存
with torch.no_grad():
z_encoded = self.encoder(x, self.latent_tokens)
else:
# Stage 1: 正常前向传播
z_encoded = self.encoder(x, self.latent_tokens)
# 输出形状: z_encoded [B, token_size, 1, num_latent_tokens]
# 例如: [2, 8, 1, 32]
Encoder 的工作原理:
- 图像 \(x\) 经过 patch embedding
- 与 learnable latent tokens 进行 cross-attention
- 输出压缩的 latent 表示
量化
# ========== 量化 (VQ 模式) ==========
if self.freeze_quantizer:
# Stage 2: 冻结量化器
with torch.no_grad():
z_quantized, quant_result_dict = self.quantizer(z_encoded)
# 将量化损失设为 0 (因为不更新量化器)
bpp_loss = torch.tensor(0.0, device=x.device)
quant_result_dict['quantizer_loss'] = torch.tensor(0.0, device=x.device)
quant_result_dict['codebook_loss'] = torch.tensor(0.0, device=x.device)
quant_result_dict['commitment_loss'] = torch.tensor(0.0, device=x.device)
else:
# Stage 1: 正常训练量化器
z_quantized, quant_result_dict = self.quantizer(z_encoded)
bpp_loss = quant_result_dict.get('quantizer_loss', torch.tensor(0.0, device=x.device))
Vector Quantization:
- 将每个 latent token 映射到最近的 codebook 向量
- Commitment loss: 鼓励 encoder 输出接近 codebook
- Codebook loss: 更新 codebook 向量
前向传播 - 第二步: 加噪
# ========== 第二步:加噪得到中间态 ==========
# 在图像空间加噪,然后转换为 patch tokens
t = self.sample_t(x.size(0), device=x.device) # [B]
# 采样标准高斯噪声
e = torch.randn_like(x) * self.noise_scale # [B, C, H, W]
# 加噪公式: z_noisy = t * x + (1 - t) * e
# t=0: 纯噪声; t=1: 原始图像
z_noisy = t.view(-1, *([1] * (x.ndim - 1))) * x + (1 - t.view(-1, *([1] * (x.ndim - 1)))) * e
# 将加噪图像转换为 patch tokens (复用 encoder 的 patch_embed)
z_patches = self.encoder.patch_embed(z_noisy) # [B, width, H', W']
z_patches = z_patches.reshape(z_patches.shape[0], z_patches.shape[1], -1) # [B, width, H'*W']
z_patches = z_patches.permute(0, 2, 1) # [B, H'*W', width]
加噪公式说明:
- \(z_{noisy} = t \cdot x + (1-t) \cdot \epsilon\)
- 这是连续时间的扩散过程
- \(t\) 越大,保留的原始信息越多
前向传播 - 第三步: 标签丢弃 (CFG)
# ========== Classifier-Free Guidance 训练 ==========
labels_dropped = self.drop_labels(labels) if self.training else labels
CFG 机制:
- 训练时: 随机将部分样本的标签替换为"无条件"标签
- 推理时: \(\hat{x} = x_{uncond} + \text{CFG} \cdot (x_{cond} - x_{uncond})\)
- 效果: 增强条件一致性
前向传播 - 第四步: Decoder 预测
# ========== 第三步:通过 Diffusion Decoder 生成预测值 ==========
x_pred = self.diffusion_decoder(z_patches, z_quantized, t, labels_dropped)
# 输出: [B, 3, 256, 256]
Decoder 的输入:
- z_patches: 加噪的图像 tokens (需要预测清晰的)
- z_quantized: 原始图像的 latent 表示 (作为条件)
- t: 当前时间步
- labels: 类别标签
前向传播 - 第五步: 计算 v-prediction loss
# ========== 第四步:计算 v 的损失 ==========
# v 是从加噪图像到原始图像的"速度"
# 真实速度: v = (x - z_noisy) / (1 - t)
# 预测速度: v_pred = (x_pred - z_noisy) / (1 - t)
# 计算真实速度
v = (x - z_noisy) / (1 - t.view(-1, *([1] * (x.ndim - 1)))).clamp_min(self.t_eps)
# 计算预测速度
v_pred = (x_pred - z_noisy) / (1 - t.view(-1, *([1] * (x.ndim - 1)))).clamp_min(self.t_eps)
# MSE loss
v_loss = (v - v_pred) ** 2
v_loss = v_loss.mean(dim=(1, 2, 3)) # [B]
loss_dict.update({'v_loss': v_loss.mean()})
V-Prediction 的优势:
- 相比 $ε$-prediction, \(v\) 的分布更均匀
- 更容易学习,训练更稳定
- 来自 Progressive Distillation 论文
前向传播 - 第六步: 计算 gamma_loss (可选)
# ========== 第五步:计算 gamma_loss(v 的负对数似然)==========
# 目的:让 v 的分布呈现正态分布 N(v_pred, exp(logvar_t))
# 公式:log N(v; v_pred, σ²) = (v - v_pred)² / (2σ²) + log(σ)
if self.gamma_weight > 0:
# 将连续时间 t ∈ [0, 1] 离散化为时间步索引
t_idx = (t * (self.num_timesteps - 1)).long()
t_idx = torch.clamp(t_idx, 0, self.num_timesteps - 1)
# 获取对应时间步的 logvar
logvar_t = self.logvar[t_idx] # [B]
# 计算 gamma_loss
gamma_loss = v_loss / (2 * torch.exp(logvar_t)) + logvar_t / 2
loss_dict.update({'gamma_loss': gamma_loss.mean()})
else:
gamma_loss = torch.tensor(0.0, device=x.device)
Gamma Loss 的作用:
- 提供似然目标,不仅仅是 MSE
- 自适应地调整不同时间步的权重
- 来自 Improved Denoising Diffusion Models
前向传播 - 第七步: 感知损失 (可选)
# ========== 第七步:计算 perceptual loss(Stage 2,可选)==========
if self.use_perceptual_loss and self.perceptual_loss is not None:
# 使用配置的感知损失 (VGG 或 LPIPS)
perceptual_loss = self.perceptual_loss(x, x_pred)
loss_dict.update({'perceptual_loss': perceptual_loss})
else:
perceptual_loss = torch.tensor(0.0, device=x.device)
前向传播 - 第八步: 总损失
# ========== 第七步:bpp loss(Stage 2 中为 0)==========
if self.freeze_quantizer:
bpp_loss = torch.tensor(0.0, device=x.device)
loss_dict.update({'bpp_loss': bpp_loss})
loss_dict.update({'codebook_loss': quant_result_dict['codebook_loss']})
loss_dict.update({'commitment_loss': quant_result_dict['commitment_loss']})
# ========== 第八步:计算总损失 ==========
loss = (
self.v_weight * v_loss.mean() +
self.gamma_weight * gamma_loss.mean() +
self.perceptual_weight * perceptual_loss +
bpp_loss # Stage 2 中为 0
)
loss_dict.update({'loss': loss})
return loss_dict
损失权重调节:
- v_weight: 主要损失,通常设为 1.0
- gamma_weight: 可选,通常设为 1.0
- perceptual_weight: 可选,通常设为 0.5
辅助函数
def drop_labels(self, labels):
"""Randomly drop labels for classifier-free guidance training."""
drop = torch.rand(labels.shape[0], device=labels.device) < self.label_drop_prob
num_classes = labels.max().item() + 1 # infer number of classes
out = torch.where(drop, torch.full_like(labels, num_classes), labels)
return out
def sample_t(self, n: int, device=None):
"""从 logit-normal 分布采样时间步"""
z = torch.randn(n, device=device) * self.P_std + self.P_mean
return torch.sigmoid(z)
EMA 更新
@torch.no_grad()
def update_ema(self):
"""更新 EMA 参数(参考 JiT 实现)
EMA (Exponential Moving Average) 用于平滑模型参数,
提高推理时的稳定性和生成质量。
更新公式:
θ_ema = decay * θ_ema + (1 - decay) * θ
使用两个 EMA 跟踪器:
- EMA1: decay=0.9999, 主 EMA,用于推理
- EMA2: decay=0.9996, 辅助 EMA,衰减更快
"""
# 第一次调用时,初始化 EMA 参数
if self.ema_params1 is None:
self.ema_params1 = [p.clone().detach() for p in self.parameters()]
if self.ema_params2 is None:
self.ema_params2 = [p.clone().detach() for p in self.parameters()]
# 获取当前模型参数
source_params = list(self.parameters())
# 更新 EMA1(主 EMA,用于推理)
for targ, src in zip(self.ema_params1, source_params):
# θ_ema = decay * θ_ema + (1 - decay) * θ
targ.detach().mul_(self.ema_decay1).add_(src.data, alpha=1 - self.ema_decay1)
# 更新 EMA2(辅助 EMA,衰减更快)
for targ, src in zip(self.ema_params2, source_params):
targ.detach().mul_(self.ema_decay2).add_(src.data, alpha=1 - self.ema_decay2)
EMA 的使用:
- 训练时: 每次更新后调用 update_ema()
- 推理时: 使用 EMA 参数而非原始参数
- 保存时: 同时保存原始参数和 EMA 参数
训练脚本调用示例
# 在训练循环中调用 EMA 更新
for batch in dataloader:
loss = model(x, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 更新 EMA 参数
if accelerator.is_main_process:
model.update_ema()
# 更新学习率
lr_scheduler.step()
完整代码汇总
DiTok 的完整实现包括以下核心模块:
- TimestepEmbedder: 时间步嵌入
- LabelEmbedder: 类别标签嵌入
- DiTokDecoder: 扩散解码器
- DiTok: 端到端模型
- VectorQuantizer: 向量量化器
- VGGPerceptualLoss: 基于 VGG 的感知损失
- LPIPSLoss: 学习得到的感知损失
- EMA: 指数移动平均
关键设计决策:
- 使用 v-prediction 而非 $ε$-prediction
- 使用条件 token 而非 AdaLN-Zero
- 两阶段训练策略
- EMA 用于稳定推理