模型训练,到底在做什么
一说到模型训练,大家脑子里冒出来的词通常不是同一层面的东西。有人先想到 loss,有人先想到数据,有人先想到 7B、70B 这样的参数规模,也有人会立刻问需要几张卡、多少显存、训练多久。它们都属于训练,但并不在回答同一个问题。
这也是为什么“训练”这个主题很容易越讲越乱。因为它天然横跨了好几层:你要先说明模型究竟在优化什么,再说明数据是怎样被喂进去的,然后还要解释训练量、显存、优化器、batch、工程技巧这些变量为什么会彼此牵连。只要其中任何两层被混在一起,讨论就会迅速从原理滑向口号。
这篇文章的写法
我会把“训练”拆成一条完整但尽量清晰的流水线来看:训练目标、训练数据、训练量、显存与资源、优化器与工程技巧。你给出的两道题,则会放在“训练量与显存”这一节里,作为具体的落点。
如果只保留一句最核心的话,那么大语言模型训练其实是在做一件非常朴素的事:给定前文,尽可能准确地预测下一个 token。对因果语言模型而言,这就是标准的 autoregressive next-token prediction。
形式上,模型试图学习一个条件概率分布:
也就是说,它并不是在背整段话,而是在序列的每一个位置上,根据已经看到的上下文,对下一个 token 赋予尽量高的概率。训练过程中真正被最小化的,通常就是 cross-entropy loss。你可以把它理解成:真实 token 已经给定,模型需要为这个真实答案分配尽量大的概率,否则就会被 loss 惩罚。
这一点非常关键,因为它决定了训练的底层逻辑并不神秘。模型能力不是凭空出现的,而是在海量位置上反复做这种“条件概率校正”。参数更新的本质,就是把原本分散、不稳定、充满误差的概率分布,一步步压向更符合真实语料分布的方向。
训练目标的最简描述
- 输入
- 一串 token 序列
- 目标
- 预测每个位置的下一个 token
- 损失函数
- 通常是 cross-entropy
- 参数更新
- 让真实 token 的概率更高,让错误分布被压低
模型并不直接看自然语言文本,而是先看被分词器切开的 token 序列。像 BPE 这样的分词方法,本质上是在字符与词之间寻找一个适合统计建模的中间粒度。它要可逆、能处理任意文本、还能尽量把高频子词稳定地复用出来。
所以训练数据的第一层,不是“有多少篇文章”,而是最后被切成了多少 token。而训练数据的第二层,则是这些 token 到底来自哪里。网页文本、书籍、代码、百科、论坛、问答,它们的混合比例会直接塑造模型的语感、知识分布与推理习惯。
这一步远不是把更多文本扔进去就结束了。数据质量至少包含三个重要问题:
- 清洗:去掉乱码、模板垃圾、低信息密度文本
- 去重:避免重复内容把模型推向记忆而不是泛化
- 配比:不同来源的 token 比例会改变模型学到的分布
去重尤其值得单独强调。已有研究指出,大规模语料里往往包含大量近重复样本和长重复片段。如果不做去重,模型不仅更容易背诵训练语料,还会浪费本该用于学习新模式的 token 预算。换句话说,脏 token 和重复 token 并不等于有效训练量。
到了这里,训练这个主题才会进入最常被讨论、也最容易被误解的一层:训练量。因为“训练量”这个词本身就可能在指三件不同的事。
第一种是参数量。7B、13B、70B 这些数字,回答的是模型本身有多大。第二种是token 数。1T、2T、14.8T tokens 这些数字,回答的是模型到底看了多少离散训练单位。第三种则是训练 compute,也就是为了把这些 token 灌进模型,总共花掉了多少浮点运算。
如果只想做一阶量级估算,decoder-only Transformer 的训练 FLOPs 常被写成:
这里的 $N$ 是参数量,$T$ 是训练 token 数。这个式子最重要的意义,不是常数 6 本身,而是它告诉你:训练 compute 的主导项,本质上就是“参数量 × token 数”。模型做大和训练更久,会一起把成本推高。
而 Chinchilla 的贡献,则是把这件事再往前推进一步。它指出,很多大模型并不是“规模不够”,而是 undertrained:参数扩上去了,但 token 没有按比例增加。在固定 compute budget 下,更合理的做法往往不是一味堆参数,而是让参数量和训练 token 数近似同步扩张。一个被广泛引用的经验总结,就是 1 个参数大约对应 20 个 token 的量级关系。
训练量的三层口径
- 模型规模
- 参数量 $N$
- 数据规模
- token 数 $T$
- 训练成本
- 粗略看 $6NT$ 对应的 compute
到这里可以先下一个结论:讨论训练量时,如果只说“模型有多大”,但不说 token 数;或者只说“训练了多少 token”,但不说参数量,其实都只说了一半。
训练一旦落到工程上,最先拦住你的通常不是理论,而是显存。因为训练时你保存的不只是参数,还包括梯度、优化器状态,以及为了反向传播而暂存的中间激活。
最容易产生错觉的地方在于,很多人一开始只会算参数本身。比如一个 7B 模型,FP16 每个参数 2 Byte,于是会得到:
这一步没有错,但它只是在算一份参数副本。训练时还要额外承担:
- FP32 主权重
- FP16 或 BF16 参数副本
- 梯度
- Adam 一阶矩
- Adam 二阶矩
- 激活值与临时缓冲
前五项基本都与参数量线性相关,最后一项激活值则与 micro-batch、sequence length、hidden size 强相关。也就是说,训练显存同时受两类变量支配:一类是“每个参数要背多少状态”,另一类是“每一步前向反向要暂存多少中间结果”。
为什么 sequence length 很可怕
上下文一拉长,注意力层的计算与内存都会迅速膨胀。像 FlashAttention 这类工作,本质上就是在针对这一瓶颈做 IO-aware 的重写:不是改变模型学什么,而是让同样的注意力计算少搬内存、少浪费显存。
例题 1
题目:设模型参数量 $N = 7B$,训练数据量 $T = 1T$ token,采用 FP16 混合精度训练(Adam 优化器)。以下关于显存占用的估算,说法正确的是:
- A. 使用梯度检查点(Gradient Checkpointing)会增加显存占用
- B. 训练时需保存 FP32 主权重、FP16 参数、FP16 梯度,以及 Adam 的一阶和二阶矩(FP32),总显存约 $7B \times (4 + 2 + 2 + 4 + 4)\text{ Byte} = 7B \times 16\text{ Byte} \approx 112\text{GB}$
- C. 仅模型参数占用显存 $= 7B \times 2\text{ Byte} = 14GB$(FP16),训练总显存约等于 14GB
- D. 训练显存仅取决于参数量,与 batch size 和序列长度无关
答案:B
这道题最想纠正的,就是“把一份参数副本误认为训练总显存”的直觉。14GB 只是一份 FP16 参数;而 Adam 训练时,你还要负担主权重、梯度以及两份优化器状态。题目给出的 112GB,是参数相关状态的粗略总账。
更重要的是,即便这个数估出来了,你也还没有把激活值完整算进去。所以训练一个 7B 模型时,真正的瓶颈常常不是“模型能不能放进去”,而是“你有没有足够空间完成一次 forward + backward”。
例题 2
题目:训练脚本支持 gradient accumulation,你把 micro_batch = 1 → 4,并把 accumulation 相应减小,global batch 保持不变。通常最可能出现的变化是( )
- A. 一定吞吐下降
- B. 参数量减少
- C. 显存占用上升
- D. 一定收敛更好
答案:C
这道题的关键在于区分 global batch 和 micro-batch。global batch 表示一次参数更新最终汇总了多少样本;micro-batch 表示单次真正落到卡上的那一小批样本有多大。
当 global batch 保持不变、但 micro-batch 从 1 变到 4 时,每次前向反向同时处理的样本数更多,于是要保留的激活值也更多。accumulation 减少,只是说明你少累积了几轮梯度;它并不能消除“单轮执行规模更大”这个事实。因此显存通常会上升。
训练不是只要有 loss、有数据就能自然完成。真正让模型动起来的是优化器,而真正让大模型跑得动的是一整套工程技巧。它们并不改变训练目标,却强烈改变训练的稳定性、速度和资源边界。
先说优化器。SGD、Adam、AdamW 的区别,表面上是更新公式不同,实质上则是:你希望参数在噪声梯度中怎样移动、怎样保留历史信息、怎样施加正则化。Adam 之所以常用,是因为它用一阶矩和二阶矩去自适应调节更新步幅;AdamW 则进一步把 weight decay 从梯度更新中解耦,使正则化行为更干净。
再说工程技巧。mixed precision 让你在大部分计算中使用低精度,从而节省显存与带宽,但关键状态仍保留高精度。gradient checkpointing 通过重算减少激活内存。FlashAttention 改写了注意力计算的 IO 路径,减少长序列下的显存浪费。gradient accumulation 则是在显存有限时,用多次小步来近似更大的更新 batch。
这些技巧之所以重要,是因为它们共同说明了一件事:训练并不只是一个数学问题,也是一个资源调度问题。很多时候,我们不是不知道该用什么模型,而是知道这个模型如果直接跑,根本落不到现有硬件上。
大 batch 为什么不是无限有利
batch 做大确实有利于并行吞吐,但并不是越大越好。已有研究用 gradient noise scale 来描述“最大有用 batch size”的尺度:超过某个范围以后,继续增大 batch 对时间效率可能有利,对数据效率却未必继续划算。
现在回过头看,“训练”这个大主题其实是一条连续链条。以下是每个环节的核心变量、关键结论和代表性来源。
| 环节 | 核心问题 | 关键变量 / 结论 | 代表性来源 |
|---|---|---|---|
| 训练目标 | 模型在每一个 token 位置学什么 | Next-token prediction · Cross-entropy loss | #Brown et al., 2020 |
| 训练数据 | token 从哪里来、质量如何 | BPE tokenizer · 去重(减少背诵)· 配比决定知识分布 | #Lee et al., 2022 · tiktoken |
| 训练量 | 投入了多少算力 | 参数量 $N$ · token 数 $T$ · $6NT$ · Chinchilla 1:20 | #Hoffmann et al., 2022 · #Kaplan et al., 2020 |
| 显存与资源 | 能否在硬件上成立 | 参数副本 · 梯度 · Adam 状态 · 激活值 · sequence length | NVIDIA mixed precision docs · HF perf docs |
| 优化器与工程技巧 | 能否稳定高效完成训练 | AdamW(解耦 weight decay)· mixed precision · checkpointing · FlashAttention · accumulation | #Loshchilov and Hutter, 2019 · #Dao et al., 2022 · #McCandlish et al., 2018 |
- 训练目标 决定模型在每个 token 位置究竟要学什么
- 训练数据 决定模型能从世界里看到什么、看到多少、看到得是否干净
- 训练量 决定你在参数、token 和 compute 上到底投入了多少
- 显存与资源 决定这件事能不能在现实硬件上成立
- 优化器与工程技巧 决定你能否稳定、高效地把整个过程跑完
很多围绕大模型训练的争论,之所以看起来各说各话,恰恰是因为它们落在这条链的不同位置。有人强调数据质量,有人强调 scaling law,有人强调显存瓶颈,有人强调优化器与 kernel 工程。其实他们不是在谈不同的世界,而是在谈同一条训练流水线的不同环节。
所以,如果以后再有人说“训练一个模型很难”,一个更具体的追问方式应该是:难在什么地方?是目标函数难定义,还是数据难构造?是 token 预算不够,还是显存放不下?是优化不稳定,还是工程吞吐太低?只有把问题放回这条链上,我们才知道真正该优化的是哪一段。
一页回忆
- 训练目标:next-token prediction + cross-entropy。
- 训练数据:token 化之后进入模型,质量取决于清洗、去重、配比。
- 训练量:至少包含参数量、token 数、训练 compute 三层。
- Chinchilla:固定 compute 下,参数量和 token 数应同步扩张;常见经验比约 1:20。
- 训练显存:参数副本、梯度、优化器状态、激活值共同组成。
- micro-batch:直接影响单次执行时的激活峰值。
- AdamW:把 weight decay 与自适应梯度更新解耦。
- 工程技巧:mixed precision、checkpointing、FlashAttention、gradient accumulation 都是在重写资源边界。
- #Brown et al., 2020 — Language Models are Few-Shot Learners
- #Kaplan et al., 2020 — Scaling Laws for Neural Language Models
- #Hoffmann et al., 2022 — Training Compute-Optimal Large Language Models
- #Loshchilov and Hutter, 2019 — Decoupled Weight Decay Regularization
- #Lee et al., 2022 — Deduplicating Training Data Makes Language Models Better
- #McCandlish et al., 2018 — An Empirical Model of Large-Batch Training
- #Dao et al., 2022 — FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- NVIDIA Docs — Train With Mixed Precision
- Hugging Face Transformers Docs — Efficient Training on a Single GPU
- Hugging Face Accelerate Docs — Performing gradient accumulation with Accelerate
- PyTorch Blog — Current and New Activation Checkpointing Techniques in PyTorch
- OpenAI — tiktoken
训练不是单个技巧,也不是单个数字。它更像一条从概率目标到硬件现实的长链:前面是分布建模,后面是资源约束,中间穿过数据、token、参数、优化器与各种工程折中。真正理解训练,不是背下几条经验结论,而是知道每一条结论落在这条链的哪一段上。