大模型流水线并行训练优化
在华为昇腾集群上训练千亿参数大模型时,单个 NPU 的 HBM 显存装不下整个模型,需要用流水线并行(Pipeline Parallelism)将模型按序切分为多个连续阶段,分配到不同 NPU 上。
给定 $N$ 层神经网络,需要划分到 $K$ 个 NPU 上,每个 NPU 至少分到 $1$ 层,且分配给同一个 NPU 的层必须在原始顺序中连续。第 $i$ 层的计算耗时为 $C[i]$,显存需求为 $W[i]$,每个 NPU 的显存上限为 $M$。
形式化定义
将 $N$ 层按序切分为 $K$ 段连续子数组:$[1..p_1], [p_1{+}1..p_2], \ldots, [p_{K-1}{+}1..N]$。
第 $j$ 段的计算耗时为 $T_j = \sum_{i \in \text{段 }j} C[i]$,显存为 $S_j = \sum_{i \in \text{段 }j} W[i]$。
输入 / 输出格式
输入:第一行 $N, K, M$;第二行 $N$ 个整数 $C[]$;第三行 $N$ 个整数 $W[]$。
输出:最小的最大计算耗时。若 $K > N$ 或无合法方案,输出 $-1$。
样例 1:$K > N$,直接无解
输入:
3 4 100
1 1 1
1 1 1
输出:-1
解释:3 层网络分给 4 个 NPU,无法保证每个 NPU 至少 1 层,直接返回 $-1$。
样例 2:显存约束导致无解
输入:
4 2 10
2 2 2 2
6 6 6 6
输出:-1
解释:每层显存 $6$,两层合计 $12 > 10$,所以每层必须独占一个 NPU,需要 4 个。但只有 2 个 NPU,无解。
样例 3:最优切分
输入:
5 3 20
5 1 2 3 4
10 5 5 5 10
输出:6
解释:最佳方案 $[\text{层}0,\text{层}1] \mid [\text{层}2,\text{层}3] \mid [\text{层}4]$
- NPU1:耗时 $5{+}1=6$,显存 $10{+}5=15 \le 20$
- NPU2:耗时 $2{+}3=5$,显存 $5{+}5=10 \le 20$
- NPU3:耗时 $4$,显存 $10 \le 20$
最大耗时 $6$,不存在更优的合法方案。
这道题的本质是:在额外约束(每段显存不超过 $M$)下,将数组切分成 $K$ 段,最小化最大段和。这是经典的二分答案问题。
关键观察
观察 1(单调性):如果"最大耗时不超过 $T$"有解,那么 $T' > T$ 也一定有解。答案关于 $T$ 单调,可以二分。
观察 2(贪心验证):给定 $T$,用贪心验证可行性:从左到右扫描,尽量往当前段塞层,当塞入下一层会超 $T$ 或超 $M$ 时就断开新起一段。如果最终段数 $\le K$,则 $T$ 可行。
观察 3(下界):答案至少是 $\max_i C[i]$(最慢的单层耗时),至多是 $\sum_i C[i]$(所有层塞给一个 NPU)。这就是二分的搜索范围。
算法框架:
- 预处理:检查 $K > N$ 或任意 $W[i] > M$,直接返回 $-1$。
- 二分答案:$L = \max(C)$,$R = \sum(C)$,每次取 $mid$ 进行贪心验证。
- 贪心验证:从第 1 层开始,累计计算耗时和显存。当加入下一层会超过 $mid$ 或 $M$ 时,断开新段。
- 若最终段数 $\le K$,说明 $mid$ 可行,尝试更小;否则需要更大。
为什么贪心是对的?
在固定 $T$ 的前提下,每段"尽量多塞"不会比"提前断开"更差——提前断开只会增加段数,让后续段更难安排。所以贪心地让每段尽可能长,段数最少,是最有希望满足 $\le K$ 的策略。
给定阈值 $T$,贪心过程如下:
维护当前段的计算耗时和 $t$、显存和 $s$、已用段数 $cnt$。初始 $t = s = 0$,$cnt = 1$。
对每一层 $i$,判断能否加入当前段:
如果任何时刻 $C[i] > T$ 或 $W[i] > M$,说明单个层就超过了限制,$T$ 不可行。实际上 $C[i] > T$ 在我们的二分下界 $L = \max(C)$ 的设定下不会发生,但 $W[i] > M$ 需要在预处理中捕获。
最终如果 $cnt \le K$,则 $T$ 是一个可行的最大耗时上限。
算法步骤:
- 读入 $N, K, M$ 及数组 $C, W$。
- 若 $K > N$ 或存在 $W[i] > M$,输出 $-1$。
- 二分 $L = \max(C)$,$R = \sum(C) + 1$。
- 贪心验证每个 $mid$。
- 输出 $L$。
def solve():
import sys
data = sys.stdin.read().split()
idx = 0
N, K, M = int(data[idx]), int(data[idx+1]), int(data[idx+2])
idx += 3
C = [int(data[idx+i]) for i in range(N)]
idx += N
W = [int(data[idx+i]) for i in range(N)]
# 无解判定
if K > N:
print(-1)
return
for i in range(N):
if W[i] > M:
print(-1)
return
# 贪心验证:给定阈值 T,最少需要几段
def check(T):
cnt = 1
t = 0 # 当前段计算耗时和
s = 0 # 当前段显存和
for i in range(N):
if t + C[i] <= T and s + W[i] <= M:
t += C[i]
s += W[i]
else:
cnt += 1
t = C[i]
s = W[i]
return cnt <= K
# 二分答案:左闭右开 [L, R)
L = max(C)
R = sum(C) + 1
while L < R:
mid = (L + R) // 2
if check(mid):
R = mid
else:
L = mid + 1
print(L)
solve()
| 维度 | 复杂度 | 说明 |
|---|---|---|
| 时间 | $O(N \log(\sum C))$ | 二分 $\log(\sum C)$ 轮,每轮贪心 $O(N)$ 扫描 |
| 空间 | $O(N)$ | 存储 $C$ 和 $W$ 数组,贪心验证只用常数额外空间 |
对于本题的数据规模($N$ 可达 $10^5$ 级别,$C[i]$ 可达 $10^4$ 级别),$\log(\sum C) \approx 30$ 左右,总操作量约 $3 \times 10^6$,在 2 秒时限内绰绰有余。