大模型 Attention 模块开发
已知大模型常用的 Attention 模块 定义如下:
此处考虑二维情况,其中:
- 所有输入矩阵 $X$ 初始化为 全 1 矩阵
- 所有权重矩阵 $W_1, W_2, W_3$ 初始化为 上三角全 1 矩阵
- softmax 简化为按行归一化:$\text{softmax}(M)_{ij} = \dfrac{M_{ij}}{\sum_j M_{ij}}$
输入 / 输出
输入:三个正整数 $n, m, h$(空格隔开,均小于 100)
输出:结果矩阵 $Y$ 所有元素之和,四舍五入保留整数
样例 1:$n=3,\; m=3,\; h=3$
输入:3 3 3
输出:18
$X$ 为 $3 \times 3$ 全 1 矩阵,$W$ 为 $3 \times 3$ 上三角全 1 矩阵:
计算 $Q = XW$(每行相同):
由于 $Q, K, V$ 三者相同,经过 softmax 归一化后 $Y$ 每行仍为 $(1, 2, 3)$,总和为 $3 \times (1+2+3) = 18$。
样例 2:$n=2,\; m=3,\; h=1$
输入:2 3 1
输出:2
$W$ 退化为 $3 \times 1$ 列向量 $(1, 0, 0)^T$,$Q = XW$ 每行只有一个元素 $1$。
$Y$ 每行为 $(1)$,总和为 $2 \times 1 = 2$。
这道题表面上是矩阵运算模拟,但因为有「全 1 输入 + 上三角权重」的特殊初始化,实际蕴含一条极其优雅的化简路径。
关键观察
观察 1:$X$ 全 1 意味着 $Q = XW$ 的每一行都相同。因此 $Q$ 的信息可以用一个长度为 $h$ 的向量 $\mathbf{q}$ 完全描述,其中 $q_j = \min(j+1,\; m)$($W$ 第 $j$ 列中 1 的个数,因为 $W_{ij}=1$ 当且仅当 $i \leq j$)。
观察 2:$W_1 = W_2 = W_3$,所以 $Q = K = V$,三者完全一致。
观察 3:$S = QK^T$ 的所有元素相同(因为 $Q$ 每行相同),导致 $\text{softmax}(S)$ 每个元素恰好为 $1/h$。
观察 4:最终 $Y = \text{softmax}(S) \cdot V$ 的每一行恰好就是 $\mathbf{q}$ 本身。
推导过程:
第一步:计算 $\mathbf{q}$ 向量
$W$ 是 $m \times h$ 的上三角全 1 矩阵,$W_{ij} = 1$ 当且仅当 $i \leq j$。因此第 $j$ 列的 1 的个数等于满足 $0 \leq i \leq \min(j, m-1)$ 的行数:
以 $m=4,\; h=5$ 为例,$W$ 是 $4 \times 5$ 的上三角全 1 矩阵:
$Q = XW$,而 $X$ 全 1,所以 $q_j$ 就是 $W$ 第 $j$ 列的列和。逐列数 1 的个数:
| 列 $j$ | 哪些行为 1 | 1 的个数 $q_j$ |
|---|---|---|
| 0 | 行 0 | 1 |
| 1 | 行 0, 1 | 2 |
| 2 | 行 0, 1, 2 | 3 |
| 3 | 行 0, 1, 2, 3 | 4 |
| 4 | 行 0, 1, 2, 3 | 4(被 $m$ 截断) |
列索引越大,从上往下覆盖的行越多($j+1$),但不会超过总行数 $m$,所以 $q_j = \min(j+1,\; m)$。
第二步:分析 $QK^T$
由于 $Q$ 的 $n$ 行完全相同(都是 $\mathbf{q}$),$K$ 也是,所以 $QK^T$ 的每个元素都是 $\|\mathbf{q}\|^2 = \sum_{k=0}^{h-1} q_k^2$。
第三步:简化 softmax
$QK^T / \sqrt{h}$ 的所有元素相同,设为 $c$。softmax 按行归一化:每行有 $h$ 个 $c$,所以每个元素变为 $c / (h \cdot c) = 1/h$。
第四步:计算 $Y$
因此 $Y$ 每行就是 $\mathbf{q}$,总和为 $n \times \sum_{j=0}^{h-1} q_j$。
基于上述分析,最终实现只需一行核心计算:
n, m, h = map(int, input().split())
ans = n * sum(min(j + 1, m) for j in range(h))
print(round(ans))
暴力模拟版(用于理解过程)
如果不想走捷径,直接按公式模拟也完全可行($n, m, h < 100$):
import math
n, m, h = map(int, input().split())
# 构造全 1 矩阵 X
X = [[1] * m for _ in range(n)]
# 构造上三角全 1 矩阵 W
W = [[1 if i <= j else 0 for j in range(h)] for i in range(m)]
# 矩阵乘法
def matmul(A, B):
r1, c1 = len(A), len(A[0])
c2 = len(B[0])
C = [[0.0] * c2 for _ in range(r1)]
for i in range(r1):
for k in range(c1):
for j in range(c2):
C[i][j] += A[i][k] * B[k][j]
return C
Q = matmul(X, W)
K = matmul(X, W)
V = matmul(X, W)
# QK^T
S = matmul(Q, [[K[j][i] for j in range(n)] for i in range(h)])
# 除以 sqrt(h)
for i in range(n):
for j in range(n):
S[i][j] /= math.sqrt(h)
# softmax
A = [[0.0] * n for _ in range(n)]
for i in range(n):
row_sum = sum(S[i])
for j in range(n):
A[i][j] = S[i][j] / row_sum
# Y = A @ V
Y = matmul(A, V)
print(round(sum(sum(row) for row in Y)))
| 方法 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 暴力模拟 | $O(n^2 h + n m h)$ | $O(n^2)$ |
| 数学化简 | $O(h)$ | $O(1)$ |
在本题的数据范围($n, m, h < 100$)下,暴力模拟完全可以 AC。但数学化简的思路揭示了一个有趣的事实:当输入具有特殊结构时,Attention 的矩阵运算可以退化到近乎平凡的计算。
解题要点速查
- 核心公式:$\text{ans} = n \cdot \sum_{j=0}^{h-1} \min(j+1, m)$
- 关键洞察:全 1 输入 + 相同权重 → $Q = K = V$ → softmax 均匀分布 → $Y$ 每行等于 $\mathbf{q}$
- 注意:四舍五入用
round(),而非截断