DiTok Stage 2 训练指南
概述
本文档介绍 DiTok Stage 2 训练流程,用于微调扩散解码器。
训练策略:
- Stage 1: 端到端训练 Encoder + Quantizer + Decoder (使用预训练的 TiTok) 这个过程我们并不需要真的进行,可以使用 TiTok 已经训练好的结果。
- Stage 2 (本文档): 冻结 Encoder 和 Quantizer,只训练 Decoder
预训练 Checkpoint:
本教程使用预训练的 TiTok L32 checkpoint:
- 文件: tokenizer_titok_l32_imagenet.safetensors
- 大小: ~2.4 GB
- 来源: TiTok 官方预训练模型
- 特点: 32 个 latent tokens,在 ImageNet 上训练
配置文件中已默认设置好预训练 checkpoint 路径,无需手动下载。
为什么需要 Stage 2:
- TiTok 已经在大规模数据上训练好了,能很好地提取图像特征
- 我们只需要学习如何从这些特征生成清晰的图像
- 冻结 Encoder 和 Quantizer 可以减少训练参数,加快训练速度
关键特性:
- 冻结 Encoder 和 Quantizer
- 使用 v-prediction loss 训练扩散解码器
- 支持可选的感知损失 (VGG 或 LPIPS)
- EMA (指数移动平均) 用于稳定推理
- 使用 Accelerate 支持分布式训练
- 余弦退火学习率调度
环境设置
本项目使用 uv 作为包管理器。
# 1. 安装 uv (如果还没有安装)
curl -LsSf https://astral.sh/uv/install.sh | sh
# 2. 安装项目依赖
uv sync
# 3. (可选) 添加 LPIPS 感知损失支持
uv add lpips
项目结构:
DiTok/
├── ditok.py # DiTok 模型定义
├── modules/ # 核心模块
│ ├── titok_blocks.py # TiTok Transformer blocks
│ ├── ditok_blocks.py # DiTok Transformer blocks
│ ├── quantizer.py # Vector Quantization
│ └── losses.py # 损失函数 (LPIPS, VGG等)
├── utils/ # 工具脚本
│ ├── train_ditok.py # Stage 2 训练脚本
│ └── logger.py # 日志工具
├── configs/ # 配置文件
│ └── training/DiTok/stage2/
│ └── ditok_b64.yaml # Stage 2 配置
├── pyproject.toml # 项目配置
├── uv.lock # 依赖锁定
└── .python-version # Python 3.12
快速开始
单 GPU 训练:
# 使用 uv run 运行训练脚本
uv run python utils/train_ditok.py
多 GPU 训练 (推荐):
# 使用 uv run 配合 Accelerate 启动 8 个 GPU
uv run accelerate launch --num_processes=8 --main_process_port=9999 utils/train_ditok.py
自定义配置:
# 注意: 当前版本暂不支持命令行参数
# 请直接修改 configs/training/DiTok/stage2/ditok_l32.yaml
uv run python utils/train_ditok.py
配置文件详解
配置文件位于: configs/training/DiTok/stage2/ditok_l32.yaml
核心配置项:
# ========== 模型配置 ==========
model:
vq_model:
vit_enc_patch_size: 16 # Encoder patch size
vit_enc_model_size: "base" # Encoder model size
vit_dec_patch_size: 16 # Decoder patch size
vit_dec_model_size: "base" # Decoder model size
num_latent_tokens: 32 # L32: 32 latent tokens
token_size: 8 # Token dimension
codebook_size: 1024 # VQ codebook size
commitment_cost: 0.25 # VQ commitment cost
# ========== 预训练 Checkpoint ==========
pretrained_checkpoint: "./checkpoints/tokenizer_titok_l32_imagenet.safetensors"
# ========== 训练配置 ==========
training:
output_dir: "./output/ditok_l32_stage2"
per_gpu_batch_size: 32 # 每个 GPU 的 batch size
gradient_accumulation_steps: 1 # 梯度累积步数
mixed_precision: "bf16" # 混合精度训练
max_train_steps: 500000 # 最大训练步数
lr: 1.0e-4 # 初始学习率
min_lr: 1.0e-6 # 最小学习率
warmup_steps: 10000 # 预热步数
lr_schedule: "cosine" # 学习率调度
gradient_clip: 1.0 # 梯度裁剪
# Stage 2 特定配置
freeze_encoder: true # 冻结 encoder
freeze_quantizer: true # 冻结 quantizer
# 感知损失配置
perceptual_loss_type: "lpips" # 'none', 'vgg', 'lpips'
perceptual_weight: 0.5 # 感知损失权重
# Diffusion 训练参数
P_std: 0.8 # 时间步采样标准差
P_mean: -0.8 # 时间步采样均值
t_eps: 5e-2 # 最小时间步
noise_scale: 1.0 # 噪声缩放
label_drop_prob: 0.1 # CFG 标签丢弃概率
v_weight: 1.0 # v-prediction loss 权重
gamma_weight: 1.0 # gamma loss 权重
# EMA 参数
ema_decay1: 0.9999 # 主 EMA 衰减率
ema_decay2: 0.9996 # 辅助 EMA 衰减率
# 生成参数
sampling_method: "heun" # 采样方法
num_sampling_steps: 50 # 采样步数
cfg: 1.0 # CFG 强度
# ========== 数据集配置 ==========
dataset:
data_path: "./data/imagenet" # ImageNet 数据集路径
preprocessing:
crop_size: 256 # 图像裁剪大小
训练流程详解
训练脚本主要函数:
- get_config(): 加载配置文件
- create_model_and_loss_module(): 创建模型并冻结 Encoder 和 Quantizer
- create_optimizer(): 创建 AdamW 优化器
- create_lr_scheduler(): 创建余弦退火学习率调度器
- create_dataloader(): 创建数据加载器
- train_one_epoch(): 训练一个 epoch
- save_checkpoint(): 保存检查点
关键训练步骤:
完整的训练流程包括以下 7 个主要函数的调用:
# ========== 1. 加载配置文件 ==========
config = get_config() # 加载 YAML 配置或使用默认配置
# ========== 2. 创建模型并冻结参数 ==========
model, _, _ = create_model_and_loss_module(config, logger, accelerator)
# - 创建 DiTok 模型
# - 加载预训练 checkpoint (如果有)
# - 冻结 encoder 和 quantizer
# ========== 3. 创建优化器 ==========
optimizer = create_optimizer(config, logger, model)
# - 收集可训练参数 (diffusion_decoder)
# - 创建 AdamW 优化器
# ========== 4. 创建学习率调度器 ==========
lr_scheduler, _ = create_lr_scheduler(config, logger, accelerator, optimizer)
# - 创建余弦退火学习率调度器
# ========== 5. 创建数据加载器 ==========
train_dataloader, _ = create_dataloader(config, logger, accelerator)
# - 加载 ImageNet 数据集
# - 应用数据增强 (center crop, random flip)
# ========== 6. 训练循环 ==========
global_step = 0
while global_step < config.training.max_train_steps:
global_step = train_one_epoch(
config, logger, accelerator, model, optimizer, lr_scheduler,
train_dataloader, global_step
)
# 每个 epoch 内:
# - 前向传播计算 loss
# - 反向传播更新参数
# - 更新 EMA (主进程)
# - 更新学习率
# - 定期保存 checkpoint
# ========== 7. 保存最终检查点 ==========
save_checkpoint(model, optimizer, lr_scheduler, global_step, config, logger)
# - 保存模型参数和优化器状态
# - 保存 EMA1 和 EMA2 参数
加载预训练 Checkpoint 详解
load_pretrained_checkpoint 函数负责从 SafeTensors 文件加载预训练权重:
def load_pretrained_checkpoint(model, checkpoint_path, logger):
"""加载预训练 TiTok checkpoint
处理 SafeTensors 格式,并映射键名:
- checkpoint 'decoder.*' -> model 'diffusion_decoder.*'
- checkpoint 'encoder.*' -> model 'encoder.*'
"""
# 判断文件格式
if checkpoint_path.endswith('.safetensors'):
from safetensors import safe_open
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint['model']
# 创建键名映射
mapped_state_dict = {}
for key, value in state_dict.items():
if key.startswith('decoder.'):
new_key = 'diffusion_decoder.' + key[8:]
mapped_state_dict[new_key] = value
elif key.startswith('encoder.'):
mapped_state_dict[key] = value
# 加载权重
load_result = model.load_state_dict(mapped_state_dict, strict=False)
# 记录加载结果
logger.info(f"Loaded {len(load_result.loaded_keys)} keys")
logger.info(f"Missing {len(load_result.missing_keys)} keys")
return model
创建模型并冻结参数详解
def create_model_and_loss_module(config, logger, accelerator):
"""创建 DiTok 模型
Stage 2 训练冻结 encoder 和 quantizer,只训练 diffusion decoder。
"""
model = DiTok(config)
# 加载预训练 checkpoint
pretrained_checkpoint = config.training.get("pretrained_checkpoint", None)
if pretrained_checkpoint and os.path.exists(pretrained_checkpoint):
model = load_pretrained_checkpoint(model, pretrained_checkpoint, logger)
# Stage 2: 冻结 encoder
if config.training.freeze_encoder:
model.encoder.eval()
for param in model.encoder.parameters():
param.requires_grad = False
logger.info("Encoder frozen for Stage 2 training")
# Stage 2: 冻结 quantizer
if config.training.freeze_quantizer:
model.quantizer.eval()
for param in model.quantizer.parameters():
param.requires_grad = False
logger.info("Quantizer frozen for Stage 2 training")
# 统计可训练参数
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Trainable parameters: {n_params:,}")
return model, None, None
创建优化器详解
def create_optimizer(config, logger, model):
"""创建 AdamW 优化器
只优化 diffusion decoder 的参数。
"""
# 收集所有可训练参数
trainable_params = [p for p in model.parameters() if p.requires_grad]
# 按模块分组统计
param_count_by_module = {}
for name, param in model.named_parameters():
if param.requires_grad:
module = name.split('.')[0]
param_count_by_module[module] = param_count_by_module.get(module, 0) + param.numel()
logger.info(f"Trainable parameters by module:")
for module, count in sorted(param_count_by_module.items()):
logger.info(f" {module}: {count:,}")
# 创建 AdamW 优化器
optimizer = torch.optim.AdamW(
trainable_params,
lr=config.training.lr,
betas=(0.9, 0.95),
weight_decay=config.training.weight_decay
)
return optimizer
训练循环详解
def train_one_epoch(config, logger, accelerator, model, optimizer,
train_dataloader, global_step):
"""训练一个 epoch
关键点:
1. Encoder 和 Quantizer 必须保持 eval 模式
2. 图像需要归一化到 [0, 1] 范围
3. EMA 只在主进程中更新
"""
model.train()
# 确保 encoder 和 quantizer 处于 eval 模式
model.encoder.eval()
model.quantizer.eval()
for batch_idx, (images, labels) in enumerate(train_dataloader):
# 移动数据到设备
images = images.to(accelerator.device)
images = images.float() / 255.0 # 归一化到 [0, 1]
labels = labels.to(accelerator.device)
# 前向传播
loss_dict = model(images, labels)
loss = loss_dict['loss']
# 反向传播
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# 更新 EMA (只在主进程)
if accelerator.is_main_process:
model.update_ema()
# 更新学习率
lr_scheduler.step()
global_step += 1
# 日志记录
if global_step % config.training.get("log_freq", 100) == 0:
log_parts = [f"Step {global_step}"]
for key in ['loss', 'v_loss', 'gamma_loss', 'perceptual_loss']:
if key in loss_dict:
value = loss_dict[key].item() if torch.is_tensor(loss_dict[key]) else loss_dict[key]
log_parts.append(f"{key}={value:.4f}")
logger.info(" ".join(log_parts))
# 保存 checkpoint
if global_step % config.training.get("save_freq", 10000) == 0:
save_checkpoint(model, optimizer, lr_scheduler, global_step, config, logger)
if global_step >= config.training.max_train_steps:
break
return global_step
保存 Checkpoint 详解
def save_checkpoint(model, optimizer, lr_scheduler, global_step, config, logger):
"""保存检查点
保存内容:
1. 模型参数和优化器状态
2. EMA1 和 EMA2 参数(如果已初始化)
"""
output_dir = config.training.output_dir
# 保存模型检查点
model_save_path = os.path.join(output_dir, f"checkpoint-{global_step}.pt")
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict() if lr_scheduler else None,
'global_step': global_step,
'config': config
}, model_save_path)
# 保存 EMA 参数
if model.ema_params1 is not None and model.ema_params2 is not None:
ema_save_path = os.path.join(output_dir, f"ema-{global_step}.pt")
# 使用 named_parameters() 确保键名匹配
ema_state_dict1 = {
name: param.data.cpu()
for (name, _), param in zip(model.named_parameters(), model.ema_params1)
}
ema_state_dict2 = {
name: param.data.cpu()
for (name, _), param in zip(model.named_parameters(), model.ema_params2)
}
torch.save({
'ema_state_dict1': ema_state_dict1,
'ema_state_dict2': ema_state_dict2,
'global_step': global_step,
'ema_decay1': model.ema_decay1,
'ema_decay2': model.ema_decay2,
}, ema_save_path)
logger.info(f"Checkpoint saved at step {global_step} (with EMA)")
else:
logger.info(f"Checkpoint saved at step {global_step} (no EMA yet)")
加载配置文件详解
def get_config():
"""加载训练配置
支持从 YAML 文件加载配置,如果文件不存在则使用默认配置。
"""
# 配置文件路径
config_path = "configs/training/DiTok/stage2/ditok_l32.yaml"
# 尝试加载配置文件
if os.path.exists(config_path):
config = OmegaConf.load(config_path)
else:
# 配置文件不存在时使用默认配置
print(f"Warning: {config_path} not found, using default config")
config = OmegaConf.create({
"dataset": {
"preprocessing": {"crop_size": 256}
},
"model": {
"vq_model": {
"vit_enc_patch_size": 16,
"vit_enc_model_size": "base",
"vit_dec_patch_size": 16,
"vit_dec_model_size": "base",
"num_latent_tokens": 32,
"token_size": 8,
"codebook_size": 1024,
"commitment_cost": 0.25,
"use_l2_norm": False,
"clustering_vq": False,
"is_legacy": False,
}
},
"training": {
"output_dir": "./output/ditok_l32_stage2",
"per_gpu_batch_size": 32,
"lr": 1.0e-4,
"min_lr": 1.0e-6,
"max_train_steps": 500000,
"warmup_steps": 10000,
"lr_schedule": "cosine",
# Stage 2 特定配置
"freeze_encoder": True,
"freeze_quantizer": True,
# Diffusion 训练参数
"P_std": 0.8,
"P_mean": -0.8,
"v_weight": 1.0,
"gamma_weight": 1.0,
# EMA 参数
"ema_decay1": 0.9999,
"ema_decay2": 0.9996,
}
})
return config
配置项说明:
- dataset: 数据集预处理配置
- model.vq_model: Encoder 和 Decoder 架构配置
- training: 训练超参数 (学习率、batch size、步数等)
- training.freeze_encoder/quantizer: Stage 2 冻结设置
- training.P_std/P_mean: 时间步采样参数 (Karras et al., 2022)
- training.ema_decay1/2: EMA 衰减率
创建学习率调度器详解
def create_lr_scheduler(config, logger, accelerator, optimizer):
"""创建学习率调度器
使用余弦退火调度器,从初始学习率逐渐降低到最小学习率。
"""
if config.training.lr_schedule == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=config.training.max_train_steps, # 总步数
eta_min=config.training.min_lr # 最小学习率
)
else:
scheduler = None
return scheduler, None
余弦退火学习率曲线:
学习率按照余弦曲线从初始值衰减到最小值:
其中:
- \(\eta_{max} = 1.0 \times 10^{-4}\) (初始学习率)
- \(\eta_{min} = 1.0 \times 10^{-6}\) (最小学习率)
- \(T_{max} = 500000\) (最大训练步数)
学习率调度示例:
| Step | 学习率 | 说明 |
|---|---|---|
| 0 | 1.00e-4 | 初始学习率 |
| 10k | ~1.00e-4 | Warmup 后 |
| 250k | 5.05e-5 | 中点 (50%) |
| 500k | 1.00e-6 | 最小学习率 |
创建数据加载器详解
def create_dataloader(config, logger, accelerator):
"""创建数据加载器
加载 ImageNet 数据集并应用数据增强。
"""
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
# 数据增强流程
transform_train = transforms.Compose([
# 1. Center crop 到指定大小 (256x256)
transforms.Lambda(lambda img: center_crop_arr(img,
config.dataset.preprocessing.crop_size)),
# 2. 随机水平翻转 (数据增强)
transforms.RandomHorizontalFlip(),
# 3. 转换为 PyTorch tensor (uint8 [0, 255])
transforms.PILToTensor()
])
# 数据集路径: data_path/train/
data_path = config.dataset.get("data_path", "./data/imagenet")
train_dataset = ImageFolder(
os.path.join(data_path, 'train'),
transform=transform_train
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=config.training.per_gpu_batch_size, # 每个GPU的batch size
shuffle=True, # 打乱数据
num_workers=config.training.get("num_workers", 4), # 数据加载进程数
pin_memory=True, # 加速GPU数据传输
drop_last=True # 丢弃最后不完整的batch
)
logger.info(f"Train dataset size: {len(train_dataset)}")
logger.info(f"Train dataloader size: {len(train_loader)}")
return train_loader, None
def center_crop_arr(img, size):
"""Center crop 辅助函数
将图像从中心裁剪到指定大小。
"""
import numpy as np
img = np.asarray(img)
h, w = img.shape[:2]
c_h, c_w = size
# 计算裁剪起始位置
top = (h - c_h) // 2
left = (w - c_w) // 2
return img[top:top+c_h, left:left+c_w]
数据增强流程说明:
Center Crop (256x256):
- 从原始图像中心裁剪出 256×256 区域
- 保证所有图像尺寸一致
- 适合 ImageNet (通常图像 ≥ 256)
Random Horizontal Flip:
- 随机水平翻转 (p=0.5)
- 增加数据多样性
- 不改变图像语义
PILToTensor:
- 转换为 PyTorch tensor
- 值域: uint8 [0, 255]
- 后续在训练中归一化到 [0, 1]
数据加载器参数说明:
| 参数 | 值 | 说明 |
|---|---|---|
| batch_size | 32 | 每个 GPU 的 batch size |
| shuffle | True | 每个 epoch 打乱数据 |
| num_workers | 4 | 数据加载进程数 (可调整) |
| pin_memory | True | 锁页内存,加速 GPU 传输 |
| drop_last | True | 丢弃不完整的最后 batch |
数据加载器性能优化:
# GPU 内存充足时增大 batch size
per_gpu_batch_size: 64
# CPU 核心多时增加 num_workers
num_workers: 8 # 根据 CPU 核心数调整
# 使用 gradient_accumulation 增大有效 batch size
per_gpu_batch_size: 16
gradient_accumulation_steps: 4
# 有效 batch size = 16 * 4 * num_gpus
EMA (指数移动平均)
EMA 的作用:
- 平滑模型参数,减少推理时的随机波动
- 提高生成质量和稳定性
- 双 EMA 提供不同程度的平滑
EMA 更新公式:
使用两个 EMA 跟踪器:
- EMA1 (decay=0.9999): 主 EMA,用于推理
- EMA2 (decay=0.9996): 辅助 EMA,衰减更快
保存 EMA 参数:
检查点文件包含:
- checkpoint-{step}.pt: 模型参数和优化器状态
- ema-{step}.pt: EMA1 和 EMA2 参数
损失函数
总损失组成:
其中:
- \(\mathcal{L}_v\): v-prediction loss (主要损失)
- \(\mathcal{L}_\gamma\): gamma loss (可选,负对数似然)
- \(\mathcal{L}_p\): perceptual loss (可选,VGG 或 LPIPS)
- \(\mathcal{L}_{bpp}\): bpp loss (Stage 2 中为 0)
损失权重:
- v_weight = 1.0: 主要损失,通常设为 1.0
- gamma_weight = 0.0: 可选,通常设为 0.0
- perceptual_weight = 0.5: 可选,通常设为 0.5
感知损失选择
DiTok 支持两种感知损失:
VGGPerceptualLoss (快速):
- 基于 VGG-16BN
- 计算效率高
- 效果一般
LPIPSLoss (推荐):
- 学习得到的感知距离
- 与人类感知高度一致
- 需要安装: uv add lpips
配置选项:
perceptual_loss_type: "lpips" # 'none', 'vgg', 'lpips'
perceptual_weight: 0.5
lpips_net: "vgg" # 'vgg' or 'alex'
数据准备
数据集结构:
data_path/
├── train/
│ ├── n01440764/
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ └── ...
│ ├── n01443537/
│ └── ...
└── val/
└── ...
数据增强:
- Center crop (256x256)
- Random horizontal flip
- 转换为 PyTorch tensor
训练监控
日志输出:
训练脚本会输出以下信息:
- 每个 100 步输出损失
- 每个 10000 步保存检查点
- 每隔 test_interval 步(默认 3000)进行评估
- TensorBoard 日志 (在 logging_dir 中)
日志示例:
Step 100: loss=3.4567, v_loss=3.4123
Step 200: loss=3.2345, v_loss=3.1987
...
评估结果
训练过程中会自动保存评估结果,便于追踪模型性能变化。
评估结果文件结构:
output/ditok_l32_stage2_filelist/
├── evaluation_history/
│ ├── eval_step00003000.json # 单次评估的详细结果
│ ├── eval_step00006000.json
│ ├── eval_step00009000.json
│ └── evaluation_summary.jsonl # 所有结果的汇总文件
└── visualizations/
├── comparison_step00003000.png # 原图与重建图对比
├── comparison_step00006000.png
└── ...
单个评估结果文件格式:
eval_step00003000.json:
{
"global_step": 3000,
"timestamp": "2025-02-07T12:34:56.789012",
"metrics": {
"rFID": 25.4321,
"IS": 18.7654,
"entropy": 9.8765
},
"image_path": "output/ditok_l32_stage2_filelist/visualizations/comparison_step00003000.png",
"image_path_absolute": "/完整路径/.../comparison_step00003000.png"
}
汇总文件格式:
evaluation_summary.jsonl (JSON Lines 格式,每行一个 JSON 对象):
{"global_step": 3000, "timestamp": "2025-02-07T12:34:56", "metrics": {"rFID": 25.43, "IS": 18.76}, "image_path": "output/..."}
{"global_step": 6000, "timestamp": "2025-02-07T13:45:67", "metrics": {"rFID": 23.21, "IS": 19.54}, "image_path": "output/..."}
{"global_step": 9000, "timestamp": "2025-02-07T14:56:78", "metrics": {"rFID": 21.09, "IS": 20.12}, "image_path": "output/..."}
查看评估结果:
使用 scripts/show_eval_results.py 脚本查看和分析评估结果:
# 显示所有评估结果汇总表格
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist
# 显示最近 N 次评估
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist --top-k 5
# 显示指标趋势分析
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist --trend rFID
# 导出为 CSV 文件
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist --export results.csv
输出示例:
汇总表格:
================================================================================
Evaluation Summary
================================================================================
Step Timestamp rFID IS Image
--------------------------------------------------------------------------------
3000 2025-02-07 12:34 25.4321 18.7654 output/...comparison_step00003000.png
6000 2025-02-07 13:45 23.2100 19.5400 output/...comparison_step00006000.png
9000 2025-02-07 14:56 21.0900 20.1200 output/...comparison_step00009000.png
================================================================================
趋势分析:
================================================================================
Metric Trend Analysis: rFID
================================================================================
Total evaluations: 10
Min value: 21.0900 (at step 9000)
Max value: 25.4321 (at step 3000)
Mean value: 22.8765
Latest value: 21.0900 (at step 9000)
Total change: -4.3421 (-17.08%)
Improving steps: 8/9
Detailed Values:
--------------------------------------------------------------------------------
Step rFID Change
--------------------------------------------------------------------------------
3000 25.4321
6000 23.2100 -2.2221
9000 21.0900 -2.1200
================================================================================
评估配置:
在配置文件中设置评估间隔:
training:
test_interval: 3000 # 每 3000 步评估一次
自动保存时机:
评估结果会在以下时机自动保存:
- 定期评估: 每隔 test_interval 步(默认 3000 步)
- 训练结束: 最终模型评估后
路径说明:
- 相对路径 (image_path): 相对于项目根目录的路径,便于在代码中使用
- 绝对路径 (image_path_absolute): 完整的文件系统路径,便于直接访问文件
- 时间戳: JSON 内部包含 ISO 8601 格式的保存时间
编程接口:
也可以在 Python 代码中读取评估历史:
from utils.viz_utils import load_evaluation_history, print_evaluation_summary
# 加载评估历史
history = load_evaluation_history("output/ditok_l32_stage2_filelist")
# 获取最近 N 次评估
from utils.viz_utils import get_latest_evaluation_results
recent_results = get_latest_evaluation_results("output/ditok_l32_stage2_filelist", top_k=5)
# 打印汇总表格
print_evaluation_summary("output/ditok_l32_stage2_filelist")
检查点保存
检查点文件:
- checkpoint-{step}.pt: 模型参数和优化器状态
- ema-{step}.pt: EMA 参数
- final_model.pt: 最终模型
加载检查点:
import torch
# 加载模型检查点
checkpoint = torch.load("checkpoint-100000.pt")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
# 加载 EMA 检查点
ema_checkpoint = torch.load("ema-100000.pt")
# 使用 EMA1 参数进行推理
model.load_state_dict(ema_checkpoint['ema_state_dict1'])
# 或者使用 EMA2 参数
# model.load_state_dict(ema_checkpoint['ema_state_dict2'])
常见问题
Q: 为什么要使用 uv?
A: uv 是一个快速的 Python 包管理器,比 pip 快 10-100 倍。它提供了:
- 极快的依赖解析和安装速度
- 锁定文件 (uv.lock) 确保可重现的环境
- 更好的依赖管理体验
- 与 pyproject.toml 标准完全兼容
Q: 如何调整 batch size?
A: 修改配置文件中的 per_gpu_batch_size。如果 GPU 内存不足,可以减小这个值或使用梯度累积。
Q: 如何使用多个 GPU?
A: 使用 uv run 配合 Accelerate 启动训练脚本:
uv run accelerate launch --num_processes=8 utils/train_ditok.py
Q: 如何恢复训练?
A: 在 create_model_and_loss_module() 函数中添加从检查点加载的逻辑,或者修改配置文件的 output_dir 为之前的输出目录。
Q: 如何使用感知损失?
A: 在配置文件中设置:
- 快速实验: perceptual_loss_type: "vgg"
- 最佳质量: perceptual_loss_type: "lpips"
Q: EMA 参数有什么用?
A: EMA 参数用于推理时,可以提高生成质量。训练时会自动更新 EMA,推理时加载 EMA 参数即可。
Q: 如何查看训练过程中的评估结果?
A: 使用 scripts/show_eval_results.py 脚本:
# 查看所有评估结果
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist
# 查看指标趋势
uv run scripts/show_eval_results.py --output-dir output/... --trend rFID
# 导出为 CSV
uv run scripts/show_eval_results.py --output-dir output/... --export results.csv
Q: 评估结果保存在哪里?
A: 评估结果保存在 output_dir/evaluation_history/ 目录下:
- eval_step00003000.json: 单次评估的详细结果
- evaluation_summary.jsonl: 所有结果的汇总文件(JSON Lines 格式)
Q: 如何修改评估频率?
A: 在配置文件中修改 test_interval 参数:
training:
test_interval: 3000 # 每 3000 步评估一次
Q: 评估结果包含哪些指标?
A: 评估结果包含以下指标:
- rFID: reconstruction Fréchet Inception Distance(重建质量)
- IS: Inception Score(图像质量)
- entropy: Codebook 使用熵(量化多样性)
性能优化建议
1. 混合精度训练:
使用 mixed_precision: "bf16" 可以:
- 减少 GPU 内存使用
- 加速训练 (约 2x)
- 保持数值稳定性
2. 梯度累积:
如果 GPU 内存不足,可以使用梯度累积:
per_gpu_batch_size: 16 # 减小 batch size
gradient_accumulation_steps: 2 # 累积 2 步
# 有效 batch size = 16 * 2 * num_gpus
3. 数据加载优化:
增加 num_workers 可以加速数据加载:
num_workers: 8 # 根据CPU核心数调整
4. TF32 (Ampere GPU):
如果使用 Ampere GPU (A100, 3090等),可以启用 TF32:
enable_tf32: true
参考资料
相关脚本:
- scripts/compute_bpp.py: 计算 Bits Per Pixel (BPP) 指标
- scripts/show_eval_results.py: 查看和分析评估结果
- scripts/show_ckpt_structure.py: 显示 checkpoint 结构
相关文档:
- docs/EVAL_RESULTS_EXAMPLE.md: 评估结果文件格式说明
- docs/Config-usage.md: 配置文件使用说明
相关论文:
- TiTok: An Image is Worth 32 Tokens (NeurIPS 2024)
- DiT: Scalable Diffusion Models with Transformers (ICLR 2023)
- JiT: Just Image Transformer
- DiffEIC: Towards Extreme Image Compression with Diffusion Models
- LPIPS: The Unreasonable Effectiveness of Deep Features as a Perceptual Metric (CVPR 2018)