ESC
输入关键词搜索文章
目录

大模型流水线并行训练优化

CodeFun2000 · P4982
显存约束下的最优切分:二分答案 + 贪心验证
7难度
6通过
华为来源
二分算法标签
Part 0
题目描述

在华为昇腾集群上训练千亿参数大模型时,单个 NPU 的 HBM 显存装不下整个模型,需要用流水线并行(Pipeline Parallelism)将模型按序切分为多个连续阶段,分配到不同 NPU 上。

给定 $N$ 层神经网络,需要划分到 $K$ 个 NPU 上,每个 NPU 至少分到 $1$ 层,且分配给同一个 NPU 的层必须在原始顺序中连续。第 $i$ 层的计算耗时为 $C[i]$,显存需求为 $W[i]$,每个 NPU 的显存上限为 $M$

形式化定义

$N$ 层按序切分为 $K$ 段连续子数组:$[1..p_1], [p_1{+}1..p_2], \ldots, [p_{K-1}{+}1..N]$

$j$ 段的计算耗时为 $T_j = \sum_{i \in \text{段 }j} C[i]$,显存为 $S_j = \sum_{i \in \text{段 }j} W[i]$

$$\min \max_{1 \le j \le K} T_j \quad \text{s.t.} \quad S_j \le M \;\;\forall j$$

输入 / 输出格式

输入:第一行 $N, K, M$;第二行 $N$ 个整数 $C[]$;第三行 $N$ 个整数 $W[]$

输出:最小的最大计算耗时。若 $K > N$ 或无合法方案,输出 $-1$

无解条件
  • $K > N$(无法保证每卡至少 1 层);
  • 任意单层显存 $W[i] > M$
  • 连续层的显存约束使得无法装进 $K$ 个 NPU。
  • Part 1
    样例分析

    样例 1:$K > N$,直接无解

    输入

    3 4 100
    1 1 1
    1 1 1

    输出-1

    解释:3 层网络分给 4 个 NPU,无法保证每个 NPU 至少 1 层,直接返回 $-1$

    样例 2:显存约束导致无解

    输入

    4 2 10
    2 2 2 2
    6 6 6 6

    输出-1

    解释:每层显存 $6$,两层合计 $12 > 10$,所以每层必须独占一个 NPU,需要 4 个。但只有 2 个 NPU,无解。

    样例 3:最优切分

    输入

    5 3 20
    5 1 2 3 4
    10 5 5 5 10

    输出6

    解释:最佳方案 $[\text{层}0,\text{层}1] \mid [\text{层}2,\text{层}3] \mid [\text{层}4]$

    • NPU1:耗时 $5{+}1=6$,显存 $10{+}5=15 \le 20$
    • NPU2:耗时 $2{+}3=5$,显存 $5{+}5=10 \le 20$
    • NPU3:耗时 $4$,显存 $10 \le 20$

    最大耗时 $6$,不存在更优的合法方案。

    Part 2
    核心思路

    这道题的本质是:在额外约束(每段显存不超过 $M$)下,将数组切分成 $K$ 段,最小化最大段和。这是经典的二分答案问题。

    关键观察

    观察 1(单调性):如果"最大耗时不超过 $T$"有解,那么 $T' > T$ 也一定有解。答案关于 $T$ 单调,可以二分。

    观察 2(贪心验证):给定 $T$,用贪心验证可行性:从左到右扫描,尽量往当前段塞层,当塞入下一层会超 $T$ 或超 $M$ 时就断开新起一段。如果最终段数 $\le K$,则 $T$ 可行。

    观察 3(下界):答案至少是 $\max_i C[i]$(最慢的单层耗时),至多是 $\sum_i C[i]$(所有层塞给一个 NPU)。这就是二分的搜索范围。

    算法框架:

    1. 预处理:检查 $K > N$ 或任意 $W[i] > M$,直接返回 $-1$
    2. 二分答案:$L = \max(C)$$R = \sum(C)$,每次取 $mid$ 进行贪心验证。
    3. 贪心验证:从第 1 层开始,累计计算耗时和显存。当加入下一层会超过 $mid$$M$ 时,断开新段。
    4. 若最终段数 $\le K$,说明 $mid$ 可行,尝试更小;否则需要更大。

    为什么贪心是对的?

    在固定 $T$ 的前提下,每段"尽量多塞"不会比"提前断开"更差——提前断开只会增加段数,让后续段更难安排。所以贪心地让每段尽可能长,段数最少,是最有希望满足 $\le K$ 的策略。

    Part 3
    贪心验证的详细逻辑

    给定阈值 $T$,贪心过程如下:

    维护当前段的计算耗时和 $t$、显存和 $s$、已用段数 $cnt$。初始 $t = s = 0$$cnt = 1$

    对每一层 $i$,判断能否加入当前段:

    $$\text{如果 } t + C[i] \le T \;\text{且}\; s + W[i] \le M \text{,则加入当前段}$$
    $$\text{否则,新起一段:}\; cnt \leftarrow cnt + 1,\; t = C[i],\; s = W[i]$$

    如果任何时刻 $C[i] > T$$W[i] > M$,说明单个层就超过了限制,$T$ 不可行。实际上 $C[i] > T$ 在我们的二分下界 $L = \max(C)$ 的设定下不会发生,但 $W[i] > M$ 需要在预处理中捕获。

    最终如果 $cnt \le K$,则 $T$ 是一个可行的最大耗时上限。

    二分边界:使用左闭右开 $[L, R)$ 的标准二分模板。循环条件 $L < R$$mid = \lfloor(L+R)/2\rfloor$。若 $mid$ 可行则 $R = mid$,否则 $L = mid + 1$。最终 $L$ 即为答案。
    Part 4
    代码实现

    算法步骤:

    1. 读入 $N, K, M$ 及数组 $C, W$
    2. $K > N$ 或存在 $W[i] > M$,输出 $-1$
    3. 二分 $L = \max(C)$$R = \sum(C) + 1$
    4. 贪心验证每个 $mid$
    5. 输出 $L$
    
    def solve():
        import sys
        data = sys.stdin.read().split()
        idx = 0
        N, K, M = int(data[idx]), int(data[idx+1]), int(data[idx+2])
        idx += 3
        C = [int(data[idx+i]) for i in range(N)]
        idx += N
        W = [int(data[idx+i]) for i in range(N)]
    
        # 无解判定
        if K > N:
            print(-1)
            return
        for i in range(N):
            if W[i] > M:
                print(-1)
                return
    
        # 贪心验证:给定阈值 T,最少需要几段
        def check(T):
            cnt = 1
            t = 0  # 当前段计算耗时和
            s = 0  # 当前段显存和
            for i in range(N):
                if t + C[i] <= T and s + W[i] <= M:
                    t += C[i]
                    s += W[i]
                else:
                    cnt += 1
                    t = C[i]
                    s = W[i]
            return cnt <= K
    
        # 二分答案:左闭右开 [L, R)
        L = max(C)
        R = sum(C) + 1
        while L < R:
            mid = (L + R) // 2
            if check(mid):
                R = mid
            else:
                L = mid + 1
    
        print(L)
    
    solve()
      
    关键细节$R$ 的初始值是 $\sum(C) + 1$ 而非 $\sum(C)$,因为使用左闭右开区间。如果二分结束时 $L = \sum(C)$,说明所有层放一个 NPU 仍是最优。
    Part 5
    复杂度分析
    维度复杂度说明
    时间$O(N \log(\sum C))$二分 $\log(\sum C)$ 轮,每轮贪心 $O(N)$ 扫描
    空间$O(N)$存储 $C$$W$ 数组,贪心验证只用常数额外空间

    对于本题的数据规模($N$ 可达 $10^5$ 级别,$C[i]$ 可达 $10^4$ 级别),$\log(\sum C) \approx 30$ 左右,总操作量约 $3 \times 10^6$,在 2 秒时限内绰绰有余。

    相关题目
    参考来源