ESC
输入关键词搜索文章
目录

DiTok Stage 2 训练指南

概述

本文档介绍 DiTok Stage 2 训练流程,用于微调扩散解码器。

训练策略:

预训练 Checkpoint:

本教程使用预训练的 TiTok L32 checkpoint:

配置文件中已默认设置好预训练 checkpoint 路径,无需手动下载。

为什么需要 Stage 2:

  1. TiTok 已经在大规模数据上训练好了,能很好地提取图像特征
  2. 我们只需要学习如何从这些特征生成清晰的图像
  3. 冻结 Encoder 和 Quantizer 可以减少训练参数,加快训练速度

关键特性:

环境设置

本项目使用 uv 作为包管理器。

# 1. 安装 uv (如果还没有安装)
curl -LsSf https://astral.sh/uv/install.sh | sh

# 2. 安装项目依赖
uv sync

# 3. (可选) 添加 LPIPS 感知损失支持
uv add lpips

项目结构:

DiTok/
├── ditok.py                    # DiTok 模型定义
├── modules/                    # 核心模块
│   ├── titok_blocks.py        # TiTok Transformer blocks
│   ├── ditok_blocks.py        # DiTok Transformer blocks
│   ├── quantizer.py           # Vector Quantization
│   └── losses.py              # 损失函数 (LPIPS, VGG等)
├── utils/                      # 工具脚本
│   ├── train_ditok.py         # Stage 2 训练脚本
│   └── logger.py              # 日志工具
├── configs/                    # 配置文件
│   └── training/DiTok/stage2/
│       └── ditok_b64.yaml     # Stage 2 配置
├── pyproject.toml              # 项目配置
├── uv.lock                     # 依赖锁定
└── .python-version             # Python 3.12

快速开始

单 GPU 训练:

# 使用 uv run 运行训练脚本
uv run python utils/train_ditok.py

多 GPU 训练 (推荐):

# 使用 uv run 配合 Accelerate 启动 8 个 GPU
uv run accelerate launch --num_processes=8 --main_process_port=9999 utils/train_ditok.py

自定义配置:

# 注意: 当前版本暂不支持命令行参数
# 请直接修改 configs/training/DiTok/stage2/ditok_l32.yaml
uv run python utils/train_ditok.py

配置文件详解

配置文件位于: configs/training/DiTok/stage2/ditok_l32.yaml

核心配置项:

# ========== 模型配置 ==========
model:
  vq_model:
    vit_enc_patch_size: 16        # Encoder patch size
    vit_enc_model_size: "base"     # Encoder model size
    vit_dec_patch_size: 16        # Decoder patch size
    vit_dec_model_size: "base"     # Decoder model size
    num_latent_tokens: 32         # L32: 32 latent tokens
    token_size: 8                 # Token dimension
    codebook_size: 1024           # VQ codebook size
    commitment_cost: 0.25         # VQ commitment cost

# ========== 预训练 Checkpoint ==========
pretrained_checkpoint: "./checkpoints/tokenizer_titok_l32_imagenet.safetensors"

# ========== 训练配置 ==========
training:
  output_dir: "./output/ditok_l32_stage2"
  per_gpu_batch_size: 32         # 每个 GPU 的 batch size
  gradient_accumulation_steps: 1  # 梯度累积步数
  mixed_precision: "bf16"        # 混合精度训练
  max_train_steps: 500000        # 最大训练步数
  lr: 1.0e-4                    # 初始学习率
  min_lr: 1.0e-6                 # 最小学习率
  warmup_steps: 10000           # 预热步数
  lr_schedule: "cosine"          # 学习率调度
  gradient_clip: 1.0             # 梯度裁剪

  # Stage 2 特定配置
  freeze_encoder: true           # 冻结 encoder
  freeze_quantizer: true         # 冻结 quantizer

  # 感知损失配置
  perceptual_loss_type: "lpips"  # 'none', 'vgg', 'lpips'
  perceptual_weight: 0.5         # 感知损失权重

  # Diffusion 训练参数
  P_std: 0.8                     # 时间步采样标准差
  P_mean: -0.8                   # 时间步采样均值
  t_eps: 5e-2                    # 最小时间步
  noise_scale: 1.0               # 噪声缩放
  label_drop_prob: 0.1           # CFG 标签丢弃概率
  v_weight: 1.0                  # v-prediction loss 权重
  gamma_weight: 1.0              # gamma loss 权重

  # EMA 参数
  ema_decay1: 0.9999             # 主 EMA 衰减率
  ema_decay2: 0.9996             # 辅助 EMA 衰减率

  # 生成参数
  sampling_method: "heun"        # 采样方法
  num_sampling_steps: 50         # 采样步数
  cfg: 1.0                      # CFG 强度

# ========== 数据集配置 ==========
dataset:
  data_path: "./data/imagenet"    # ImageNet 数据集路径
  preprocessing:
    crop_size: 256               # 图像裁剪大小

训练流程详解

训练脚本主要函数:

  1. get_config(): 加载配置文件
  2. create_model_and_loss_module(): 创建模型并冻结 Encoder 和 Quantizer
  3. create_optimizer(): 创建 AdamW 优化器
  4. create_lr_scheduler(): 创建余弦退火学习率调度器
  5. create_dataloader(): 创建数据加载器
  6. train_one_epoch(): 训练一个 epoch
  7. save_checkpoint(): 保存检查点

关键训练步骤:

完整的训练流程包括以下 7 个主要函数的调用:

# ========== 1. 加载配置文件 ==========
config = get_config()  # 加载 YAML 配置或使用默认配置

# ========== 2. 创建模型并冻结参数 ==========
model, _, _ = create_model_and_loss_module(config, logger, accelerator)
# - 创建 DiTok 模型
# - 加载预训练 checkpoint (如果有)
# - 冻结 encoder 和 quantizer

# ========== 3. 创建优化器 ==========
optimizer = create_optimizer(config, logger, model)
# - 收集可训练参数 (diffusion_decoder)
# - 创建 AdamW 优化器

# ========== 4. 创建学习率调度器 ==========
lr_scheduler, _ = create_lr_scheduler(config, logger, accelerator, optimizer)
# - 创建余弦退火学习率调度器

# ========== 5. 创建数据加载器 ==========
train_dataloader, _ = create_dataloader(config, logger, accelerator)
# - 加载 ImageNet 数据集
# - 应用数据增强 (center crop, random flip)

# ========== 6. 训练循环 ==========
global_step = 0
while global_step < config.training.max_train_steps:
    global_step = train_one_epoch(
        config, logger, accelerator, model, optimizer, lr_scheduler,
        train_dataloader, global_step
    )
    # 每个 epoch 内:
    # - 前向传播计算 loss
    # - 反向传播更新参数
    # - 更新 EMA (主进程)
    # - 更新学习率
    # - 定期保存 checkpoint

# ========== 7. 保存最终检查点 ==========
save_checkpoint(model, optimizer, lr_scheduler, global_step, config, logger)
# - 保存模型参数和优化器状态
# - 保存 EMA1 和 EMA2 参数

加载预训练 Checkpoint 详解

load_pretrained_checkpoint 函数负责从 SafeTensors 文件加载预训练权重:

def load_pretrained_checkpoint(model, checkpoint_path, logger):
    """加载预训练 TiTok checkpoint

    处理 SafeTensors 格式,并映射键名:
    - checkpoint 'decoder.*' -> model 'diffusion_decoder.*'
    - checkpoint 'encoder.*' -> model 'encoder.*'
    """
    # 判断文件格式
    if checkpoint_path.endswith('.safetensors'):
        from safetensors import safe_open
        with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
            state_dict = {key: f.get_tensor(key) for key in f.keys()}
    else:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict = checkpoint['model']

    # 创建键名映射
    mapped_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('decoder.'):
            new_key = 'diffusion_decoder.' + key[8:]
            mapped_state_dict[new_key] = value
        elif key.startswith('encoder.'):
            mapped_state_dict[key] = value

    # 加载权重
    load_result = model.load_state_dict(mapped_state_dict, strict=False)

    # 记录加载结果
    logger.info(f"Loaded {len(load_result.loaded_keys)} keys")
    logger.info(f"Missing {len(load_result.missing_keys)} keys")

    return model

创建模型并冻结参数详解

def create_model_and_loss_module(config, logger, accelerator):
    """创建 DiTok 模型

    Stage 2 训练冻结 encoder 和 quantizer,只训练 diffusion decoder。
    """
    model = DiTok(config)

    # 加载预训练 checkpoint
    pretrained_checkpoint = config.training.get("pretrained_checkpoint", None)
    if pretrained_checkpoint and os.path.exists(pretrained_checkpoint):
        model = load_pretrained_checkpoint(model, pretrained_checkpoint, logger)

    # Stage 2: 冻结 encoder
    if config.training.freeze_encoder:
        model.encoder.eval()
        for param in model.encoder.parameters():
            param.requires_grad = False
        logger.info("Encoder frozen for Stage 2 training")

    # Stage 2: 冻结 quantizer
    if config.training.freeze_quantizer:
        model.quantizer.eval()
        for param in model.quantizer.parameters():
            param.requires_grad = False
        logger.info("Quantizer frozen for Stage 2 training")

    # 统计可训练参数
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Trainable parameters: {n_params:,}")

    return model, None, None

创建优化器详解

def create_optimizer(config, logger, model):
    """创建 AdamW 优化器

    只优化 diffusion decoder 的参数。
    """
    # 收集所有可训练参数
    trainable_params = [p for p in model.parameters() if p.requires_grad]

    # 按模块分组统计
    param_count_by_module = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            module = name.split('.')[0]
            param_count_by_module[module] = param_count_by_module.get(module, 0) + param.numel()

    logger.info(f"Trainable parameters by module:")
    for module, count in sorted(param_count_by_module.items()):
        logger.info(f"  {module}: {count:,}")

    # 创建 AdamW 优化器
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=config.training.lr,
        betas=(0.9, 0.95),
        weight_decay=config.training.weight_decay
    )

    return optimizer

训练循环详解

def train_one_epoch(config, logger, accelerator, model, optimizer,
                     train_dataloader, global_step):
    """训练一个 epoch

    关键点:
    1. Encoder 和 Quantizer 必须保持 eval 模式
    2. 图像需要归一化到 [0, 1] 范围
    3. EMA 只在主进程中更新
    """
    model.train()
    # 确保 encoder 和 quantizer 处于 eval 模式
    model.encoder.eval()
    model.quantizer.eval()

    for batch_idx, (images, labels) in enumerate(train_dataloader):
        # 移动数据到设备
        images = images.to(accelerator.device)
        images = images.float() / 255.0  # 归一化到 [0, 1]
        labels = labels.to(accelerator.device)

        # 前向传播
        loss_dict = model(images, labels)
        loss = loss_dict['loss']

        # 反向传播
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

        # 更新 EMA (只在主进程)
        if accelerator.is_main_process:
            model.update_ema()

        # 更新学习率
        lr_scheduler.step()
        global_step += 1

        # 日志记录
        if global_step % config.training.get("log_freq", 100) == 0:
            log_parts = [f"Step {global_step}"]
            for key in ['loss', 'v_loss', 'gamma_loss', 'perceptual_loss']:
                if key in loss_dict:
                    value = loss_dict[key].item() if torch.is_tensor(loss_dict[key]) else loss_dict[key]
                    log_parts.append(f"{key}={value:.4f}")
            logger.info(" ".join(log_parts))

        # 保存 checkpoint
        if global_step % config.training.get("save_freq", 10000) == 0:
            save_checkpoint(model, optimizer, lr_scheduler, global_step, config, logger)

        if global_step >= config.training.max_train_steps:
            break

    return global_step

保存 Checkpoint 详解

def save_checkpoint(model, optimizer, lr_scheduler, global_step, config, logger):
    """保存检查点

    保存内容:
    1. 模型参数和优化器状态
    2. EMA1 和 EMA2 参数(如果已初始化)
    """
    output_dir = config.training.output_dir

    # 保存模型检查点
    model_save_path = os.path.join(output_dir, f"checkpoint-{global_step}.pt")
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict() if lr_scheduler else None,
        'global_step': global_step,
        'config': config
    }, model_save_path)

    # 保存 EMA 参数
    if model.ema_params1 is not None and model.ema_params2 is not None:
        ema_save_path = os.path.join(output_dir, f"ema-{global_step}.pt")

        # 使用 named_parameters() 确保键名匹配
        ema_state_dict1 = {
            name: param.data.cpu()
            for (name, _), param in zip(model.named_parameters(), model.ema_params1)
        }
        ema_state_dict2 = {
            name: param.data.cpu()
            for (name, _), param in zip(model.named_parameters(), model.ema_params2)
        }

        torch.save({
            'ema_state_dict1': ema_state_dict1,
            'ema_state_dict2': ema_state_dict2,
            'global_step': global_step,
            'ema_decay1': model.ema_decay1,
            'ema_decay2': model.ema_decay2,
        }, ema_save_path)

        logger.info(f"Checkpoint saved at step {global_step} (with EMA)")
    else:
        logger.info(f"Checkpoint saved at step {global_step} (no EMA yet)")

加载配置文件详解

def get_config():
    """加载训练配置

    支持从 YAML 文件加载配置,如果文件不存在则使用默认配置。
    """
    # 配置文件路径
    config_path = "configs/training/DiTok/stage2/ditok_l32.yaml"

    # 尝试加载配置文件
    if os.path.exists(config_path):
        config = OmegaConf.load(config_path)
    else:
        # 配置文件不存在时使用默认配置
        print(f"Warning: {config_path} not found, using default config")
        config = OmegaConf.create({
            "dataset": {
                "preprocessing": {"crop_size": 256}
            },
            "model": {
                "vq_model": {
                    "vit_enc_patch_size": 16,
                    "vit_enc_model_size": "base",
                    "vit_dec_patch_size": 16,
                    "vit_dec_model_size": "base",
                    "num_latent_tokens": 32,
                    "token_size": 8,
                    "codebook_size": 1024,
                    "commitment_cost": 0.25,
                    "use_l2_norm": False,
                    "clustering_vq": False,
                    "is_legacy": False,
                }
            },
            "training": {
                "output_dir": "./output/ditok_l32_stage2",
                "per_gpu_batch_size": 32,
                "lr": 1.0e-4,
                "min_lr": 1.0e-6,
                "max_train_steps": 500000,
                "warmup_steps": 10000,
                "lr_schedule": "cosine",
                # Stage 2 特定配置
                "freeze_encoder": True,
                "freeze_quantizer": True,
                # Diffusion 训练参数
                "P_std": 0.8,
                "P_mean": -0.8,
                "v_weight": 1.0,
                "gamma_weight": 1.0,
                # EMA 参数
                "ema_decay1": 0.9999,
                "ema_decay2": 0.9996,
            }
        })

    return config

配置项说明:

创建学习率调度器详解

def create_lr_scheduler(config, logger, accelerator, optimizer):
    """创建学习率调度器

    使用余弦退火调度器,从初始学习率逐渐降低到最小学习率。
    """
    if config.training.lr_schedule == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=config.training.max_train_steps,  # 总步数
            eta_min=config.training.min_lr           # 最小学习率
        )
    else:
        scheduler = None

    return scheduler, None

余弦退火学习率曲线:

学习率按照余弦曲线从初始值衰减到最小值:

$$\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\frac{t\pi}{T_{max}}\right)$$

其中:

学习率调度示例:

Step 学习率 说明
0 1.00e-4 初始学习率
10k ~1.00e-4 Warmup 后
250k 5.05e-5 中点 (50%)
500k 1.00e-6 最小学习率

创建数据加载器详解

def create_dataloader(config, logger, accelerator):
    """创建数据加载器

    加载 ImageNet 数据集并应用数据增强。
    """
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    from torchvision.datasets import ImageFolder

    # 数据增强流程
    transform_train = transforms.Compose([
        # 1. Center crop 到指定大小 (256x256)
        transforms.Lambda(lambda img: center_crop_arr(img,
                         config.dataset.preprocessing.crop_size)),
        # 2. 随机水平翻转 (数据增强)
        transforms.RandomHorizontalFlip(),
        # 3. 转换为 PyTorch tensor (uint8 [0, 255])
        transforms.PILToTensor()
    ])

    # 数据集路径: data_path/train/
    data_path = config.dataset.get("data_path", "./data/imagenet")
    train_dataset = ImageFolder(
        os.path.join(data_path, 'train'),
        transform=transform_train
    )

    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.training.per_gpu_batch_size,  # 每个GPU的batch size
        shuffle=True,                                  # 打乱数据
        num_workers=config.training.get("num_workers", 4),  # 数据加载进程数
        pin_memory=True,                               # 加速GPU数据传输
        drop_last=True                                 # 丢弃最后不完整的batch
    )

    logger.info(f"Train dataset size: {len(train_dataset)}")
    logger.info(f"Train dataloader size: {len(train_loader)}")

    return train_loader, None


def center_crop_arr(img, size):
    """Center crop 辅助函数

    将图像从中心裁剪到指定大小。
    """
    import numpy as np

    img = np.asarray(img)
    h, w = img.shape[:2]
    c_h, c_w = size

    # 计算裁剪起始位置
    top = (h - c_h) // 2
    left = (w - c_w) // 2

    return img[top:top+c_h, left:left+c_w]

数据增强流程说明:

  1. Center Crop (256x256):

    • 从原始图像中心裁剪出 256×256 区域
    • 保证所有图像尺寸一致
    • 适合 ImageNet (通常图像 ≥ 256)
  2. Random Horizontal Flip:

    • 随机水平翻转 (p=0.5)
    • 增加数据多样性
    • 不改变图像语义
  3. PILToTensor:

    • 转换为 PyTorch tensor
    • 值域: uint8 [0, 255]
    • 后续在训练中归一化到 [0, 1]

数据加载器参数说明:

参数 说明
batch_size 32 每个 GPU 的 batch size
shuffle True 每个 epoch 打乱数据
num_workers 4 数据加载进程数 (可调整)
pin_memory True 锁页内存,加速 GPU 传输
drop_last True 丢弃不完整的最后 batch

数据加载器性能优化:

# GPU 内存充足时增大 batch size
per_gpu_batch_size: 64

# CPU 核心多时增加 num_workers
num_workers: 8  # 根据 CPU 核心数调整

# 使用 gradient_accumulation 增大有效 batch size
per_gpu_batch_size: 16
gradient_accumulation_steps: 4
# 有效 batch size = 16 * 4 * num_gpus

EMA (指数移动平均)

EMA 的作用:

EMA 更新公式:

$$\theta_{EMA} = \text{decay} \times \theta_{EMA} + (1 - \text{decay}) \times \theta$$

使用两个 EMA 跟踪器:

保存 EMA 参数:

检查点文件包含:

损失函数

总损失组成:

$$\mathcal{L} = w_v \cdot \mathcal{L}_v + w_\gamma \cdot \mathcal{L}_\gamma + w_p \cdot \mathcal{L}_p + \mathcal{L}_{bpp}$$

其中:

损失权重:

感知损失选择

DiTok 支持两种感知损失:

  1. VGGPerceptualLoss (快速):

    • 基于 VGG-16BN
    • 计算效率高
    • 效果一般
  2. LPIPSLoss (推荐):

    • 学习得到的感知距离
    • 与人类感知高度一致
    • 需要安装: uv add lpips

配置选项:

perceptual_loss_type: "lpips"  # 'none', 'vgg', 'lpips'
perceptual_weight: 0.5
lpips_net: "vgg"              # 'vgg' or 'alex'

数据准备

数据集结构:

data_path/
├── train/
│   ├── n01440764/
│   │   ├── n01440764_10026.JPEG
│   │   ├── n01440764_10027.JPEG
│   │   └── ...
│   ├── n01443537/
│   └── ...
└── val/
    └── ...

数据增强:

训练监控

日志输出:

训练脚本会输出以下信息:

日志示例:

Step 100: loss=3.4567, v_loss=3.4123
Step 200: loss=3.2345, v_loss=3.1987
...

评估结果

训练过程中会自动保存评估结果,便于追踪模型性能变化。

评估结果文件结构:

output/ditok_l32_stage2_filelist/
├── evaluation_history/
│   ├── eval_step00003000.json           # 单次评估的详细结果
│   ├── eval_step00006000.json
│   ├── eval_step00009000.json
│   └── evaluation_summary.jsonl         # 所有结果的汇总文件
└── visualizations/
    ├── comparison_step00003000.png      # 原图与重建图对比
    ├── comparison_step00006000.png
    └── ...

单个评估结果文件格式:

eval_step00003000.json:

{
  "global_step": 3000,
  "timestamp": "2025-02-07T12:34:56.789012",
  "metrics": {
    "rFID": 25.4321,
    "IS": 18.7654,
    "entropy": 9.8765
  },
  "image_path": "output/ditok_l32_stage2_filelist/visualizations/comparison_step00003000.png",
  "image_path_absolute": "/完整路径/.../comparison_step00003000.png"
}

汇总文件格式:

evaluation_summary.jsonl (JSON Lines 格式,每行一个 JSON 对象):

{"global_step": 3000, "timestamp": "2025-02-07T12:34:56", "metrics": {"rFID": 25.43, "IS": 18.76}, "image_path": "output/..."}
{"global_step": 6000, "timestamp": "2025-02-07T13:45:67", "metrics": {"rFID": 23.21, "IS": 19.54}, "image_path": "output/..."}
{"global_step": 9000, "timestamp": "2025-02-07T14:56:78", "metrics": {"rFID": 21.09, "IS": 20.12}, "image_path": "output/..."}

查看评估结果:

使用 scripts/show_eval_results.py 脚本查看和分析评估结果:

# 显示所有评估结果汇总表格
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist

# 显示最近 N 次评估
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist --top-k 5

# 显示指标趋势分析
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist --trend rFID

# 导出为 CSV 文件
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist --export results.csv

输出示例:

汇总表格:

================================================================================
Evaluation Summary
================================================================================

Step         Timestamp            rFID            IS              Image
--------------------------------------------------------------------------------
3000         2025-02-07 12:34     25.4321         18.7654         output/...comparison_step00003000.png
6000         2025-02-07 13:45     23.2100         19.5400         output/...comparison_step00006000.png
9000         2025-02-07 14:56     21.0900         20.1200         output/...comparison_step00009000.png

================================================================================

趋势分析:

================================================================================
Metric Trend Analysis: rFID
================================================================================

Total evaluations: 10
Min value: 21.0900 (at step 9000)
Max value: 25.4321 (at step 3000)
Mean value: 22.8765
Latest value: 21.0900 (at step 9000)
Total change: -4.3421 (-17.08%)
Improving steps: 8/9

Detailed Values:
--------------------------------------------------------------------------------
Step         rFID                 Change
--------------------------------------------------------------------------------
3000         25.4321
6000         23.2100              -2.2221
9000         21.0900              -2.1200

================================================================================

评估配置:

在配置文件中设置评估间隔:

training:
  test_interval: 3000  # 每 3000 步评估一次

自动保存时机:

评估结果会在以下时机自动保存:

  1. 定期评估: 每隔 test_interval 步(默认 3000 步)
  2. 训练结束: 最终模型评估后

路径说明:

编程接口:

也可以在 Python 代码中读取评估历史:

from utils.viz_utils import load_evaluation_history, print_evaluation_summary

# 加载评估历史
history = load_evaluation_history("output/ditok_l32_stage2_filelist")

# 获取最近 N 次评估
from utils.viz_utils import get_latest_evaluation_results
recent_results = get_latest_evaluation_results("output/ditok_l32_stage2_filelist", top_k=5)

# 打印汇总表格
print_evaluation_summary("output/ditok_l32_stage2_filelist")

检查点保存

检查点文件:

加载检查点:

import torch

# 加载模型检查点
checkpoint = torch.load("checkpoint-100000.pt")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

# 加载 EMA 检查点
ema_checkpoint = torch.load("ema-100000.pt")
# 使用 EMA1 参数进行推理
model.load_state_dict(ema_checkpoint['ema_state_dict1'])
# 或者使用 EMA2 参数
# model.load_state_dict(ema_checkpoint['ema_state_dict2'])

常见问题

Q: 为什么要使用 uv?

A: uv 是一个快速的 Python 包管理器,比 pip 快 10-100 倍。它提供了:

Q: 如何调整 batch size?

A: 修改配置文件中的 per_gpu_batch_size。如果 GPU 内存不足,可以减小这个值或使用梯度累积。

Q: 如何使用多个 GPU?

A: 使用 uv run 配合 Accelerate 启动训练脚本:

uv run accelerate launch --num_processes=8 utils/train_ditok.py

Q: 如何恢复训练?

A: 在 create_model_and_loss_module() 函数中添加从检查点加载的逻辑,或者修改配置文件的 output_dir 为之前的输出目录。

Q: 如何使用感知损失?

A: 在配置文件中设置:

Q: EMA 参数有什么用?

A: EMA 参数用于推理时,可以提高生成质量。训练时会自动更新 EMA,推理时加载 EMA 参数即可。

Q: 如何查看训练过程中的评估结果?

A: 使用 scripts/show_eval_results.py 脚本:

# 查看所有评估结果
uv run scripts/show_eval_results.py --output-dir output/ditok_l32_stage2_filelist

# 查看指标趋势
uv run scripts/show_eval_results.py --output-dir output/... --trend rFID

# 导出为 CSV
uv run scripts/show_eval_results.py --output-dir output/... --export results.csv

Q: 评估结果保存在哪里?

A: 评估结果保存在 output_dir/evaluation_history/ 目录下:

Q: 如何修改评估频率?

A: 在配置文件中修改 test_interval 参数:

training:
  test_interval: 3000  # 每 3000 步评估一次

Q: 评估结果包含哪些指标?

A: 评估结果包含以下指标:

性能优化建议

1. 混合精度训练:

使用 mixed_precision: "bf16" 可以:

2. 梯度累积:

如果 GPU 内存不足,可以使用梯度累积:

per_gpu_batch_size: 16        # 减小 batch size
gradient_accumulation_steps: 2  # 累积 2 步
# 有效 batch size = 16 * 2 * num_gpus

3. 数据加载优化:

增加 num_workers 可以加速数据加载:

num_workers: 8  # 根据CPU核心数调整

4. TF32 (Ampere GPU):

如果使用 Ampere GPU (A100, 3090等),可以启用 TF32:

enable_tf32: true

参考资料

相关脚本:

相关文档:

相关论文: