ESC
输入关键词搜索文章
目录

模型训练,到底在做什么

目标函数 · 数据 · 训练量 · 显存 · 优化器
把训练这条流水线拆开以后,很多看似神秘的工程问题都会变成一笔一笔可计算的账
5核心章节
2例题
6NT粗略训练 FLOPs
1:20Chinchilla 经验比
Part 0
训练这件事,为什么总是被讲得又大又散

一说到模型训练,大家脑子里冒出来的词通常不是同一层面的东西。有人先想到 loss,有人先想到数据,有人先想到 7B、70B 这样的参数规模,也有人会立刻问需要几张卡、多少显存、训练多久。它们都属于训练,但并不在回答同一个问题。

这也是为什么“训练”这个主题很容易越讲越乱。因为它天然横跨了好几层:你要先说明模型究竟在优化什么,再说明数据是怎样被喂进去的,然后还要解释训练量、显存、优化器、batch、工程技巧这些变量为什么会彼此牵连。只要其中任何两层被混在一起,讨论就会迅速从原理滑向口号。

这篇文章的写法

我会把“训练”拆成一条完整但尽量清晰的流水线来看:训练目标训练数据训练量显存与资源优化器与工程技巧。你给出的两道题,则会放在“训练量与显存”这一节里,作为具体的落点。

Chapter 1
训练的第一层:模型到底在优化什么

如果只保留一句最核心的话,那么大语言模型训练其实是在做一件非常朴素的事:给定前文,尽可能准确地预测下一个 token。对因果语言模型而言,这就是标准的 autoregressive next-token prediction。

形式上,模型试图学习一个条件概率分布:

$$p(x_1,x_2,\ldots,x_T)=\prod_{t=1}^{T} p(x_t \mid x_1,x_2,\ldots,x_{t-1})$$

也就是说,它并不是在背整段话,而是在序列的每一个位置上,根据已经看到的上下文,对下一个 token 赋予尽量高的概率。训练过程中真正被最小化的,通常就是 cross-entropy loss。你可以把它理解成:真实 token 已经给定,模型需要为这个真实答案分配尽量大的概率,否则就会被 loss 惩罚。

这一点非常关键,因为它决定了训练的底层逻辑并不神秘。模型能力不是凭空出现的,而是在海量位置上反复做这种“条件概率校正”。参数更新的本质,就是把原本分散、不稳定、充满误差的概率分布,一步步压向更符合真实语料分布的方向。

训练目标的最简描述

  • 输入
    一串 token 序列
  • 目标
    预测每个位置的下一个 token
  • 损失函数
    通常是 cross-entropy
  • 参数更新
    让真实 token 的概率更高,让错误分布被压低
Chapter 2
训练的第二层:数据并不是“文本很多”这么简单

模型并不直接看自然语言文本,而是先看被分词器切开的 token 序列。像 BPE 这样的分词方法,本质上是在字符与词之间寻找一个适合统计建模的中间粒度。它要可逆、能处理任意文本、还能尽量把高频子词稳定地复用出来。

所以训练数据的第一层,不是“有多少篇文章”,而是最后被切成了多少 token。而训练数据的第二层,则是这些 token 到底来自哪里。网页文本、书籍、代码、百科、论坛、问答,它们的混合比例会直接塑造模型的语感、知识分布与推理习惯。

这一步远不是把更多文本扔进去就结束了。数据质量至少包含三个重要问题:

  1. 清洗:去掉乱码、模板垃圾、低信息密度文本
  2. 去重:避免重复内容把模型推向记忆而不是泛化
  3. 配比:不同来源的 token 比例会改变模型学到的分布

去重尤其值得单独强调。已有研究指出,大规模语料里往往包含大量近重复样本和长重复片段。如果不做去重,模型不仅更容易背诵训练语料,还会浪费本该用于学习新模式的 token 预算。换句话说,脏 token 和重复 token 并不等于有效训练量

一个容易忽略的事实:数据量和有效信息量不是同义词。训练看起来吃了很多 token,不代表它真的看到了很多新的结构。
Chapter 3
训练的第三层:所谓“训练量”,到底在说哪一笔账

到了这里,训练这个主题才会进入最常被讨论、也最容易被误解的一层:训练量。因为“训练量”这个词本身就可能在指三件不同的事。

第一种是参数量。7B、13B、70B 这些数字,回答的是模型本身有多大。第二种是token 数。1T、2T、14.8T tokens 这些数字,回答的是模型到底看了多少离散训练单位。第三种则是训练 compute,也就是为了把这些 token 灌进模型,总共花掉了多少浮点运算。

如果只想做一阶量级估算,decoder-only Transformer 的训练 FLOPs 常被写成:

$$\text{Training FLOPs} \approx 6NT$$

这里的 $N$ 是参数量,$T$ 是训练 token 数。这个式子最重要的意义,不是常数 6 本身,而是它告诉你:训练 compute 的主导项,本质上就是“参数量 × token 数”。模型做大和训练更久,会一起把成本推高。

而 Chinchilla 的贡献,则是把这件事再往前推进一步。它指出,很多大模型并不是“规模不够”,而是 undertrained:参数扩上去了,但 token 没有按比例增加。在固定 compute budget 下,更合理的做法往往不是一味堆参数,而是让参数量和训练 token 数近似同步扩张。一个被广泛引用的经验总结,就是 1 个参数大约对应 20 个 token 的量级关系。

训练量的三层口径

  • 模型规模
    参数量 $N$
  • 数据规模
    token 数 $T$
  • 训练成本
    粗略看 $6NT$ 对应的 compute

到这里可以先下一个结论:讨论训练量时,如果只说“模型有多大”,但不说 token 数;或者只说“训练了多少 token”,但不说参数量,其实都只说了一半。

Chapter 4
训练的第四层:显存为什么总是第一个爆炸

训练一旦落到工程上,最先拦住你的通常不是理论,而是显存。因为训练时你保存的不只是参数,还包括梯度、优化器状态,以及为了反向传播而暂存的中间激活。

最容易产生错觉的地方在于,很多人一开始只会算参数本身。比如一个 7B 模型,FP16 每个参数 2 Byte,于是会得到:

$$7B \times 2\text{ Byte} \approx 14\text{GB}$$

这一步没有错,但它只是在算一份参数副本。训练时还要额外承担:

  1. FP32 主权重
  2. FP16 或 BF16 参数副本
  3. 梯度
  4. Adam 一阶矩
  5. Adam 二阶矩
  6. 激活值与临时缓冲

前五项基本都与参数量线性相关,最后一项激活值则与 micro-batch、sequence length、hidden size 强相关。也就是说,训练显存同时受两类变量支配:一类是“每个参数要背多少状态”,另一类是“每一步前向反向要暂存多少中间结果”。

为什么 sequence length 很可怕

上下文一拉长,注意力层的计算与内存都会迅速膨胀。像 FlashAttention 这类工作,本质上就是在针对这一瓶颈做 IO-aware 的重写:不是改变模型学什么,而是让同样的注意力计算少搬内存、少浪费显存。

Chapter 4.1
例题一:7B 模型为什么远不止 14GB

例题 1

题目:设模型参数量 $N = 7B$,训练数据量 $T = 1T$ token,采用 FP16 混合精度训练(Adam 优化器)。以下关于显存占用的估算,说法正确的是:

  1. A. 使用梯度检查点(Gradient Checkpointing)会增加显存占用
  2. B. 训练时需保存 FP32 主权重、FP16 参数、FP16 梯度,以及 Adam 的一阶和二阶矩(FP32),总显存约 $7B \times (4 + 2 + 2 + 4 + 4)\text{ Byte} = 7B \times 16\text{ Byte} \approx 112\text{GB}$
  3. C. 仅模型参数占用显存 $= 7B \times 2\text{ Byte} = 14GB$(FP16),训练总显存约等于 14GB
  4. D. 训练显存仅取决于参数量,与 batch size 和序列长度无关

答案:B

这道题最想纠正的,就是“把一份参数副本误认为训练总显存”的直觉。14GB 只是一份 FP16 参数;而 Adam 训练时,你还要负担主权重、梯度以及两份优化器状态。题目给出的 112GB,是参数相关状态的粗略总账。

更重要的是,即便这个数估出来了,你也还没有把激活值完整算进去。所以训练一个 7B 模型时,真正的瓶颈常常不是“模型能不能放进去”,而是“你有没有足够空间完成一次 forward + backward”。

易错点:参数显存 ≠ 训练显存;checkpointing 也不会加显存,它是用重算去换更低的激活占用。
Chapter 4.2
例题二:global batch 不变,为什么显存还是会上升

例题 2

题目:训练脚本支持 gradient accumulation,你把 micro_batch = 1 → 4,并把 accumulation 相应减小,global batch 保持不变。通常最可能出现的变化是( )

  1. A. 一定吞吐下降
  2. B. 参数量减少
  3. C. 显存占用上升
  4. D. 一定收敛更好

答案:C

这道题的关键在于区分 global batchmicro-batch。global batch 表示一次参数更新最终汇总了多少样本;micro-batch 表示单次真正落到卡上的那一小批样本有多大。

当 global batch 保持不变、但 micro-batch 从 1 变到 4 时,每次前向反向同时处理的样本数更多,于是要保留的激活值也更多。accumulation 减少,只是说明你少累积了几轮梯度;它并不能消除“单轮执行规模更大”这个事实。因此显存通常会上升。

一句话记忆:gradient accumulation 的意义,是用更多小步去模拟本来放不下的大 batch;反过来把 micro-batch 调大,本质上就是把显存压力重新拉高。
Chapter 5
训练的第五层:优化器与工程技巧,为什么会决定你能不能把想法跑起来

训练不是只要有 loss、有数据就能自然完成。真正让模型动起来的是优化器,而真正让大模型跑得动的是一整套工程技巧。它们并不改变训练目标,却强烈改变训练的稳定性、速度和资源边界。

先说优化器。SGD、Adam、AdamW 的区别,表面上是更新公式不同,实质上则是:你希望参数在噪声梯度中怎样移动、怎样保留历史信息、怎样施加正则化。Adam 之所以常用,是因为它用一阶矩和二阶矩去自适应调节更新步幅;AdamW 则进一步把 weight decay 从梯度更新中解耦,使正则化行为更干净。

再说工程技巧。mixed precision 让你在大部分计算中使用低精度,从而节省显存与带宽,但关键状态仍保留高精度。gradient checkpointing 通过重算减少激活内存。FlashAttention 改写了注意力计算的 IO 路径,减少长序列下的显存浪费。gradient accumulation 则是在显存有限时,用多次小步来近似更大的更新 batch。

这些技巧之所以重要,是因为它们共同说明了一件事:训练并不只是一个数学问题,也是一个资源调度问题。很多时候,我们不是不知道该用什么模型,而是知道这个模型如果直接跑,根本落不到现有硬件上。

大 batch 为什么不是无限有利

batch 做大确实有利于并行吞吐,但并不是越大越好。已有研究用 gradient noise scale 来描述“最大有用 batch size”的尺度:超过某个范围以后,继续增大 batch 对时间效率可能有利,对数据效率却未必继续划算。

Chapter 6
把整条训练链放回一起看

现在回过头看,“训练”这个大主题其实是一条连续链条。以下是每个环节的核心变量、关键结论和代表性来源。

环节 核心问题 关键变量 / 结论 代表性来源
训练目标 模型在每一个 token 位置学什么 Next-token prediction · Cross-entropy loss #Brown et al., 2020
训练数据 token 从哪里来、质量如何 BPE tokenizer · 去重(减少背诵)· 配比决定知识分布 #Lee et al., 2022 · tiktoken
训练量 投入了多少算力 参数量 $N$ · token 数 $T$ · $6NT$ · Chinchilla 1:20 #Hoffmann et al., 2022 · #Kaplan et al., 2020
显存与资源 能否在硬件上成立 参数副本 · 梯度 · Adam 状态 · 激活值 · sequence length NVIDIA mixed precision docs · HF perf docs
优化器与工程技巧 能否稳定高效完成训练 AdamW(解耦 weight decay)· mixed precision · checkpointing · FlashAttention · accumulation #Loshchilov and Hutter, 2019 · #Dao et al., 2022 · #McCandlish et al., 2018
  • 训练目标 决定模型在每个 token 位置究竟要学什么
  • 训练数据 决定模型能从世界里看到什么、看到多少、看到得是否干净
  • 训练量 决定你在参数、token 和 compute 上到底投入了多少
  • 显存与资源 决定这件事能不能在现实硬件上成立
  • 优化器与工程技巧 决定你能否稳定、高效地把整个过程跑完

很多围绕大模型训练的争论,之所以看起来各说各话,恰恰是因为它们落在这条链的不同位置。有人强调数据质量,有人强调 scaling law,有人强调显存瓶颈,有人强调优化器与 kernel 工程。其实他们不是在谈不同的世界,而是在谈同一条训练流水线的不同环节。

所以,如果以后再有人说“训练一个模型很难”,一个更具体的追问方式应该是:难在什么地方?是目标函数难定义,还是数据难构造?是 token 预算不够,还是显存放不下?是优化不稳定,还是工程吞吐太低?只有把问题放回这条链上,我们才知道真正该优化的是哪一段。

Review
复习速查

一页回忆

  • 训练目标:next-token prediction + cross-entropy。
  • 训练数据:token 化之后进入模型,质量取决于清洗、去重、配比。
  • 训练量:至少包含参数量、token 数、训练 compute 三层。
  • Chinchilla:固定 compute 下,参数量和 token 数应同步扩张;常见经验比约 1:20。
  • 训练显存:参数副本、梯度、优化器状态、激活值共同组成。
  • micro-batch:直接影响单次执行时的激活峰值。
  • AdamW:把 weight decay 与自适应梯度更新解耦。
  • 工程技巧:mixed precision、checkpointing、FlashAttention、gradient accumulation 都是在重写资源边界。
结语

训练不是单个技巧,也不是单个数字。它更像一条从概率目标到硬件现实的长链:前面是分布建模,后面是资源约束,中间穿过数据、token、参数、优化器与各种工程折中。真正理解训练,不是背下几条经验结论,而是知道每一条结论落在这条链的哪一段上。