ESC
输入关键词搜索文章
目录

大模型 Attention 模块开发

CodeFun2000 · P3712
模拟简化版 Attention 矩阵运算,从 O(n²h) 到 O(m·h) 的极致优化
4难度
51%通过率
华为来源
模拟算法标签
Part 0
题目描述

已知大模型常用的 Attention 模块 定义如下:

$$Y = \text{softmax}\Bigl(\frac{QK^T}{\sqrt{h}}\Bigr)V$$

此处考虑二维情况,其中:

$$Q, K, V = XW_1,\; XW_2,\; XW_3 \in \mathbb{R}^{n \times h}$$
$$X \in \mathbb{R}^{n \times m}, \quad W_1, W_2, W_3 \in \mathbb{R}^{m \times h}$$
简化约定
  • 所有输入矩阵 $X$ 初始化为 全 1 矩阵
  • 所有权重矩阵 $W_1, W_2, W_3$ 初始化为 上三角全 1 矩阵
  • softmax 简化为按行归一化:$\text{softmax}(M)_{ij} = \dfrac{M_{ij}}{\sum_j M_{ij}}$

输入 / 输出

输入:三个正整数 $n, m, h$(空格隔开,均小于 100)

输出:结果矩阵 $Y$ 所有元素之和,四舍五入保留整数

Part 1
样例验证

样例 1:$n=3,\; m=3,\; h=3$

输入3 3 3

输出18

$X$$3 \times 3$ 全 1 矩阵,$W$$3 \times 3$ 上三角全 1 矩阵:

$$W = \begin{pmatrix} 1 & 1 & 1 \\ 0 & 1 & 1 \\ 0 & 0 & 1 \end{pmatrix}$$

计算 $Q = XW$(每行相同):

$$Q = K = V = \begin{pmatrix} 1 & 2 & 3 \\ 1 & 2 & 3 \\ 1 & 2 & 3 \end{pmatrix}$$

由于 $Q, K, V$ 三者相同,经过 softmax 归一化后 $Y$ 每行仍为 $(1, 2, 3)$,总和为 $3 \times (1+2+3) = 18$

样例 2:$n=2,\; m=3,\; h=1$

输入2 3 1

输出2

$W$ 退化为 $3 \times 1$ 列向量 $(1, 0, 0)^T$$Q = XW$ 每行只有一个元素 $1$

$Y$ 每行为 $(1)$,总和为 $2 \times 1 = 2$

Part 2
思路分析

这道题表面上是矩阵运算模拟,但因为有「全 1 输入 + 上三角权重」的特殊初始化,实际蕴含一条极其优雅的化简路径

关键观察

观察 1$X$ 全 1 意味着 $Q = XW$ 的每一行都相同。因此 $Q$ 的信息可以用一个长度为 $h$ 的向量 $\mathbf{q}$ 完全描述,其中 $q_j = \min(j+1,\; m)$$W$$j$ 列中 1 的个数,因为 $W_{ij}=1$ 当且仅当 $i \leq j$)。

观察 2$W_1 = W_2 = W_3$,所以 $Q = K = V$,三者完全一致。

观察 3$S = QK^T$ 的所有元素相同(因为 $Q$ 每行相同),导致 $\text{softmax}(S)$ 每个元素恰好为 $1/h$

观察 4:最终 $Y = \text{softmax}(S) \cdot V$ 的每一行恰好就是 $\mathbf{q}$ 本身。

推导过程:

第一步:计算 $\mathbf{q}$ 向量

$W$$m \times h$ 的上三角全 1 矩阵,$W_{ij} = 1$ 当且仅当 $i \leq j$。因此第 $j$ 列的 1 的个数等于满足 $0 \leq i \leq \min(j, m-1)$ 的行数:

$$q_j = \min(j+1,\; m)$$

$m=4,\; h=5$ 为例,$W$$4 \times 5$ 的上三角全 1 矩阵:

$$W = \begin{pmatrix} 1 & 1 & 1 & 1 & 1 \\ 0 & 1 & 1 & 1 & 1 \\ 0 & 0 & 1 & 1 & 1 \\ 0 & 0 & 0 & 1 & 1 \end{pmatrix}$$

$Q = XW$,而 $X$ 全 1,所以 $q_j$ 就是 $W$$j$ 列的列和。逐列数 1 的个数:

$j$哪些行为 11 的个数 $q_j$
0行 01
1行 0, 12
2行 0, 1, 23
3行 0, 1, 2, 34
4行 0, 1, 2, 34(被 $m$ 截断)

列索引越大,从上往下覆盖的行越多($j+1$),但不会超过总行数 $m$,所以 $q_j = \min(j+1,\; m)$

第二步:分析 $QK^T$

由于 $Q$$n$ 行完全相同(都是 $\mathbf{q}$),$K$ 也是,所以 $QK^T$ 的每个元素都是 $\|\mathbf{q}\|^2 = \sum_{k=0}^{h-1} q_k^2$

第三步:简化 softmax

$QK^T / \sqrt{h}$ 的所有元素相同,设为 $c$。softmax 按行归一化:每行有 $h$$c$,所以每个元素变为 $c / (h \cdot c) = 1/h$

第四步:计算 $Y$

$$Y_{ij} = \sum_{k=0}^{h-1} \frac{1}{h} \cdot q_j = q_j$$

因此 $Y$ 每行就是 $\mathbf{q}$,总和为 $n \times \sum_{j=0}^{h-1} q_j$

结论:最终答案就是 $n \times \sum_{j=0}^{h-1} \min(j+1,\; m)$,四舍五入取整。
Part 3
代码实现

基于上述分析,最终实现只需一行核心计算:


n, m, h = map(int, input().split())
ans = n * sum(min(j + 1, m) for j in range(h))
print(round(ans))
  

暴力模拟版(用于理解过程)

如果不想走捷径,直接按公式模拟也完全可行($n, m, h < 100$):


import math

n, m, h = map(int, input().split())

# 构造全 1 矩阵 X
X = [[1] * m for _ in range(n)]
# 构造上三角全 1 矩阵 W
W = [[1 if i <= j else 0 for j in range(h)] for i in range(m)]

# 矩阵乘法
def matmul(A, B):
    r1, c1 = len(A), len(A[0])
    c2 = len(B[0])
    C = [[0.0] * c2 for _ in range(r1)]
    for i in range(r1):
        for k in range(c1):
            for j in range(c2):
                C[i][j] += A[i][k] * B[k][j]
    return C

Q = matmul(X, W)
K = matmul(X, W)
V = matmul(X, W)

# QK^T
S = matmul(Q, [[K[j][i] for j in range(n)] for i in range(h)])
# 除以 sqrt(h)
for i in range(n):
    for j in range(n):
        S[i][j] /= math.sqrt(h)

# softmax
A = [[0.0] * n for _ in range(n)]
for i in range(n):
    row_sum = sum(S[i])
    for j in range(n):
        A[i][j] = S[i][j] / row_sum

# Y = A @ V
Y = matmul(A, V)
print(round(sum(sum(row) for row in Y)))
      
Part 4
复杂度分析
方法时间复杂度空间复杂度
暴力模拟$O(n^2 h + n m h)$$O(n^2)$
数学化简$O(h)$$O(1)$

在本题的数据范围($n, m, h < 100$)下,暴力模拟完全可以 AC。但数学化简的思路揭示了一个有趣的事实:当输入具有特殊结构时,Attention 的矩阵运算可以退化到近乎平凡的计算

解题要点速查

  • 核心公式$\text{ans} = n \cdot \sum_{j=0}^{h-1} \min(j+1, m)$
  • 关键洞察:全 1 输入 + 相同权重 → $Q = K = V$ → softmax 均匀分布 → $Y$ 每行等于 $\mathbf{q}$
  • 注意:四舍五入用 round(),而非截断