ESC
输入关键词搜索文章
目录

DiTok v1.0 实验记录

概述

DiTok (Diffusion Tokenizer) 是基于 TiTok (1D Tokenizer) 的扩散模型改进版本。

核心思想: 结合 TiTok 的高效 tokenization 和扩散模型的生成能力,实现高质量的图像生成。

主要改进:

  1. 使用扩散解码器替代 TiTok 的确定性解码器
  2. 引入时间步条件 \(t\) 和类别标签条件 \(y\)
  3. 支持 EMA (Exponential Moving Average) 用于稳定推理
  4. 两阶段训练: Stage 1 联合训练, Stage 2 微调解码器

模型架构:

TimestepEmbedder - 时间步嵌入

时间步嵌入模块负责将连续时间步 \(t \in [0, 1]\) 映射到高维空间。

为什么需要: 在扩散模型中,不同的时间步需要不同的预测能力。早期时间步需要重构整体结构,晚期时间步只需要细化细节。

实现方式: 使用正弦位置编码 (sinusoidal position embedding) + MLP 投影。

class TimestepEmbedder(nn.Module):
    """时间步嵌入模块

    将连续时间步 t ∈ [0, 1] 映射到 hidden_size 维度的向量。

    使用正弦位置编码 + MLP 投影的方式,与 DiT 一致。
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """创建正弦时间步嵌入

        使用不同频率的正弦和余弦函数来编码时间信息。
        这是一种绝对位置编码,可以唯一确定每个时间步。

        Args:
            t: 时间步张量 [batch_size]
            dim: 嵌入维度
            max_period: 最大周期,控制频率范围

        Returns:
            embedding: [batch_size, dim]
        """
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        """
        Args:
            t: 连续时间步 [batch_size], 范围 [0, 1]

        Returns:
            embedding: [batch_size, hidden_size]
        """
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

设计要点:

  1. 正弦编码: 提供唯一的时间步表示
  2. MLP 投影: 将低维频率编码映射到高维空间
  3. SiLU 激活: 平滑的非线性变换

LabelEmbedder - 类别标签嵌入

类别标签嵌入模块负责将离散的类别标签映射到连续的嵌入空间。

为什么需要: 在条件生成任务中,模型需要根据类别信息生成相应的图像。直接使用 one-hot 编码维度太高且稀疏,嵌入映射更高效。

class LabelEmbedder(nn.Module):
    """类别标签嵌入模块

    将离散的类别标签映射到连续的嵌入空间。
    支持分类器自由引导 (Classifier-Free Guidance)。
    """
    def __init__(self, num_classes, hidden_size):
        super().__init__()
        # 创建一个额外的 token 用于"无条件"生成
        # 当 CFG 时,部分样本会将标签替换为这个特殊 token
        self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)

    def forward(self, labels):
        """
        Args:
            labels: 类别标签 [batch_size], 范围 [0, num_classes-1]
                    对于 CFG,可能有 num_classes 作为特殊的"无条件"标签

        Returns:
            embedding: [batch_size, hidden_size]
        """
        return self.embedding_table(labels)

CFG (Classifier-Free Guidance):

DiTokDecoder - 扩散解码器

DiTokDecoder 是核心改进部分,负责从加噪的图像 patch tokens 重建清晰的图像。

关键设计决策:

  1. 接受两个输入: 加噪的图像 tokens (\(z_{patches}\)) + 量化的 latent tokens (\(z_{quantized}\))
  2. 使用条件 token 替代了 TiTok 中的 class_tokens
  3. 条件 token 由条件信息 \(t\)\(y\) 组成,通过独立 token 传入,而非 DiT 中常见的 AdaLN-Zero.
    • 这学习了 UViT 的设计。
    • \(t\)\(y\) 两个条件信息因为是相对独立(正交)的,因此可以直接相加。
    • 之所以不用 AdaLN-Zero, 是为了在结构上与 TiTok 保持一致。

输入输出:

初始化

class DiTokDecoder(TiTokDecoder):
    """扩散模型的解码器,继承自 TiTokDecoder。

    完全参考 TiTokDecoder 的 forward 结构,但接受两个输入:
    1. z: 加噪后的图像 patch tokens(1024个)
    2. z_quantized: 量化的 latent tokens(32个)

    条件注入方式:
    1. 时间步条件 t(通过 TimestepEmbedder)
    2. 类别标签 labels(通过 LabelEmbedder)
    3. 通过独立的条件 token 传入(参考 UViT,而非 DiT 的 AdaLN-Zero)
    """

    def __init__(self, config):
        # 继承 TiTokDecoder 的所有组件
        super().__init__(config)

        # 获取隐藏层维度
        hidden_size = self.width  # 例如: 768 for base model
        num_classes = config.get("num_classes", 1000)

        # ========== 时间步嵌入 ==========
        frequency_embedding_size = config.get("frequency_embedding_size", 256)
        self.t_embedder = TimestepEmbedder(hidden_size, frequency_embedding_size)

        # ========== 标签嵌入 ==========
        self.y_embedder = LabelEmbedder(num_classes, hidden_size)

        # ========== 条件投影 ==========
        # 将 t_emb + y_emb 投影到 hidden_size
        # 这样可以让条件信息更灵活地融合
        self.cond_proj = nn.Linear(hidden_size, hidden_size)

组件说明:

  1. TiTokDecoder 基础结构: Transformer 解码器,包含 self-attention 和 MLP
  2. TimestepEmbedder: 将时间步 \(t\) 映射到高维空间
  3. LabelEmbedder: 将类别标签映射到嵌入空间
  4. cond_proj: 条件投影层,融合 \(t\)\(y\) 信息

前向传播 - 第一步: 处理量化 latent tokens

def forward(self, z_patches, z_quantized, t, labels):
    """
    前向传播,带时间步条件和标签条件。

    Args:
        z_patches: 加噪后的图像 patch tokens(已 patchified)
                  shape [B, grid_size**2, width]
                  例如: [2, 1024, 768]
        z_quantized: 量化的潜在表示(32个tokens)
                  shape [B, token_size, 1, num_latent_tokens]
                  例如: [2, 8, 1, 32]
        t: 连续时间步 ∈ [0, 1], shape [B]
        labels: 类别标签, shape [B]

    Returns:
        decoded: 解码后的图像, shape [B, C, H, W]
                 例如: [2, 3, 256, 256]
    """
    batch_size = z_patches.shape[0]

    # ========== 第一步:处理 z_quantized(32个 latent tokens)==========
    Nq, Cq, Hq, Wq = z_quantized.shape
    assert Hq == 1 and Wq == self.num_latent_tokens, \
        f"Expected z_quantized shape [B, C, 1, {self.num_latent_tokens}], got {z_quantized.shape}"

    # Reshape: [B, C, 1, num_latent_tokens] → [B, num_latent_tokens, C]
    # 例如: [2, 8, 1, 32] → [2, 32, 8]
    x_latent = z_quantized.reshape(Nq, Cq*Hq, Wq).permute(0, 2, 1)

    # 通过 decoder_embed 投影到 width 维度
    # 例如: [2, 32, 8] → [2, 32, 768]
    x_latent = self.decoder_embed(x_latent)

    # 添加 latent_token_positional_embedding
    # 这是为了让模型知道每个 latent token 的位置信息
    x_latent = x_latent + self.latent_token_positional_embedding[:x_latent.shape[1]]

说明:

前向传播 - 第二步: 验证图像 patch tokens

# ========== 第二步:验证 z_patches(1024个图像 patch tokens)==========
# z_patches 应该已经是 [B, grid_size**2, width] 格式(patchified)
assert z_patches.ndim == 3, f"Expected z_patches to be 3D tensor, got shape {z_patches.shape}"
assert z_patches.shape[1] == self.grid_size**2, \
    f"Expected z_patches shape [B, {self.grid_size**2}, width], got {z_patches.shape}"

# 如果 z_patches 的最后一维不等于 width,需要投影
z = z_patches
if z.shape[-1] != self.width:
    if not hasattr(self, 'z_proj'):
        # 动态创建投影层 (lazy initialization)
        self.z_proj = nn.Linear(z.shape[-1], self.width, device=z.device)
    z = self.z_proj(z)  # [B, grid_size**2, width]

说明:

前向传播 - 第三步: 创建条件 token

# ========== 第三步:创建条件 token ==========
# 计算条件嵌入
t_emb = self.t_embedder(t)  # [B, hidden_size]
y_emb = self.y_embedder(labels)  # [B, hidden_size]

# 将 t_emb 和 y_emb 相加 (假设它们正交)
# 这是一种简单但有效的条件融合方式
cond = t_emb + y_emb  # [B, hidden_size]

# 通过 cond_proj 进行投影,增加灵活性
cond = self.cond_proj(cond)  # [B, hidden_size]

# 将 cond 作为条件 token (扩展一个序列维度)
cond_token = cond.unsqueeze(1)  # [B, 1, width]

# 拼接条件 token 和 z (图像 patch tokens)
# 序列结构: [cond_token, z_0, z_1, ..., z_1023]
partial = torch.cat([cond_token, z], dim=1)  # [B, 1 + grid_size**2, width]

# 加上完整的 positional_embedding(第一个位置给条件 token)
# 位置编码让模型知道每个 token 在序列中的位置
partial = partial + self.positional_embedding  # [B, 1 + grid_size**2, width]

设计思路:

  1. 条件融合: \(t\)\(y\) 相加是因为它们来自不同的嵌入空间,近似正交
  2. 独立 token: 将条件作为独立 token,参考 UViT 而非 DiT 的 AdaLN-Zero
  3. 位置编码: 条件 token 也有位置编码 (位置 0)

前向传播 - 第四步: 拼接所有 tokens

# ========== 第四步:拼接 x_latent(它已经有 latent_token_positional_embedding 了)==========
# 完整序列结构: [cond_token, z_patches, x_latent]
# 例如: [1, 1024, 32] = 1057 个 tokens
x = torch.cat([partial, x_latent], dim=1)  # [B, 1 + grid_size**2 + num_latent_tokens, width]

序列结构:

[cond_token, z_0, z_1, ..., z_1023, latent_0, latent_1, ..., latent_31]
     0      1    2  ...  1024   1025       1026  ...           1056

前向传播 - 第五步: Transformer 处理

# ========== 第五步:通过 Transformer(参考 TiTokDecoder)==========
# Layer Normalization (预处理)
x = self.ln_pre(x)

# 转置为 LND 格式 (Transformer 的标准输入格式)
x = x.permute(1, 0, 2)  # NLD -> LND
                     # 例如: [2, 1057, 768] -> [1057, 2, 768]

# 通过多层 Transformer
for i in range(self.num_layers):  # 例如: 12 层
    x = self.transformer[i](x)

# 转置回 NLD 格式
x = x.permute(1, 0, 2)  # LND -> NLD

Transformer 结构:

前向传播 - 第六步: 提取图像 patch tokens

# ========== 第六步:提取图像 patch tokens ==========
# 序列是:[cond_token, z, x_latent]
# 我们要取 z 的部分:从索引 1 到 1+grid_size**2
x = x[:, 1:1+self.grid_size**2]  # [B, grid_size**2, width]

为什么只提取 z:

前向传播 - 第七步: 重建图像

# ========== 第七步:最终处理(参考 TiTokDecoder)==========
# Layer Normalization (后处理)
x = self.ln_post(x)

# 重塑为空间格式: N L D -> N D H W
# 例如: [2, 1024, 768] -> [2, 768, 16, 16]
x = x.permute(0, 2, 1).reshape(batch_size, self.width, self.grid_size, self.grid_size)

# 通过 FFN 进行上采样
# FFN 包含: Conv2d(768, 768) + Rearrange (unpatchify)
# 将 16x16 grid 重建为 256x256 图像
x = self.ffn(x.contiguous())

# 最后的卷积层: 生成 RGB 图像
x = self.conv_out(x)  # [B, 3, 256, 256]

return x

FFN (Feed-Forward Network):

  1. Conv2d(768, 768): 增加非线性
  2. Rearrange: unpatchify 操作,将 patch tokens 重排为像素
  3. 输出: [B, 3, 256, 256] 的 RGB 图像

DiTok - 完整模型

DiTok 是完整的端到端模型,包含 Encoder、Quantizer 和 Decoder。

两阶段训练策略:

为什么冻结 Encoder 和 Quantizer:

  1. TiTok 已经在大量数据上训练好了,能很好地提取图像特征
  2. 我们只需要学习如何从这些特征生成清晰的图像
  3. 减少训练参数,加快训练速度

初始化

class DiTok(nn.Module):
    def __init__(self, config):
        super().__init__()

        # ========== Encoder: TiTok 编码器 ==========
        self.encoder = TiTokEncoder(config)

        # ========== Learnable Latent Tokens ==========
        # 可学习的 latent tokens,用于 cross-attention
        # 与 TiTok 保持一致
        self.num_latent_tokens = config.model.vq_model.num_latent_tokens
        scale = self.encoder.width ** -0.5  # Xavier 初始化
        self.latent_tokens = nn.Parameter(
            scale * torch.randn(self.num_latent_tokens, self.encoder.width)
        )

        # ========== Quantizer: 向量量化器 ==========
        vq_config = config.model.vq_model
        self.quantizer = VectorQuantizer(
            codebook_size=vq_config.codebook_size,    # 例如: 1024
            token_size=vq_config.token_size,          # 例如: 8
            commitment_cost=vq_config.commitment_cost, # 例如: 0.25
            use_l2_norm=vq_config.use_l2_norm,        # False
            clustering_vq=vq_config.get("clustering_vq", False)
        )

        # ========== Decoder: DiTok 扩散解码器 ==========
        self.diffusion_decoder = DiTokDecoder(config)

Stage 2 训练配置

# ========== Stage 2 训练配置 ==========
self.finetune_decoder = config.get("finetune_decoder", False)
self.freeze_encoder = config.get("freeze_encoder", False)
self.freeze_quantizer = config.get("freeze_quantizer", False)

# 如果启用 Stage 2,冻结 encoder 和 quantizer
if self.freeze_encoder:
    self.encoder.eval()  # 设置为 eval 模式
    for param in self.encoder.parameters():
        param.requires_grad = False  # 冻结参数
    print("Encoder frozen for Stage 2 training")

if self.freeze_quantizer:
    self.quantizer.eval()
    for param in self.quantizer.parameters():
        param.requires_grad = False
    print("Quantizer frozen for Stage 2 training")

冻结参数的意义:

  1. requires_grad = False: 不计算梯度,节省内存
  2. eval(): 关闭 Dropout 和 BatchNorm 的训练行为

扩散训练参数

# ========== Diffusion 训练参数 (from JiT) ==========
# 时间步采样参数 (logit-normal 分布)
self.P_std = config.get("P_std", 0.8)    # 标准差
self.P_mean = config.get("P_mean", -0.8) # 均值
self.t_eps = config.get("t_eps", 5e-2)   # 最小时间步 (避免除零)

# 噪声缩放
self.noise_scale = config.get("noise_scale", 1.0)

# 标签丢弃概率 (用于 Classifier-Free Guidance)
self.label_drop_prob = config.get("label_drop_prob", 0.1)

# ========== 损失权重 ==========
self.v_weight = config.get("v_weight", 1.0)  # v-prediction loss 权重
self.gamma_weight = config.get("gamma_weight", 0.0)  # 负对数似然权重
self.reconstruction_weight = config.get("reconstruction_weight", 1.0)

时间步采样:

Logvar 参数 (用于 gamma_loss)

# ========== Logvar 参数 (用于 gamma_loss) ==========
# 为了在连续时间扩散中使用 gamma_loss,需要:
# 1. 将连续时间 t ∈ [0, 1] 离散化为 num_timesteps 个 bin
# 2. 每个 bin 学习一个独立的 logvar 值
self.num_timesteps = config.get("num_timesteps", 1000)
logvar_init = config.get("logvar_init", 0.0)
self.learn_logvar = config.get("learn_logvar", False)

# 创建 logvar 张量
logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
    self.logvar = nn.Parameter(logvar, requires_grad=True)
else:
    self.register_buffer('logvar', logvar)

Gamma Loss 的作用:

感知损失配置 (可选)

# ========== 感知损失配置 (Stage 2 可选) ==========
# 支持 perceptual_loss_type: 'vgg', 'lpips', 'none'
# - 'vgg': 使用 VGG-16BN 感知损失 (计算快,效果一般)
# - 'lpips': 使用 LPIPS 感知损失 (需要 pip install lpips,效果最好)
# - 'none': 不使用感知损失
perceptual_loss_type = config.get("perceptual_loss_type", "none").lower()

if perceptual_loss_type == "lpips":
    from modules.ditok_blocks import LPIPSLoss
    self.perceptual_loss = LPIPSLoss(config)
    self.perceptual_weight = config.get("perceptual_weight", 0.5)
    self.use_perceptual_loss = True
elif perceptual_loss_type == "vgg":
    from modules.ditok_blocks import VGGPerceptualLoss
    self.perceptual_loss = VGGPerceptualLoss(config)
    self.perceptual_weight = config.get("perceptual_weight", 0.5)
    self.use_perceptual_loss = True
else:
    self.perceptual_loss = None
    self.perceptual_weight = 0.0
    self.use_perceptual_loss = False

感知损失的作用:

EMA 参数

# ========== EMA (Exponential Moving Average) 参数 ==========
self.ema_decay1 = config.get("ema_decay1", 0.9999)  # 主 EMA
self.ema_decay2 = config.get("ema_decay2", 0.9996)  # 辅助 EMA (衰减更快)
self.ema_params1 = None  # EMA1 参数
self.ema_params2 = None  # EMA2 参数

# ========== 生成参数 ==========
self.sampling_method = config.get("sampling_method", "heun")  # 采样方法
self.num_sampling_steps = config.get("num_sampling_steps", 50)
self.cfg_scale = config.get("cfg", 1.0)  # CFG 强度
self.cfg_interval = (
    config.get("interval_min", 0.0),
    config.get("interval_max", 1.0)
)

EMA 的作用:

前向传播 - 第一步: 编码和量化

def forward(self, x, labels):
    """
    DiTok 前向传播和损失计算。

    支持 Stage 1 和 Stage 2 训练:
    - Stage 1: 训练整个模型(encoder + quantizer + decoder)
    - Stage 2: 冻结 encoder 和 quantizer,只训练 decoder

    损失组成:
    - v_loss: 扩散模型的 v-prediction loss(MSE)
    - gamma_loss: v 的负对数似然(可选)
    - perceptual_loss: 感知损失(Stage 2,可选)
    - bpp_loss: 在 Stage 2 中设为 0
    """
    loss_dict = {}

    # ========== 第一步:通过 Encoder 生成编码值 ==========
    if self.freeze_encoder:
        # Stage 2: 使用 torch.no_grad() 节省内存
        with torch.no_grad():
            z_encoded = self.encoder(x, self.latent_tokens)
    else:
        # Stage 1: 正常前向传播
        z_encoded = self.encoder(x, self.latent_tokens)

    # 输出形状: z_encoded [B, token_size, 1, num_latent_tokens]
    # 例如: [2, 8, 1, 32]

Encoder 的工作原理:

  1. 图像 \(x\) 经过 patch embedding
  2. 与 learnable latent tokens 进行 cross-attention
  3. 输出压缩的 latent 表示

量化

# ========== 量化 (VQ 模式) ==========
if self.freeze_quantizer:
    # Stage 2: 冻结量化器
    with torch.no_grad():
        z_quantized, quant_result_dict = self.quantizer(z_encoded)

    # 将量化损失设为 0 (因为不更新量化器)
    bpp_loss = torch.tensor(0.0, device=x.device)
    quant_result_dict['quantizer_loss'] = torch.tensor(0.0, device=x.device)
    quant_result_dict['codebook_loss'] = torch.tensor(0.0, device=x.device)
    quant_result_dict['commitment_loss'] = torch.tensor(0.0, device=x.device)
else:
    # Stage 1: 正常训练量化器
    z_quantized, quant_result_dict = self.quantizer(z_encoded)
    bpp_loss = quant_result_dict.get('quantizer_loss', torch.tensor(0.0, device=x.device))

Vector Quantization:

  1. 将每个 latent token 映射到最近的 codebook 向量
  2. Commitment loss: 鼓励 encoder 输出接近 codebook
  3. Codebook loss: 更新 codebook 向量

前向传播 - 第二步: 加噪

# ========== 第二步:加噪得到中间态 ==========
# 在图像空间加噪,然后转换为 patch tokens
t = self.sample_t(x.size(0), device=x.device)  # [B]

# 采样标准高斯噪声
e = torch.randn_like(x) * self.noise_scale  # [B, C, H, W]

# 加噪公式: z_noisy = t * x + (1 - t) * e
# t=0: 纯噪声; t=1: 原始图像
z_noisy = t.view(-1, *([1] * (x.ndim - 1))) * x + (1 - t.view(-1, *([1] * (x.ndim - 1)))) * e

# 将加噪图像转换为 patch tokens (复用 encoder 的 patch_embed)
z_patches = self.encoder.patch_embed(z_noisy)  # [B, width, H', W']
z_patches = z_patches.reshape(z_patches.shape[0], z_patches.shape[1], -1)  # [B, width, H'*W']
z_patches = z_patches.permute(0, 2, 1)  # [B, H'*W', width]

加噪公式说明:

前向传播 - 第三步: 标签丢弃 (CFG)

# ========== Classifier-Free Guidance 训练 ==========
labels_dropped = self.drop_labels(labels) if self.training else labels

CFG 机制:

前向传播 - 第四步: Decoder 预测

# ========== 第三步:通过 Diffusion Decoder 生成预测值 ==========
x_pred = self.diffusion_decoder(z_patches, z_quantized, t, labels_dropped)
# 输出: [B, 3, 256, 256]

Decoder 的输入:

  1. z_patches: 加噪的图像 tokens (需要预测清晰的)
  2. z_quantized: 原始图像的 latent 表示 (作为条件)
  3. t: 当前时间步
  4. labels: 类别标签

前向传播 - 第五步: 计算 v-prediction loss

# ========== 第四步:计算 v 的损失 ==========
# v 是从加噪图像到原始图像的"速度"
# 真实速度: v = (x - z_noisy) / (1 - t)
# 预测速度: v_pred = (x_pred - z_noisy) / (1 - t)

# 计算真实速度
v = (x - z_noisy) / (1 - t.view(-1, *([1] * (x.ndim - 1)))).clamp_min(self.t_eps)

# 计算预测速度
v_pred = (x_pred - z_noisy) / (1 - t.view(-1, *([1] * (x.ndim - 1)))).clamp_min(self.t_eps)

# MSE loss
v_loss = (v - v_pred) ** 2
v_loss = v_loss.mean(dim=(1, 2, 3))  # [B]

loss_dict.update({'v_loss': v_loss.mean()})

V-Prediction 的优势:

前向传播 - 第六步: 计算 gamma_loss (可选)

# ========== 第五步:计算 gamma_loss(v 的负对数似然)==========
# 目的:让 v 的分布呈现正态分布 N(v_pred, exp(logvar_t))
# 公式:log N(v; v_pred, σ²) = (v - v_pred)² / (2σ²) + log(σ)

if self.gamma_weight > 0:
    # 将连续时间 t ∈ [0, 1] 离散化为时间步索引
    t_idx = (t * (self.num_timesteps - 1)).long()
    t_idx = torch.clamp(t_idx, 0, self.num_timesteps - 1)

    # 获取对应时间步的 logvar
    logvar_t = self.logvar[t_idx]  # [B]

    # 计算 gamma_loss
    gamma_loss = v_loss / (2 * torch.exp(logvar_t)) + logvar_t / 2
    loss_dict.update({'gamma_loss': gamma_loss.mean()})
else:
    gamma_loss = torch.tensor(0.0, device=x.device)

Gamma Loss 的作用:

前向传播 - 第七步: 感知损失 (可选)

# ========== 第七步:计算 perceptual loss(Stage 2,可选)==========
if self.use_perceptual_loss and self.perceptual_loss is not None:
    # 使用配置的感知损失 (VGG 或 LPIPS)
    perceptual_loss = self.perceptual_loss(x, x_pred)
    loss_dict.update({'perceptual_loss': perceptual_loss})
else:
    perceptual_loss = torch.tensor(0.0, device=x.device)

前向传播 - 第八步: 总损失

# ========== 第七步:bpp loss(Stage 2 中为 0)==========
if self.freeze_quantizer:
    bpp_loss = torch.tensor(0.0, device=x.device)

loss_dict.update({'bpp_loss': bpp_loss})
loss_dict.update({'codebook_loss': quant_result_dict['codebook_loss']})
loss_dict.update({'commitment_loss': quant_result_dict['commitment_loss']})

# ========== 第八步:计算总损失 ==========
loss = (
    self.v_weight * v_loss.mean() +
    self.gamma_weight * gamma_loss.mean() +
    self.perceptual_weight * perceptual_loss +
    bpp_loss  # Stage 2 中为 0
)

loss_dict.update({'loss': loss})
return loss_dict

损失权重调节:

辅助函数

def drop_labels(self, labels):
    """Randomly drop labels for classifier-free guidance training."""
    drop = torch.rand(labels.shape[0], device=labels.device) < self.label_drop_prob
    num_classes = labels.max().item() + 1  # infer number of classes
    out = torch.where(drop, torch.full_like(labels, num_classes), labels)
    return out

def sample_t(self, n: int, device=None):
    """从 logit-normal 分布采样时间步"""
    z = torch.randn(n, device=device) * self.P_std + self.P_mean
    return torch.sigmoid(z)

EMA 更新

@torch.no_grad()
def update_ema(self):
    """更新 EMA 参数(参考 JiT 实现)

    EMA (Exponential Moving Average) 用于平滑模型参数,
    提高推理时的稳定性和生成质量。

    更新公式:
        θ_ema = decay * θ_ema + (1 - decay) * θ

    使用两个 EMA 跟踪器:
    - EMA1: decay=0.9999, 主 EMA,用于推理
    - EMA2: decay=0.9996, 辅助 EMA,衰减更快
    """
    # 第一次调用时,初始化 EMA 参数
    if self.ema_params1 is None:
        self.ema_params1 = [p.clone().detach() for p in self.parameters()]
    if self.ema_params2 is None:
        self.ema_params2 = [p.clone().detach() for p in self.parameters()]

    # 获取当前模型参数
    source_params = list(self.parameters())

    # 更新 EMA1(主 EMA,用于推理)
    for targ, src in zip(self.ema_params1, source_params):
        # θ_ema = decay * θ_ema + (1 - decay) * θ
        targ.detach().mul_(self.ema_decay1).add_(src.data, alpha=1 - self.ema_decay1)

    # 更新 EMA2(辅助 EMA,衰减更快)
    for targ, src in zip(self.ema_params2, source_params):
        targ.detach().mul_(self.ema_decay2).add_(src.data, alpha=1 - self.ema_decay2)

EMA 的使用:

  1. 训练时: 每次更新后调用 update_ema()
  2. 推理时: 使用 EMA 参数而非原始参数
  3. 保存时: 同时保存原始参数和 EMA 参数

训练脚本调用示例

# 在训练循环中调用 EMA 更新
for batch in dataloader:
    loss = model(x, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # 更新 EMA 参数
    if accelerator.is_main_process:
        model.update_ema()

    # 更新学习率
    lr_scheduler.step()

完整代码汇总

DiTok 的完整实现包括以下核心模块:

  1. TimestepEmbedder: 时间步嵌入
  2. LabelEmbedder: 类别标签嵌入
  3. DiTokDecoder: 扩散解码器
  4. DiTok: 端到端模型
  5. VectorQuantizer: 向量量化器
  6. VGGPerceptualLoss: 基于 VGG 的感知损失
  7. LPIPSLoss: 学习得到的感知损失
  8. EMA: 指数移动平均

关键设计决策: