模型层
池化层
激活函数
- 引入非线性变换
- 常用函数:ReLU、Sigmoid、Tanh、GELU
- 影响模型学习复杂模式的能力 一文搞懂激活函数(Sigmoid/ReLU/LeakyReLU/PReLU/ELU) - 知乎
ReLU
ReLU(Rectified Linear Unit)是一种常用的激活函数,定义为:
特点:
- 计算简单:仅需比较和取最大值,无指数运算
- 缓解梯度消失:正区间梯度恒为1,避免深层网络梯度消失
- 稀疏激活:负值输出为零,使网络更高效
变体:
- Leaky ReLU:\( f(x) = \max(\alpha x, x) \)(\(\alpha\) 为小斜率,如0.01),解决“神经元死亡”问题
- Parametric ReLU (PReLU):将 \(\alpha\) 作为可学习参数
- ELU:\( f(x) = \begin{cases} x & x \geq 0 \\ \alpha(e^x - 1) & x < 0 \end{cases} \),改善负值区表现
适用场景:
- 大多数前馈神经网络和卷积网络的隐藏层
- 需注意初始化和学习率,避免大量负输入导致神经元失活
Tanh
Tanh(双曲正切)激活函数:
/特点/:
- 输出范围 \((-1, 1)\),均值为0,有助于中心化数据
- 相比Sigmoid,梯度更强(导数范围 \(0 \to 1\))
- 仍存在梯度消失问题(饱和区梯度接近0)
/导数/:
/适用场景/:
RNN、LSTM等循环网络的隐藏层
需要输出有正负区分的场景
逐渐被ReLU及其变体替代,但在特定架构中仍有价值
用 python 画出 tanh 的双曲正切图
import numpy as np
import matplotlib.pyplot as plt
# 生成数据
x = np.linspace(-5, 5, 1000)
y = np.tanh(x)
# 绘制图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, 'b-', linewidth=2, label='tanh(x)')
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='-', alpha=0.3)
plt.grid(True, alpha=0.3)
plt.xlabel('x')
plt.ylabel('tanh(x)')
plt.title('Hyperbolic Tangent Function (tanh)')
plt.legend()
plt.axis([-5, 5, -1.2, 1.2])
fname = "images/tanh.png"
plt.savefig(fname)
fname
输出说明:
- 曲线呈S形,在(-∞, ∞)上连续可微
- 当x→-∞时,tanh(x)→-1
- 当x→+∞时,tanh(x)→1
- 在x=0处,tanh(0)=0,导数为1
sigmoid
激活函数,将输入压缩到 (0,1):
- 输出可解释为概率
- 梯度平滑,但易饱和导致梯度消失
- 常用于二分类输出层
sigmoid 函数的 python 表示
import numpy as np
import matplotlib.pyplot as plt
def sigmoid(x):
"""Sigmoid激活函数"""
return 1 / (1 + np.exp(-x))
# 示例使用
x = np.linspace(-10, 10, 1000)
y = sigmoid(x)
# 绘制图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, 'r-', linewidth=2, label='sigmoid(x)')
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axhline(y=1, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='-', alpha=0.3)
plt.grid(True, alpha=0.3)
plt.xlabel('x')
plt.ylabel('sigmoid(x)')
plt.title('Sigmoid Activation Function')
plt.legend()
plt.axis([-10, 10, -0.1, 1.1])
plt.show()
# 输出特性
print(f"sigmoid(0) = {sigmoid(0):.3f}") # 输出: 0.5
print(f"sigmoid(5) = {sigmoid(5):.3f}") # 输出: 0.993
print(f"sigmoid(-5) = {sigmoid(-5):.3f}") # 输出: 0.007
关键特性:
- 输出范围:(0, 1)
- 中心点:sigmoid(0) = 0.5
- 单调递增
- 导数:σ'(x) = σ(x)(1-σ(x))
Softmax
多分类激活函数,将向量转换为概率分布:
- 输出和为 1,适用于多分类
- 与交叉熵损失配合使用
- 对输入尺度敏感,需注意数值稳定性
softmax python 实现
import numpy as np
import matplotlib.pyplot as plt
def softmax(x):
"""Softmax activation function with numerical stability"""
exp_x = np.exp(x - np.max(x))
return exp_x / np.sum(exp_x)
# Create sample data
x = np.linspace(-5, 5, 100)
scores = np.array([x, np.zeros_like(x), -x]).T # Three classes
# Apply softmax
probabilities = np.array([softmax(score) for score in scores])
# Plot results
plt.figure(figsize=(10, 6))
for i in range(3):
plt.plot(x, probabilities[:, i], label=f'Class {i+1}', linewidth=2)
plt.xlabel('Input Score Difference')
plt.ylabel('Probability')
plt.title('Softmax Function Visualization')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
fname = "images/softmax.png"
plt.savefig(fname)
plt.close()
fname
# Additional visualization: Single input case
single_scores = np.array([3.0, 1.0, 0.2])
single_probs = softmax(single_scores)
plt.figure(figsize=(8, 5))
plt.bar(range(len(single_scores)), single_probs, color=['skyblue', 'lightcoral', 'lightgreen'])
plt.xticks(range(len(single_scores)), [f'Class {i+1}' for i in range(len(single_scores))])
plt.ylabel('Probability')
plt.title('Softmax Output for Single Input Example')
plt.ylim(0, 1)
for i, prob in enumerate(single_probs):
plt.text(i, prob + 0.02, f'{prob:.3f}', ha='center', va='bottom')
plt.tight_layout()
fname = "images/softmax_output_for_single_input_example.png"
plt.savefig(fname)
plt.close()
fname
每行概率和: [1. 1. 1.]
关键特性:
- 输出范围:(0, 1),总和为1
- 保持原始排序:较大输入对应较大概率
- 数值稳定:减去最大值避免指数溢出
- 常用于多分类最后一层,配合交叉熵损失
SiLU
SiLU(Sigmoid Linear Unit)激活函数,也称为Swish激活函数,是一种平滑、非单调的激活 函数,在深度学习中表现出优越的性能。
数学定义
SiLU激活函数定义为:
其中 \( \sigma(x) \) 是sigmoid函数。
性质分析
- /平滑性/:SiLU是无限可微的平滑函数
- /非单调性/:当 \( x < 0 \) 时,函数值可能为负
- /下界无界/:当 \( x \to -\infty \),\( \text{SiLU}(x) \to 0 \)
- /上界无界/:当 \( x \to +\infty \),\( \text{SiLU}(x) \sim x \)
导数计算
SiLU的导数为:
实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class SiLU(nn.Module):
"""SiLU激活函数实现"""
def __init__(self):
super().__init__()
def forward(self, x):
"""前向传播"""
return x * torch.sigmoid(x)
def derivative(self, x):
"""计算导数"""
sigmoid_x = torch.sigmoid(x)
return sigmoid_x * (1 + x * (1 - sigmoid_x))
class SiLUWithBeta(nn.Module):
"""带可学习参数的SiLU变体"""
def __init__(self, beta=1.0, learnable=False):
"""
参数:
beta: 缩放参数
learnable: 是否可学习
"""
super().__init__()
if learnable:
self.beta = nn.Parameter(torch.tensor(float(beta)))
else:
self.register_buffer('beta', torch.tensor(float(beta)))
def forward(self, x):
"""前向传播"""
return x * torch.sigmoid(self.beta * x)
def extra_repr(self):
return f"beta={self.beta.item():.2f}, learnable={self.beta.requires_grad}"
def demonstrate_silu():
"""演示SiLU激活函数"""
print("SiLU激活函数演示")
print("=" * 50)
# 创建测试数据
x = torch.linspace(-4, 4,
100)
# 计算SiLU函数值
silu = SiLU()
y = silu(x)
dy = silu.derivative(x)
print(f"输入范围: [{x.min():.1
f}, {x.max():.1f}]")
print(f"输出范围: [{y.min():.3f}, {y.max():.3f}]")
print(f"导数范围: [{dy.min():.3f}, {dy.max():.3f}]")
# 关键点分析
critical_points = [-4, -2, 0, 2, 4]
print(f"\n关键点分析:")
for point in critical_points:
x_point = torch.tensor(float(point))
y_point = silu(x_point)
dy_point = silu.derivative(x_point)
print(f" x={point:2.0f}: SiLU={y_point:.3f}, 导数={dy_point:.3f}")
# 与其它激活函数比较
relu = F.relu(x)
leaky_relu = F.leaky_relu(x, 0.01)
gelu = F.gelu(x)
elu = F.elu(x)
print(f"\n激活函数比较 (
x=0):")
print(f" SiLU: {silu(torch.tensor(0.0)):.3f}")
print(f" ReLU: {F.relu(torch.tensor(0.0)):.3f}")
print(f" GELU: {F.gelu(torch.tensor(0.0)):.3f}")
print(f" ELU: {F.elu(torch.tensor(0.0)):.3
f}")
return {
'x': x,
'silu': y,
'silu_derivative': dy,
'comparison': {
'relu': relu,
'leaky_relu': leaky_relu,
'gelu': gelu,
'elu': elu
}
}
silu_demo = demonstrate_silu()
SwiGLU
Back to Basics: Let Denoising Generative Models Denoise (99+ 封私信 / 99+ 条消息) 详解SwiGLU激活函数 - 知乎
SwiGLU(x) = SiLU(W1x) ⊙ (W2x)
其中,
- \(a=xW_1+b_1\) 与 \(b=xW_2+b_2\) 是两个线性变换的输出。
- \(\text{Swish(x)=x\cdot \text{sigmoid}(\beta x)\) ,通常取 \(\beta=1\) 此时等价于 SiLu 激活函数。
也就是 $$\text{SwiGLU(x)}=\text{SiLU}(\text{Linear}_1(x))\times \text{Linear}_2(x)$$
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
self.w2 = nn.Linear(input_dim, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, input_dim, bias=False)
def forward(self, x):
# 分割输入为两部分(假设输入维度为偶数)
assert x.shape[-1] % 2 == 0, "输入维度需为偶数"
a, b = x.chunk(2, dim=-1)
# Swish门控计算
gate = F.silu(self.w1(a)) # F.silu等价于Swish(β=1)
filtered = gate * self.w2(b)
return self.w3(filtered)
作用:
- 大规模语言模型:如LLaMA、PALM等,因其高效的门控机制降低计算冗余,平衡非线性表达与梯度稳定性;
- 长序列建模:平滑梯度缓解了Transformer中的梯度消失问题;
- 低资源训练:相比GELU,Swish的计算效率更高。
GeLU
GELU(Gaussian Error Linear Unit)是一种基于高斯误差函数的激活函数,结合了ReLU和dropout的思想,在Transformer等现代架构中广泛使用。
数学定义
GELU激活函数定义为:
其中 \( \Phi(x) \) 是标准正态分布的累积分布函数,\( \text{ erf} \) 是误差函数。
近似表达式
由于误差函数计算较复杂,常用以下近似:
性质分析
- /平滑性/:GELU是无限可微的平滑函数
- /非单调性/:当 \( x < 0 \) 时,函数值为负
- /概率解释/:基于输入的概率分布进行门控
- /性能优势/:在自然语言处理任务中表现优异
导数计算
GELU的导数为:
其中 \( \phi(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2} \) 是标准正态分布的概率密度函数。
实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erf
def gelu_exact(x):
"""精确GELU实现"""
return 0.5 * x * (1 + erf(x / np.sqrt(2)))
def gelu_approximate(x):
"""近似GELU实现"""
return 0.5 * x * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.
044715 * x**3)))
class GELUAnalysis:
"""GELU激活函数分析"""
def __init__(self):
pass
def analyze_gelu(self, x_range=(-4, 4), num_points=1000):
"""分析GELU函数特性"""
x = np.linspace(x_range[0], x_range[1], num_points)
# 计算精确和近似值
y_exact = gelu_exact(x)
y_approx = gelu_approximate(x)
# 计算导数
phi_x = np.exp(-x**2/2) / np.sqrt(2*np.pi) # 标准正态PDF
Phi_x = 0.5 * (1 + erf(x/np.sqrt(2))) # 标准正态CDF
derivative = Phi_x + x * phi_x
# 误差分析
error = np.abs(y_exact - y_approx)
return {
'x': x,
'y_exact': y_exact,
'y_approx': y_approx,
'derivative': derivative,
'error': error,
'phi_x': phi_x,
'Phi_x': Phi_x
}
def compare_with_other_activations(self, x):
"""与其他激活函数比较"""
x_tensor = torch.tensor(x, dtype=torch.float32)
activations = {
'GELU': F.gelu(x_tensor).numpy(),
'ReLU': F.relu(x_tensor).numpy(),
'SiLU': x_tensor * torch.sigmoid(x_tensor).numpy(),
'ELU': F.elu(x_tensor).numpy(),
'Tanh': torch.tanh(x_tensor).numpy()
}
return activations
def demonstrate_gelu():
"""演示GELU激活函数"""
print("GELU激活函数演示")
print("=" * 50)
analyzer = GELUAnalysis()
results = analyzer.analyze_gelu()
x = results['x']
y_exact = results['y_exact']
y_approx = results['y_approx']
derivative = results['derivative']
error = results['error']
print(f"输入范围: [{x.min():.1f}, {x.max():.1f}]")
print(f"输出范围: [{y_exact.min():.3f}, {y_exact.max():.3f}]")
print(f"导数范围: [{derivative.min():.3f}, {derivative.max():.3f}]")
print(f"近似最大误差: {error.max():.6f}")
# 关键点分析
critical_points = [-3, -1, 0, 1, 3]
print(f"\n关键点分析:")
for point in critical_points:
idx = np.argmin(np.abs(x - point))
y_exact_val = y_exact[idx]
y_approx_val = y_approx[idx]
deriv_val = derivative[idx]
print(f" x={point:2.0f}: GELU={y_exact_val:.3f}, 近似={y_approx_val:.3f}, 导数={deriv_val:.3f}")
# 与其他激活函数比较
activations = analyzer.compare_with_other_activations(x)
print(f"\n激活函数比较 (x=0):")
for name, values in activations.items():
zero_idx = np.argmin(np.abs(x))
print(f" {name}: {values[zero_idx]:.3f}")
return results, activations
gelu_results, gelu_comparison = demonstrate_gelu()
*/ 可视化
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'Hiragino Sans GB', 'STHeiti']
plt.rcParams['axes.unicode_minus'] = False
# 创建可视化
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# GELU函数与近似
ax1 = axes[0, 0]
ax1.plot(gelu_results['x'], gelu_results['y_exact'], 'b-', linewidth=2, label='GELU (精确)')
ax1.plot(gelu_results['x'], gelu_results['y_approx'], 'r--', linewidth=2, label='GELU (近似)')
ax1.set_xlabel('x')
ax1.set_ylabel('GELU(x)')
ax1.set_title('GELU激活函数:精确 vs 近似')
ax1.grid(True, alpha=0.3)
ax1.legend()
# 导数
ax2 = axes[0, 1]
ax2.plot(gelu_results['x'], gelu_results['derivative'], 'g-', linewidth=2, label='GELU导数')
ax2.set_xlabel('x')
ax2.set_ylabel('dGELU/dx')
ax2.set_title('GELU导数')
ax2.grid(True, alpha=0.3)
ax2.legend()
# 近似误差
ax3 = axes[1, 0]
ax3.plot(gelu_results['x'], gelu_results['error'], 'm-', linewidth=2)
ax3.set_xlabel('x')
ax3.set_ylabel('绝对误差')
ax3.set_title('GELU近似误差')
ax3.grid(True, alpha=0.3)
# 与其他激活函数比较
ax4 = axes[1, 1]
for name, values in gelu_comparison.items():
ax4.plot(gelu_results['x'], values, linewidth=2, label=name)
ax4.set_xlabel('x')
ax4.set_ylabel('激活值')
ax4.set_title('GELU与其他激活函数比较')
ax4.grid(True, alpha=0.3)
ax4.legend()
plt.tight_layout()
fname = "images/gelu_analysis.png"
plt.savefig(fname, dpi=300, bbox_inches='tight')
plt.close()
fname
*/ 应用场景
- /Transformer架构/:BERT、GPT等模型的标准激活函数
- /自然语言处理/:在NLP任务中表现优异
- /计算机视觉/:在ViT等视觉Transformer中使用
- /需要平滑激活的场景/:相比ReLU提供更好的梯度流
*/ 优势与局限
优势:
- 平滑且处处可微
- 结合了ReLU和dropout的思想
- 在预训练模型中表现稳定
- 提供概率解释
局限:
- 计算复杂度较高
- 需要近似实现以提高效率
- 在某些硬件上可能不如ReLU高效
归一化层
LayerNorm
层归一化(Layer Normalization)是 Transformer 架构中的关键组件,用于稳定训练过程和提高模型性能。
为什么用LayerNorm 而非 BatchNorm
在 Transformer 架构中选择 LayerNorm 而非 BatchNorm 主要基于以下几个原因:
序列数据特性
- BatchNorm 的问题:在 NLP 任务中,序列长度可变,不同 batch 中的序列长度可能不同,导致统计量计算不稳定
- LayerNorm 的优势:对每个样本独立归一化,不受 batch size 和序列长度影响
计算公式对比
BatchNorm 公式:
$$\text{BatchNorm}(x) = \gamma \frac{x - \mu_{\text{batch}}}{\sigma_{\text{batch}}} + \beta$$其中 \(\mu_{\text{batch}}, \sigma_{\text{batch}}\) 在整个 batch 上计算。LayerNorm 公式:
$$\text{LayerNorm}(x) = \gamma \frac{x - \mu_{\text{layer}}}{\sigma_{\text{layer}}} + \beta$$其中 \(\mu_{\text{layer}}, \sigma_{\text{layer}}\) 在单个样本的特征维度上计算。
训练稳定性
- BatchNorm:依赖大的 batch size 来获得准确的统计量估计,在小 batch 时性能下降
- LayerNorm:对 batch size 不敏感,更适合 NLP 任务中常见的小 batch 训练
推理一致性
- BatchNorm:训练和推理时统计量计算方式不同(移动平均 vs 当前 batch)
- LayerNorm:训练和推理时计算方式完全一致,无需特殊处理
在 Transformer 中的位置
- 原始 Transformer 使用后归一化(Post-LN):注意力/FFN → Add → LayerNorm
- 现代变体常用前归一化(Pre-LN):LayerNorm → 注意力/FFN → Add
- Pre-LN 通常训练更稳定,但可能牺牲一些性能
数学公式
对于一个输入向量 \( x \in \mathbb{R}^d \),层归一化的计算如下:
其中:
- \( \mu = \frac{1}{d} \sum_{i=1}^d x_i \) 是均值
- \( \sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2 \) 是方差
- \( \gamma, \beta \in \mathbb{R}^d \) 是可学习的缩放和偏移参数
- \( \epsilon \) 是小的常数,用于数值稳定性
- \( \odot \) 表示逐元素相乘
与 BatchNorm 的区别
- BatchNorm:在批次维度上归一化,对批次大小敏感
- LayerNorm:在特征维度上归一化,更适合变长序列和不同批次大小
PyTorch 实现
import torch
import torch.nn as nn
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super(LayerNorm, self).__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(d_model)) # 缩放参数
self.beta = nn.Parameter(torch.zeros(d_model)) # 偏移参数
def forward(self, x):
"""
Args:
x: 输入张量 [batch_size, seq_len, d_model] 或 [batch_size, d_model]
Returns:
归一化后的张量
"""
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# 归一化
x_normalized = (x - mean) / torch.sqrt(var + self.eps)
# 缩放和偏移
return self.gamma * x_normalized + self.beta
# 测试代码
if __name__ == "__main__":
batch_size, seq_len, d_model = 2, 5, 512
# 测试 LayerNorm
x = torch.randn(batch_size, seq_len, d_model)
layer_norm = LayerNorm(d_model)
output = layer_norm(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"输出均值: {output.mean().item():.6f}")
print(f"输出标准差: {output.std().item():.6f}")
# 与 PyTorch 官方实现对比
official_ln = nn.LayerNorm(d_model)
official_output = official_ln(x)
print(f"自定义与官方实现差异: {torch.abs(output - official_output).max().item():.6f}")
在 Transformer 中的应用
在 Transformer 中,LayerNorm 通常应用于:
- 多头注意力后的残差连接
- 前馈网络后的残差连接
具体形式:
GroupNorm
组归一化(Group Normalization)是一种深度学习归一化技术,特别适用于小批量训练和 计算机视觉任务。它将通道维度分组进行归一化,不依赖于批量大小。
数学定义
给定输入张量 \( \mathbf{x} \in \mathbb{R}^{B \times C \times H \times W} \)(批量大小 \( B \)、通道数 \( C \)、高度 \( H \)、宽度 \( W \)),组归一化计算:
其中:
- \( G \) 是分组数
- \( K = C/G \) 是每组通道数
- \( \gamma, \beta \in \mathbb{R}^C \) 是可学习参数
实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GroupNorm(nn.Module):
"""组归一化实现"""
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
"""
参数:
num_groups: 分组数
num_channels: 通道总数
eps: 数值稳定性常数
affine: 是否使用可学习的仿射参数
"""
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
# 参数检查
assert num_channels % num_groups == 0, \
f"通道数{num_channels}必须能被分组数{num_groups}整除"
if self.affine:
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
"""前向传播"""
B, C, H, W = x.shape
# 重塑为 (B, G, C//G, H, W)
x_reshaped = x.view(B, self.num_groups, C // self.num_groups, H, W)
# 计算均值和方差
mean = x_reshaped.mean(dim=[2, 3, 4], keepdim=True)
var = x_reshaped.var(dim=[2, 3, 4], unbiased=False, keepdim=True)
# 归一化
x_normalized = (x_reshaped - mean) / torch.sqrt(var + self.eps)
# 恢复原始形状
x_normalized = x_normalized.view(B, C, H, W)
# 仿射变换
if self.affine:
x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)
return x_normalized
def extra_repr(self):
return f"groups={self.num_groups}, channels={self.num_channels}, " \
f"eps={self.eps}, affine={self.affine}"
class GroupNorm1D(nn.Module):
"""一维组归一化"""
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
assert num_channels % num_groups == 0
if self.affine:
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
"""前向传播"""
B, C, L = x.shape
# 重塑为 (B, G, C//G, L)
x_reshaped = x.view(B, self.num_groups, C // self.num_groups, L)
# 计算均值和方差
mean = x_reshaped.mean(dim=[2, 3], keepdim=True)
var = x_reshaped.var(dim=[2, 3], unbiased=False, keepdim=True)
# 归一化
x_normalized = (x_reshaped - mean) / torch.sqrt(var + self.eps)
# 恢复原始形状
x_normalized = x_normalized.view(B, C, L)
# 仿射变换
if self.affine:
x_normalized = self.weight.view(1, C, 1) * x_normalized + self.bias.view(1, C, 1)
return x_normalized
def demonstrate_group_norm():
"""演示组归一化"""
print("组归一化演示")
print("=" * 50)
# 测试数据 - 图像特征图
batch_size, channels, height, width = 4, 16, 8, 8
x = torch.randn(batch_size, channels, height, width)
print(f"输入张量形状: {x.shape}")
print(f"输入统计 - 均值: {x.mean():.4f}, 标准差: {x.std():.4f}")
# 不同分组数的组归一化
group_configs = [
(1, "LayerNorm-like"), # 1组 = 层归一化
(2, "2 Groups"),
(4, "4 Groups"),
(8, "8 Groups"),
(16, "InstanceNorm-like") # 每组1通道 = 实例归一化
]
results = {}
for num_groups, description in group_configs:
print(f"\n{description} (分组数: {num_groups}):")
group_norm = GroupNorm(num_groups, channels)
y = group_norm(x)
print(f" 输出统计 - 均值: {y.mean():.4f}, 标准差: {y.std():.4f}")
# 验证组内归一化
for b in range(min(2, batch_size)): # 只检查前2个样本
for g in range(num_groups):
group_channels = channels // num_groups
start_ch = g * group_channels
end_ch = (g + 1) * group_channels
group_data = y[b, start_ch:end_ch]
group_mean = group_data.mean().item()
group_std = group_data.std().item()
if g == 0: # 只打印第一组作为示例
print(f" 样本{b}组{g} - 均值: {group_mean:.6f}, 标准差: {group_std:.6f}")
results[description] = y
# 一维组归一化演示
print(f"\n一维组归一化演示:")
batch_size_1d, channels_1d, length = 2, 12, 20
x_1d = torch.randn(batch_size_1d, channels_1d, length)
group_norm_1d = GroupNorm1D(3, channels_1d) # 3组
y_1d = group_norm_1d(x_1d)
print(f" 输入形状: {x_1d.shape}")
print(f" 输出统计 - 均值: {y_1d.mean():.4f}, 标准差: {y_1d.std():.4f}")
return {
'configs': group_configs,
'results': results,
'group_norm_1d': group_norm_1d,
'output_1d': y_1d
}
group_norm_demo = demonstrate_group_norm()
与其它归一化方法的比较
class BatchNorm2d(nn.Module):
"""批归一化实现(用于比较)"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.training = True
def forward(self, x):
"""前向传播"""
B, C, H, W = x.shape
if self.training:
# 沿批量、空间维度计算统计量
mean = x.mean(dim=[0, 2, 3])
var = x.var(dim=[0, 2, 3], unbiased=False)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
mean = self.running_mean
var = self.running_var
# 归一化
x_normalized = (x - mean.view(1, C, 1, 1)) / torch.sqrt(var.view(1, C, 1, 1) + self.eps)
return self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)
class InstanceNorm2d(nn.Module):
"""实例归一化实现"""
def __init__(self, num_features, eps=1e-5, affine=True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
"""前向传播"""
B, C, H, W = x.shape
# 沿空间维度计算统计量
mean = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], unbiased=False, keepdim=True)
# 归一化
x_normalized = (x - mean) / torch.sqrt(var + self.eps)
# 仿射变换
if self.affine:
x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)
return x_normalized
def compare_normalization_methods():
"""比较不同归一化方法"""
print("归一化方法比较")
print("=" * 50)
# 测试不同批量大小
batch_sizes = [1, 4, 16]
channels, height, width = 8, 16, 16
results = {}
for batch_size in batch_sizes:
print(f"\n批量大小: {batch_size}")
print("-" * 30)
x = torch.randn(batch_size, channels, height, width)
# 批归一化
if batch_size > 1: # 批归一化需要批量大小 > 1
batch_norm = BatchNorm2d(channels)
y_batch = batch_norm(x)
batch_mean = y_batch.mean().item()
batch_std = y_batch.std().item()
print(f" 批归一化 - 均值: {batch_mean:.4f}, 标准差: {batch_std:.4f}")
else:
print(f" 批归一化 - 不适用(批量大小=1)")
# 组归一化 (4组)
group_norm = GroupNorm(4, channels)
y_group = group_norm(x)
group_mean = y_group.mean().item()
group_std = y_group.std().item()
print(f" 组归一化 - 均值: {group_mean:.4f}, 标准差: {group_std:.4f}")
# 实例归一化
instance_norm = InstanceNorm2d(channels)
y_instance = instance_norm(x)
instance_mean = y_instance.mean().item()
instance_std = y_instance.std().item()
print(f" 实例归一化 - 均值: {instance_mean:.4f}, 标准差: {instance_std:.4f}")
# 层归一化 (重塑为2D)
layer_norm = nn.LayerNorm([channels, height, width])
y_layer = layer_norm(x)
layer_mean = y_layer.mean().item()
layer_std = y_layer.std().item()
print(f" 层归一化 - 均值: {layer_mean:.4f}, 标准差: {layer_std:.4f}")
results[batch_size] = {
'group': y_group,
'instance': y_instance,
'layer': y_layer
}
if batch_size > 1:
results[batch_size]['batch'] = y_batch
# 计算效率比较
print(f"\n计算效率比较:")
large_tensor = torch.randn(32, 64, 32, 32)
import time
# 组归一化
gn = GroupNorm(8, 64)
start = time.time()
_ = gn(large_tensor)
gn_time = time.time() - start
# 批归一化
bn = BatchNorm2d(64)
start = time.time()
_ = bn(large_tensor)
bn_time = time.time() - start
print(f" 组归一化时间: {gn_time:.6f}s")
print(f" 批归一化时间: {bn_time:.6f}s")
print(f" 相对效率: {gn_time/bn_time:.2f}x")
return results
normalization_comparison = compare_normalization_methods()
在ResNet中的应用
class BasicBlock(nn.Module):
"""基础残差块(使用组归一化)"""
def __init__(self, in_channels, out_channels, stride=1, groups=8):
super().__init__()
# 第一个卷积层
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.gn1 = GroupNorm(groups, out_channels)
# 第二个卷积层
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.gn2 = GroupNorm(groups, out_channels)
# 激活函数
self.relu = nn.ReLU(inplace=True)
# 下采样
self.downsample = None
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, bias=False),
GroupNorm(groups, out_channels)
)
def forward(self, x):
"""前向传播"""
identity = x
# 主路径
out = self.conv1(x)
out = self.gn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.gn2(out)
# 捷径
if self.downsample is not None:
identity = self.downsample(x)
# 残差连接
out += identity
out = self.relu(out)
return out
class ResNetGN(nn.Module):
"""使用组归一化的ResNet"""
def __init__(self, block, layers, num_classes=1000, groups=8):
super().__init__()
self.in_channels = 64
self.groups = groups
# 初始卷积层
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.gn1 = GroupNorm(groups, 64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 残差层
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# 分类头
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, block, out_channels, blocks, stride):
"""构建残差层"""
layers = []
layers.append(block(self.in_channels, out_channels, stride, self.groups))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels, 1, self.groups))
return nn.Sequential(*layers)
def forward(self, x):
"""前向传播"""
x = self.conv1(x)
x = self.gn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def demonstrate_resnet_gn():
"""演示ResNet中的组归一化"""
print("ResNet中的组归一化")
print("=" * 50)
# 创建小型ResNet
model = ResNetGN(BasicBlock, [2, 2, 2, 2], num_classes=10, groups=8)
print(f"模型结构:")
print(f" 总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(f" 分组数: {model.groups}")
# 测试前向传播
batch_size = 2
x = torch.randn(batch_size, 3, 32, 32)
with torch.no_grad():
output = model(x)
print(f"\n前向传播测试:")
print(f" 输入形状: {x.shape}")
print(f" 输出形状: {output.shape}")
print(f" 输出范围: [{output.min():.3f}, {output.max():.3f}]")
# 不同批量大小测试
print(f"\n不同批量大小测试:")
for bs in [1, 2, 4, 8]:
x_test = torch.randn(bs, 3, 32, 32)
with torch.no_grad():
output_test = model(x_test)
print(f" 批量大小{bs:2d} - 输出均值: {output_test.mean():.4f}, 标准差: {output_test.std():.4f}")
# 梯度测试
print(f"\n梯度稳定性测试:")
x_grad = torch.randn(2, 3, 32, 32, requires_grad=True)
output_grad = model(x_grad)
loss = output_grad.sum()
loss.backward()
grad_norm = x_grad.grad.norm()
print(f" 输入梯度范数: {grad_norm:.4f}")
return model
resnet_gn_demo = demonstrate_resnet_gn()
BatchNorm
批归一化(Batch Normalization)是深度学习中广泛使用的归一化技术,通过对每个特征通道在批次维度上进行归一化来加速训练并提高模型稳定性。
数学定义
给定输入张量 \( \mathbf{x} \in \mathbb{R}^{B \times C \times H \times W} \)(批量大小 \( B \)、通道数 \( C \)、高度 \( H \)、宽度 \( W \)),批归一化计算:
其中:
- \( \mu_c, \sigma_c^2 \) 是第 \( c \) 个通道在批次和空间维度上的均值和方差
- \( \gamma_c, \beta_c \) 是可学习的缩放和偏移参数
- \( \epsilon \) 是数值稳定性常数
训练与推理模式
训练模式
在训练时,使用当前批次的统计量:
$$\mu_{\text{train}} = \mu_{\text{batch}}, \quad \sigma^2_{\text{train}} = \sigma^2_{\text{batch}}$$同时更新运行统计量:$$\mu_{\text{running}} = (1 - \text{momentum}) \cdot \mu_{\text{running}} + \text{momentum} \cdot \mu_{\text{batch}}$$$$\sigma^2_{\text{running}} = (1 - \text{momentum}) \cdot \sigma^2_{\text{running}} + \text{momentum} \cdot \sigma^2_{\text{batch}}$$
推理模式
在推理时,使用训练期间累积的运行统计量:
$$\mu_{\text{eval}} = \mu_{\text{running}}, \quad \sigma^2_{\text{eval}} = \sigma^2_{\text{running}}$$
实现
import torch
import torch.nn as nn
import numpy as np
class BatchNorm2d(nn.Module):
"""批归一化2D实现"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
"""
参数:
num_features: 特征通道数
eps: 数值稳定性常数
momentum: 运行统计量动量
affine: 是否使用可学习参数
"""
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
# 可学习参数
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
# 运行统计量
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
# 重置参数
self.reset_parameters()
def reset_parameters(self):
"""重置参数"""
if self.affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def forward(self, x):
"""前向传播"""
if self.training:
return self._forward_train(x)
else:
return self._forward_eval(x)
def _forward_train(self, x):
"""训练模式前向传播"""
B, C, H, W = x.shape
# 计算批次统计量
mean = x.mean(dim=[0, 2, 3]) # 沿批次和空间维度
var = x.var(dim=[0, 2, 3], unbiased=False)
# 更新运行统计量
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
self.num_batches_tracked += 1
# 归一化
x_normalized = (x - mean.view(1, C, 1, 1)) / torch.sqrt(var.view(1, C, 1, 1) + self.eps)
# 仿射变换
if self.affine:
x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)
return x_normalized
def _forward_eval(self, x):
"""评估模式前向传播"""
B, C, H, W = x.shape
# 使用运行统计量
mean = self.running_mean
var = self.running_var
# 归一化
x_normalized = (x - mean.view(1, C, 1, 1)) / torch.sqrt(var.view(1, C, 1, 1) + self.eps)
# 仿射变换
if self.affine:
x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)
return x_normalized
def extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={self.affine}"
class BatchNorm1d(nn.Module):
"""批归一化1D实现"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.reset_parameters()
def reset_parameters(self):
if self.affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def forward(self, x):
if self.training:
return self._forward_train(x)
else:
return self._forward_eval(x)
def _forward_train(self, x):
if x.dim() == 2: # (B, C)
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
elif x.dim() == 3: # (B, C, L)
mean = x.mean(dim=[0, 2])
var = x.var(dim=[0, 2], unbiased=False)
else:
raise ValueError(f"期望2D或3D输入,得到 {x.dim()}D")
# 更新运行统计量
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
self.num_batches_tracked += 1
# 归一化
if x.dim() == 2:
x_normalized = (x - mean) / torch.sqrt(var + self.eps)
else:
x_normalized = (x - mean.view(1, -1, 1)) / torch.sqrt(var.view(1, -1, 1) + self.eps)
# 仿射变换
if self.affine:
if x.dim() == 2:
x_normalized = self.weight * x_normalized + self.bias
else:
x_normalized = self.weight.view(1, -1, 1) * x_normalized + self.bias.view(1, -1, 1)
return x_normalized
def _forward_eval(self, x):
mean = self.running_mean
var = self.running_var
if x.dim() == 2:
x_normalized = (x - mean) / torch.sqrt(var + self.eps)
else:
x_normalized = (x - mean.view(1, -1, 1)) / torch.sqrt(var.view(1, -1, 1) + self.eps)
if self.affine:
if x.dim() == 2:
x_normalized = self.weight * x_normalized + self.bias
else:
x_normalized = self.weight.view(1, -1, 1) * x_normalized + self.bias.view(1, -1, 1)
return x_normalized
def demonstrate_batch_norm():
"""演示批归一化"""
print("批归一化演示")
print("=" * 50)
# 2D批归一化测试
print("2D批归一化:")
batch_size, channels, height, width = 4, 8, 16, 16
x_2d = torch.randn(batch_size, channels, height, width)
bn_2d = BatchNorm2d(channels)
y_2d = bn_2d(x_2d)
print(f" 输入形状: {x_2d.shape}")
print(f" 输出统计 - 均值: {y_2d.mean():.6f}, 标准差: {y_2d.std():.6f}")
# 验证通道统计量
print(f" 通道统计量验证:")
for c in range(min(3, channels)): # 只检查前3个通道
channel_mean = y_2d[:, c].mean().item()
channel_std = y_2d[:, c].std().item()
print(f" 通道{c}: 均值={channel_mean:.6f}, 标准差={channel_std:.6f}")
# 1D批归一化测试
print(f"\n1D批归一化:")
batch_size_1d, features, seq_len = 8, 16, 32
x_1d = torch.randn(batch_size_1d, features, seq_len)
bn_1d = BatchNorm1d(features)
y_1d = bn_1d(x_1d)
print(f" 输入形状: {x_1d.shape}")
print(f" 输出统计 - 均值: {y_1d.mean():.6f}, 标准差: {y_1d.std():.6f}")
# 训练与推理模式对比
print(f"\n训练与推理模式对比:")
bn_2d.eval() # 切换到推理模式
y_eval = bn_2d(x_2d)
print(f" 训练模式输出均值: {y_2d.mean():.6f}")
print(f" 推理模式输出均值: {y_eval.mean():.6f}")
print(f" 运行统计量 - 均值: {bn_2d.running_mean[:3]}")
print(f" 运行统计量 - 方差: {bn_2d.running_var[:3]}")
return {
'bn_2d': bn_2d,
'bn_1d': bn_1d,
'output_2d': y_2d,
'output_1d': y_1d,
'output_eval': y_eval
}
batch_norm_demo = demonstrate_batch_norm()
优势与局限
优势
- /加速训练/:减少内部协变量偏移,允许使用更高的学习率
- /正则化效果/:减少对Dropout的依赖
- /梯度流改善/:缓解梯度消失/爆炸问题
- /对初始化不敏感/:降低对权重初始化的依赖
局限
- /批量大小依赖/:小批量时统计量估计不准确
- /序列数据不适用/:在RNN/LSTM中效果有限
- /推理不一致/:训练和推理时计算方式不同
- /分布式训练复杂/:需要同步批次统计量
RMSNorm
RMSNorm(Root Mean Square Normalization)是一种基于均方根的归一化方法,在 Transformer 架构中作为 LayerNorm 的轻量级替代方案,去除了均值中心化操作。
数学定义
对于输入向量 \( x \in \mathbb{R}^d \),RMSNorm 计算如下:
其中:
- \( \text{RMS}(x) \) 是输入的均方根值
- \( g \in \mathbb{R}^d \) 是可学习的缩放参数
- \( \epsilon \) 是数值稳定性常数
- \( \odot \) 表示逐元素相乘
与 LayerNorm 的区别
LayerNorm:
RMSNorm:
关键区别:
- RMSNorm 移除了均值中心化(\( x - \mu \))
- RMSNorm 移除了偏置参数 \( b \)
- RMSNorm 计算量相对较小
- RMSNorm 使用 RMS 而非标准差
优势
- 计算效率高
- 适用于小批量或者是单样本
- 稳定:使用均方根进行归一化,可以在一定程度上避免梯度爆炸和梯度消失,提高训练稳定性
RMSNorm 实现
import torch
import torch.nn as nn
import numpy as np
class RMSNorm(nn.Module):
"""RMSNorm 实现"""
def __init__(self, d_model, eps=1e-8):
"""
参数:
d_model: 特征维度
eps: 数值稳定性常数
"""
super().__init__()
self.d_model = d_model
self.eps = eps
# 可学习的缩放参数
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
"""
前向传播
Args:
x: 输入张量 [batch_size, seq_len, d_model] 或 [batch_size, d_model]
Returns:
归一化后的张量
"""
# 计算均方根 (RMS)
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
# 归一化并缩放
x_normalized = x / rms
return x_normalized * self.weight
def extra_repr(self):
return f"d_model={self.d_model}, eps={self.eps}"
class RMSNorm1D(nn.Module):
"""1D RMSNorm 实现"""
def __init__(self, d_model, eps=1e-8):
super().__init__()
self.d_model = d_model
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
"""处理 1D 输入 [batch_size, d_model]"""
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
x_normalized = x / rms
return x_normalized * self.weight
def demonstrate_rms_norm():
"""演示 RMSNorm"""
print("RMSNorm 演示")
print("=" * 50)
# 测试数据
batch_size, seq_len, d_model = 4, 10, 512
x = torch.randn(batch_size, seq_len, d_model)
print(f"输入形状: {x.shape}")
print(f"输入统计 - 均值: {x.mean():.4f}, 标准差: {x.std():.4f}")
# RMSNorm
rms_norm = RMSNorm(d_model)
y_rms = rms_norm(x)
print(f"\nRMSNorm 输出:")
print(f" 输出统计 - 均值: {y_rms.mean():.6f}, 标准差: {y_rms.std():.6f}")
# 验证 RMS 归一化
print(f"\nRMS 归一化验证:")
for b in range(min(2, batch_size)):
for s in range(min(2, seq_len)):
sample = x[b, s]
rms_manual = torch.sqrt(torch.mean(sample**2))
normalized_manual = sample / rms_manual
normalized_actual = y_rms[b, s] / rms_norm.weight
diff = torch.abs(normalized_manual - normalized_actual).max()
print(f" 样本[{b},{s}] - RMS手动: {rms_manual:.4f}, 最大差异: {diff:.6f}")
break # 只检查第一个位置
break
# 与 LayerNorm 比较
print(f"\n与 LayerNorm 比较:")
layer_norm = nn.LayerNorm(d_model)
y_layer = layer_norm(x)
print(f" LayerNorm 输出 - 均值: {y_layer.mean():.6f}, 标准差: {y_layer.std():.6f}")
# 计算效率比较
import time
# 预热
_ = rms_norm(x)
_ = layer_norm(x)
# RMSNorm 时间
start = time.time()
for _ in range(1000):
_ = rms_norm(x)
rms_time = time.time() - start
# LayerNorm 时间
start = time.time()
for _ in range(1000):
_ = layer_norm(x)
layer_time = time.time() - start
print(f"\n计算效率比较 (1000次前向传播):")
print(f" RMSNorm 时间: {rms_time:.4f}s")
print(f" LayerNorm 时间: {layer_time:.4f}s")
print(f" RMSNorm 相对速度: {layer_time/rms_time:.2f}x")
# 内存使用比较
print(f"\n参数数量比较:")
print(f" RMSNorm 参数量: {sum(p.numel() for p in rms_norm.parameters())}")
print(f" LayerNorm 参数量: {sum(p.numel() for p in layer_norm.parameters())}")
return {
'rms_norm': rms_norm,
'layer_norm': layer_norm,
'output_rms': y_rms,
'output_layer': y_layer,
'timing': {'rms': rms_time, 'layer': layer_time}
}
rms_norm_demo = demonstrate_rms_norm()
在 Transformer 中的应用
class RMSNormTransformerBlock(nn.Module):
"""使用 RMSNorm 的 Transformer 块"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.d_model = d_model
# 自注意力机制
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout)
)
# 使用 RMSNorm 替代 LayerNorm
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
"""前向传播"""
# 自注意力 + 残差连接
attn_output, _ = self.self_attn(
self.norm1(x),
self.norm1(x),
self.norm1(x),
attn_mask=attn_mask
)
x = x + self.dropout(attn_output)
# 前馈网络 + 残差连接
ffn_output = self.ffn(self.norm2(x))
x = x + self.dropout(ffn_output)
return x
def demonstrate_rms_norm_transformer():
"""演示 RMSNorm 在 Transformer 中的应用"""
print("RMSNorm Transformer 演示")
print("=" * 50)
# 创建模型
d_model, nhead, seq_len, batch_size = 512, 8, 20, 4
transformer_block = RMSNormTransformerBlock(d_model, nhead)
print(f"模型配置:")
print(f" d_model: {d_model}")
print(f" nhead: {nhead}")
print(f" 总参数量: {sum(p.numel() for p in transformer_block.parameters()):,}")
# 测试前向传播
x = torch.randn(batch_size, seq_len, d_model)
output = transformer_block(x)
print(f"\n前向传播测试:")
print(f" 输入形状: {x.shape}")
print(f" 输出形状: {output.shape}")
print(f" 输入统计 - 均值: {x.mean():.4f}, 标准差: {x.std():.4f}")
print(f" 输出统计 - 均值: {output.mean():.4f}, 标准差: {output.std():.4f}")
# 梯度测试
print(f"\n梯度稳定性测试:")
x_grad = torch.randn(2, seq_len, d_model, requires_grad=True)
output_grad = transformer_block(x_grad)
loss = output_grad.sum()
loss.backward()
grad_norm = x_grad.grad.norm()
print(f" 输入梯度范数: {grad_norm:.4f}")
# 不同序列长度测试
print(f"\n不同序列长度测试:")
for test_seq_len in [10, 20, 50, 100]:
x_test = torch.randn(batch_size, test_seq_len, d_model)
with torch.no_grad():
output_test = transformer_block(x_test)
print(f" 序列长度{test_seq_len:3d} - 输出均值: {output_test.mean():.4f}, 标准差: {output_test.std():.4f}")
return transformer_block
rms_transformer_demo = demonstrate_rms_norm_transformer()
性能分析
def analyze_rms_norm_performance():
"""分析 RMSNorm 性能"""
print("RMSNorm 性能分析")
print("=" * 50)
# 不同批量大小测试
batch_sizes = [1, 4, 16, 64]
seq_len, d_model = 128, 512
print("不同批量大小下的性能:")
print(f"{'批量大小':<10} {'RMSNorm时间(s)':<15} {'LayerNorm时间(s)':<15} {'速度比':<10}")
print("-" * 50)
for batch_size in batch_sizes:
x = torch.randn(batch_size, seq_len, d_model)
rms_norm = RMSNorm(d_model)
layer_norm = nn.LayerNorm(d_model)
# 预热
_ = rms_norm(x)
_ = layer_norm(x)
# RMSNorm 时间
start = time.time()
for _ in range(100):
_ = rms_norm(x)
rms_time = time.time() - start
# LayerNorm 时间
start = time.time()
for _ in range(100):
_ = layer_norm(x)
layer_time = time.time() - start
speed_ratio = layer_time / rms_time
print(f"{batch_size:<10} {rms_time:<15.4f} {layer_time:<15.4f} {speed_ratio:<10.2f}")
# 不同特征维度测试
print(f"\n不同特征维度下的性能:")
batch_size, seq_len = 8, 64
d_models = [128, 256, 512, 1024, 2048]
print(f"{'特征维度':<10} {'RMSNorm时间(s)':<15} {'LayerNorm时间(s)':<15} {'速度比':<10}")
print("-" * 50)
for d_model in d_models:
x = torch.randn(batch_size, seq_len, d_model)
rms_norm = RMSNorm(d_model)
layer_norm = nn.LayerNorm(d_model)
# 预热
_ = rms_norm(x)
_ = layer_norm(x)
# RMSNorm 时间
start = time.time()
for _ in range(100):
_ = rms_norm(x)
rms_time = time.time() - start
# LayerNorm 时间
start = time.time()
for _ in range(100):
_ = layer_norm(x)
layer_time = time.time() - start
speed_ratio = layer_time / rms_time
print(f"{d_model:<10} {rms_time:<15.4f} {layer_time:<15.4f} {speed_ratio:<10.2f}")
# 内存使用分析
print(f"\n内存使用分析:")
d_model = 512
rms_norm = RMSNorm(d_model)
layer_norm = nn.LayerNorm(d_model)
rms_params = sum(p.numel() for p in rms_norm.parameters())
layer_params = sum(p.numel() for p in layer_norm.parameters())
print(f" RMSNorm 参数量: {rms_params}")
print(f" LayerNorm 参数量: {layer_params}")
print(f" 参数减少比例: {(layer_params - rms_params) / layer_params * 100:.1f}%")
return {
'batch_size_results': batch_sizes,
'dimension_results': d_models
}
performance_analysis = analyze_rms_norm_performance()
适用场景总结
- /大规模语言模型/:如 LLaMA、GPT-NeoX 等现代架构
- /计算资源受限环境/:需要轻量级归一化的场景
- /长序列处理/:在长序列任务中表现稳定
- /小批量训练/:对批量大小不敏感
优势与局限总结
优势:
- 计算效率高,比 LayerNorm 快约 20-40%
- 参数数量少,减少约 50% 的参数
- 对小批量大小不敏感
- 在多种任务中表现与 LayerNorm 相当
局限:
- 在某些任务中可能略逊于 LayerNorm
- 缺乏均值中心化可能影响某些分布的学习
- 相对较新的方法,实践经验较少
AdaLN
在 Layer Normalization(LN)的基础上进行了优化,用来增强AI模型在处理不同输入条件时的适应能力
传统的 LN 没有条件信息,就纯归一化:
AdaLN 则引入了 Modulation
例如:
更为一般的表达是:
# 输入条件的特征提取网络
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
# c代表输入的条件信息
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp ,这6个变量分别对应了多头自注意力机制的归一化参数与缩放参数
AdaLN-Zero
DiT中具体的初始化设置如下所示:
- 对DiT Block中的AdaLN和Linear层均采用参数0初始化。
- 对于其它网络层参数,使用正态分布初始化和xavier初始化。
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
卷积层
基本概念
卷积层是深度学习中用于处理网格状数据(如图像、序列)的核心组件,通过滑动窗口的方式在输入数据上提取局部特征。
数学定义
对于二维卷积,给定输入张量 \( X \in \mathbb{R}^{C_{in} \times H \times W} \) 和卷积核 \( K \in \mathbb{R}^{C_{out} \times C_{in} \times k_h \times k_w} \),输出计算为:
其中:
- \( C_{in}, C_{out} \):输入/输出通道数
- \( k_h, k_w \):卷积核高度和宽度
- \( b \):偏置项
关键参数
- /步长(Stride)/:卷积核移动的步长
- /填充(Padding)/:在输入边界添加的零值区域
- /膨胀率(Dilation)/:卷积核元素间的间距
- /分组(Groups)/:输入输出通道的分组方式
PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv2d(nn.Module):
"""基础二维卷积层"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True):
super().__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias
)
# 初始化权重
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
if bias:
nn.init.constant_(self.conv.bias, 0)
def forward(self, x):
return self.conv(x)
def demonstrate_convolution():
"""演示卷积操作"""
print("卷积层演示")
print("=" * 50)
# 基础卷积
batch_size, in_channels, height, width = 4, 3, 32, 32
x = torch.randn(batch_size, in_channels, height, width)
conv = BasicConv2d(in_channels, 64, kernel_size=3, padding=1)
output = conv(x)
print(f"输入形状: {x.shape}")
print(f"卷积核形状: {conv.conv.weight.shape}")
print(f"输出形状: {output.shape}")
# 不同参数的影响
print(f"\n不同卷积参数的影响:")
# 不同步长
conv_stride2 = BasicConv2d(in_channels, 64, kernel_size=3, stride=2, padding=1)
output_stride2 = conv_stride2(x)
print(f" 步长=2: 输出形状 {output_stride2.shape}")
# 不同填充
conv_pad2 = BasicConv2d(in_channels, 64, kernel_size=5, padding=2)
output_pad2 = conv_pad2(x)
print(f" 核大小=5, 填充=2: 输出形状 {output_pad2.shape}")
# 分组卷积
conv_group = BasicConv2d(in_channels, 64, kernel_size=3, groups=in_channels)
output_group = conv_group(x)
print(f" 分组卷积: 输出形状 {output_group.shape}")
# 膨胀卷积
conv_dilation = BasicConv2d(in_channels, 64, kernel_size=3, dilation=2, padding=2)
output_dilation = conv_dilation(x)
print(f" 膨胀率=2: 输出形状 {output_dilation.shape}")
return {
'basic_conv': conv,
'outputs': {
'basic': output,
'stride2': output_stride2,
'pad2': output_pad2,
'group': output_group,
'dilation': output_dilation
}
}
conv_demo = demonstrate_convolution()
深度可分离卷积
将标准卷积分解为深度卷积和逐点卷积:
class DepthwiseSeparableConv(nn.Module):
"""深度可分离卷积"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
# 深度卷积
self.depthwise = nn.Conv2d(
in_channels, in_channels, kernel_size,
stride=stride, padding=padding, groups=in_channels
)
# 逐点卷积
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
def compare_convolution_types():
"""比较不同类型卷积"""
print("卷积类型比较")
print("=" * 50)
batch_size, in_channels, height, width = 4, 32, 16, 16
out_channels = 64
x = torch.randn(batch_size, in_channels, height, width)
# 标准卷积
standard_conv = BasicConv2d(in_channels, out_channels, kernel_size=3, padding=1)
# 深度可分离卷积
separable_conv = DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, padding=1)
# 计算参数量
standard_params = sum(p.numel() for p in standard_conv.parameters())
separable_params = sum(p.numel() for p in separable_conv.parameters())
print(f"标准卷积参数量: {standard_params:,}")
print(f"深度可分离卷积参数量: {separable_params:,}")
print(f"参数减少比例: {(1 - separable_params/standard_params)*100:.1f}%")
# 前向传播测试
with torch.no_grad():
standard_output = standard_conv(x)
separable_output = separable_conv(x)
print(f"\n输出形状:")
print(f" 标准卷积: {standard_output.shape}")
print(f" 深度可分离卷积: {separable_output.shape}")
return {
'standard_conv': standard_conv,
'separable_conv': separable_conv,
'parameter_ratio': separable_params/standard_params
}
conv_comparison = compare_convolution_types()
转置卷积
用于上采样操作,通过插入零值实现尺寸扩大:
class TransposedConvDemo:
"""转置卷积演示"""
def __init__(self):
pass
def demonstrate_transposed_conv(self):
"""演示转置卷积"""
print("转置卷积演示")
print("=" * 50)
# 输入特征图
batch_size, channels, height, width = 4, 32, 8, 8
x = torch.randn(batch_size, channels, height, width)
print(f"输入形状: {x.shape}")
# 不同上采样倍数的转置卷积
upscale_factors = [2, 4]
for factor in upscale_factors:
# 转置卷积
conv_transpose = nn.ConvTranspose2d(
channels, channels, kernel_size=3,
stride=factor, padding=1, output_padding=factor-1
)
output = conv_transpose(x)
print(f"上采样倍数 {factor}: 输出形状 {output.shape}")
# 与插值方法比较
print(f"\n与插值方法比较:")
# 双线性插值
bilinear_upsample = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
print(f"双线性插值: 输出形状 {bilinear_upsample.shape}")
# 转置卷积
transposed_output = nn.ConvTranspose2d(channels, channels, kernel_size=4, stride=2, padding=1)(x)
print(f"转置卷积: 输出形状 {transposed_output.shape}")
return {
'input': x,
'bilinear': bilinear_upsample,
'transposed': transposed_output
}
transposed_demo = TransposedConvDemo()
transposed_results = transposed_demo.demonstrate_transposed_conv()
*/ 应用场景
- /计算机视觉/:图像分类、目标检测、语义分割
- /自然语言处理/:文本分类、序列建模
- /语音处理/:语音识别、音频生成
- /医学影像/:病灶检测、图像分割
ZeroConvolution
零卷积(Zero Convolution)是一种特殊的卷积初始化技术,在控制网络、条件生成和适配 器模块中广泛应用,通过零初始化确保初始状态下不改变输入特征。
基本概念
零卷积的核心思想是将卷积层的权重和偏置初始化为零,使得网络在训练初期表现为恒等映射:
随着训练进行,卷积层逐渐学习到有意义的变换。
数学定义
对于标准卷积:
零卷积初始化:
实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroConv1d(nn.Module):
"""一维零卷积"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding)
self._initialize_weights()
def _initialize_weights(self):
"""零初始化权重和偏置"""
nn.init.zeros_(self.conv.weight)
if self.conv.bias is not None:
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.conv(x)
class ZeroConv2d(nn.Module):
"""二维零卷积"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding)
self._initialize_weights()
def _initialize_weights(self):
"""零初始化权重和偏置"""
nn.init.zeros_(self.conv.weight)
if self.conv.bias is not None:
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.conv(x)
class ZeroConv3d(nn.Module):
"""三维零卷积"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super().__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding)
self._initialize_weights()
def _initialize_weights(self):
"""零初始化权重和偏置"""
nn.init.zeros_(self.conv.weight)
if self.conv.bias is not None:
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.conv(x)
class ZeroConvAdapter(nn.Module):
"""零卷积适配器模块"""
def __init__(self, in_channels, out_channels, hidden_ratio=4):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
hidden_channels = in_channels * hidden_ratio
# 下投影
self.down_proj = ZeroConv2d(in_channels, hidden_channels, 1)
# 非线性激活
self.activation = nn.GELU()
# 上投影
self.up_proj = ZeroConv2d(hidden_channels, out_channels, 1)
# 缩放因子
self.scale = nn.Parameter(torch.zeros(1))
def forward(self, x):
"""
Args:
x: 输入张量 [B, C, H, W]
Returns:
适配器输出
"""
identity = x
# 适配器路径
h = self.down_proj(x)
h = self.activation(h)
h = self.up_proj(h)
# 缩放并残差连接
return identity + self.scale * h
def demonstrate_zero_conv():
"""演示零卷积"""
print("零卷积演示")
print("=" * 50)
# 1D 零卷积测试
print("1D 零卷积:")
batch_size, in_channels, seq_len = 4, 32
, 100
x_1d = torch.randn(batch_size, in_channels, seq_len)
zero_conv_1d = ZeroConv1d(in_channels, 64, kernel_size=3, padding=1)
output_1d = zero_conv_1d(x_1d)
print(f" 输入形状: {x_1d.shape}")
print(f" 输出形状: {output_1d.shape}")
print(f" 初始输出均值: {output_1d.mean().item():.6f}")
print(f" 初始输出标准差: {output_1d.std().item():.6f}")
# 2D 零卷积测试
print(f"\n2D 零卷积:")
batch_size, in_channels, height, width = 4, 3, 32, 32
x_2d = torch.randn(batch_size, in_channels, height, width)
zero_conv_2d = ZeroConv2d(in_channels, 64, kernel_size=3, padding=1)
output_2d = zero_conv_2d(x_2d)
print(f" 输入形状: {x_2d.shape}")
print(f" 输出形状: {output_2d.shape}")
print(f" 初始输出均值: {output_2d.mean().item():.6f}")
print(f" 初始输出标准差: {output_2d.std().item():.6f}")
# 验证零初始化
print(f"\n零初始化验证:")
weight_norm = zero_conv_2d.conv.weight.norm().item()
bias_norm = zero_conv_2d.conv.bias.norm().item() if zero_conv_2d.conv.bias is not None else 0
print(f" 权重范数: {weight_norm:.6f}")
print(f" 偏置范数: {bias_norm:.6f}")
print(f" 输出与零张量的差异: {torch.abs(output_2d).max().item():.6f}")
# 零卷积适配器演示
print(f"\n零卷积适配器
:")
adapter = ZeroConvAdapter(64, 64)
adapter_output = adapter(x_2d)
print(f" 适配器输入形状: {x_2d.shape}")
print(f" 适配器输出形状: {adapter_output.shape}")
print(f" 初始适配器输出均值: {adapter_output.mean().item():.6f}")
print(f" 缩放因子: {adapter.scale.item():.6f}")
# 训练过程中的行为
print(f"\n训练行为模拟:")
# 模拟一步训练
optimizer = torch.optim.SGD(zero_conv_2d.parameters(), lr=0.1)
loss = output_2d.sum() # 简单的损失函数
loss.backward()
optimizer.step()
# 检查训练后的输出
output_after_train = zero_conv_2d(x_2d)
print(f" 训练一步后输出均值: {output_after_train.mean().item():.6f}")
print(f" 训练一步后权重范数: {zero_conv_2d.conv.weight.norm().item():.6f}")
return {
'zero_conv_1d': zero_conv_1d,
'zero_conv_2d': zero_conv_2d,
'adapter': adapter,
'outputs': {
'1d': output_1d,
'2d': output_2d,
'adapter': adapter_output,
'after_train': output_after_train
}
}
zero_conv_demo = demonstrate_zero_conv()
*/ 在 ControlNet 中的应用
class ZeroConvControlNetBlock(nn.Module):
"""ControlNet 风格的零卷积块"""
def __init__(self, in_channels, out_channels, condition_channels):
super().__init__()
# 主路径卷积
self.main_conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
# 条件路径 - 使用零卷积
self.condition_conv = ZeroConv2d(condition_channels, out_channels, 1)
# 融合门控
self.gate = nn.Parameter(torch.zeros(1))
# 输出零卷积
self.output_conv = ZeroConv2d(out_channels, out_channels, 1)
def forward(self, x, condition):
"""
Args:
x: 主输入 [B, C, H, W]
condition: 条件输入 [B, C_cond, H, W]
"""
# 主路径
main_out = self.main_conv(x)
# 条件路径
condition_out = self.condition_conv(condition)
# 融合
fused = main_out + self.gate * condition_out
# 输出变换
output = self.output_conv(fused)
return output
class ZeroConvResBlock(nn.Module):
"""零卷积残差块"""
def __init__(self, channels, condition_channels=None):
super().__init__()
# 第一个卷积层
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.norm1 = nn.GroupNorm(32, channels)
# 条件卷积(可选)
if condition_channels is not None:
self.condition_conv = ZeroConv2d(condition_channels, channels, 1)
else:
self.condition_conv = None
# 第二个卷积层
self.conv2 = ZeroConv2d(channels, channels, 3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.activation = nn.SiLU()
# 门控参数
self.gate = nn.Parameter(torch.zeros(1))
def forward(self, x, condition=None):
identity = x
# 第一个卷积
h = self.conv1(x)
h = self.norm1(h)
# 条件融合
if condition is not None and self.condition_conv is not None:
cond_feat = self.condition_conv(condition)
h = h + self.gate * cond_feat
h = self.activation(h)
# 第二个卷积
h = self.conv2(h)
h = self.norm2(h)
# 残差连接
return identity + h
def demonstrate_controlnet_application():
"""演示 ControlNet 中的应用"""
print("ControlNet 中的零卷积应用")
print("=" * 50)
# 创建测试数据
batch_size, channels, height, width = 4, 64, 32, 32
condition_channels = 128
x = torch.randn(batch_size, channels, height, width)
condition = torch.randn(batch_size, condition_channels, height, width)
# ControlNet 块
controlnet_block = ZeroConvControlNetBlock(channels, channels, condition_channels)
output = controlnet_block(x, condition)
print(f"主输入形状: {x.shape}")
print(f"条件输入形状: {condition.shape}")
print(f"输出形状: {output.shape}")
# 验证初始输出
print(f"\n初始状态验证:")
print(f"条件路径输出均值: {controlnet_block.condition_conv(condition).mean().item():.6f}")
print(f"输出卷积输出均值: {controlnet_block.output_conv(torch.zeros_like(output)).mean().item():.6f}")
print(f"门控参数: {controlnet_block.gate.item():.6f}")
# 零卷积残差块
print(f"\n零卷积残差块:")
zero_res_block = ZeroConvResBlock(channels, condition_channels)
res_output = zero_res_block(x, condition)
print(f"残差块输入形状: {x.shape}")
print(f"残差块输出形状: {res_output.shape}")
print(f"初始残差输出均值: {res_output.mean().item():.6f}")
# 训练稳定性测试
print(f"\n训练稳定性测试:")
# 多次前向传播
with torch.no_grad():
outputs = []
for i in range(5):
output_i = controlnet_block(x, condition)
outputs.append(output_i.mean().item())
output_variance = torch.var(torch.tensor(outputs))
print(f"多次前向传播输出方差: {output_v
ariance.item():.6f}")
return {
'controlnet_block': controlnet_block,
'zero_res_block': zero_res_block,
'outputs': {
'controlnet': output,
'residual': res_output
}
}
controlnet_demo = demonstrate
_controlnet_application()
优势与特点
- /稳定初始化/:确保网络从恒等映射开始,训练过程更稳定
- /渐进式学习/:模型逐渐学习条件信息的影响,避免剧烈变化
- /模块化设计/:便于插入到现有架构中作为适配器
- /训练友好/:减少训练初期的梯度爆炸风险
应用场景
- /ControlNet/:在稳定扩散模型中添加空间条件控制
- /模型微调/:作为适配器模块进行参数高效微调
- /多模态融合/:融合不同模态的特征表示
- /渐进式训练/:从简单任务逐渐过渡到复杂任务
注意事项
- /学习率调整/:零卷积层
可能需要不同的学习率调度
- /梯度流/:确保零初始化不会阻碍梯度传播
- /收敛速度/:初始阶段学习较慢,需要适当训练轮数
- /参数初始化/:与其他层的
初始化策略协调
反卷积
反卷积(Deconvolution),也称为转置卷积(Transposed Convolution)或分数步长卷积(Fractionally-strided Convolution),是一种用于特征图上采样的操作。
数学原理
对于输入特征图 \( X \in \mathbb{R}^{C_{in} \times H_{in} \times W_{in}} \) 和卷积核 \( K \in \mathbb{R}^{C_{out} \times C_{in} \times k_h \times k_w} \),反卷积的输出计算为:
其中:
- \( s \) 为步长(stride)
- \( p \) 为填充(padding)
- 输出尺寸:\( H_{out} = s \cdot (H_{in} - 1) + k_h - 2p \)
实现细节
class DeconvolutionDemo:
"""反卷积演示类"""
def __init__(self):
pass
def demonstrate_deconvolution(self):
"""演示反卷积操作"""
print("反卷积演示")
print("=" * 50)
# 输入特征图
batch_size, in_channels, height, width = 4, 32, 8, 8
out_channels = 64
x = torch.randn(batch_size, in_channels, height, width)
print(f"输入形状: {x.shape}")
# 不同配置的反卷积
configurations = [
{'kernel_size': 3, 'stride': 2, 'padding': 1, 'output_padding': 1},
{'kernel_size': 4, 'stride': 2, 'padding': 1},
{'kernel_size': 3, 'stride': 3, 'padding': 1, 'output_padding': 2}
]
for i, config in enumerate(configurations):
deconv = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=config['kernel_size'],
stride=config['stride'],
padding=config['padding'],
output_padding=config.get('output_padding', 0)
)
output = deconv(x)
print(f"配置 {i+1}: 核大小 {config['kernel_size']}, 步长 {config['stride']}")
print(f" 输出形状: {output.shape}")
# 计算参数量
params = sum(p.numel() for p in deconv.parameters())
print(f" 参数量: {params:,}")
return deconv, output
def compare_upsampling_methods(self):
"""比较不同上采样方法"""
print(f"\n上采样方法比较")
print("=" * 50)
# 测试输入
x = torch.randn(2, 32, 16, 16)
target_size = (32, 32)
print(f"输入形状: {x.shape}")
print(f"目标输出形状: (2, 32, {target_size[0]}, {target_size[1]})")
# 1. 最近邻插值
nearest = F.interpolate(x, size=target_size, mode='nearest')
# 2. 双线性插值
bilinear = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
# 3. 双三次插值
bicubic = F.interpolate(x, size=target_size, mode='bicubic', align_corners=False)
# 4. 反卷积
deconv = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)
deconv_output = deconv(x)
print(f"\n不同方法输出形状:")
print(f" 最近邻插值: {nearest.shape}")
print(f" 双线性插值: {bilinear.shape}")
print(f" 双三次插值: {bicubic.shape}")
print(f" 反卷积: {deconv_output.shape}")
# 计算计算量(FLOPs)
def calculate_flops_2d(input_shape, output_shape, kernel_size, groups=1):
"""计算2D卷积的FLOPs"""
batch, in_c, h, w = input_shape
_, out_c, out_h, out_w = output_shape
# 每个输出位置的计算量
flops_per_position = in_c * kernel_size * kernel_size / groups
total_flops = batch * out_c * out_h * out_w * flops_per_position * 2 # 乘加操作
return total_flops
deconv_flops = calculate_flops_2d(x.shape, deconv_output.shape, 4)
print(f"\n计算量比较:")
print(f" 反卷积 FLOPs: {deconv_flops:,.0f}")
return {
'nearest': nearest,
'bilinear': bilinear,
'bicubic': bicubic,
'deconv': deconv_output
}
# 演示反卷积
deconv_demo = DeconvolutionDemo()
deconv_results = deconv_demo.demonstrate_deconvolution()
upsampling_comparison = deconv_demo.compare_upsampling_methods()
*/ 棋盘格效应
反卷积在生成图像时可能产生棋盘格效应(checkerboard artifacts),这是由于不均匀的重叠模式造成的。
class CheckerboardAnalysis:
"""棋盘格效应分析"""
def __init__(self):
pass
def analyze_checkerboard_artifacts(self):
"""分析棋盘格效应"""
print("棋盘格效应分析")
print("=" * 50)
# 创建均匀输入
x = torch.ones(1, 1, 4, 4)
# 不同核大小的反卷积
kernel_sizes = [2, 3, 4]
for kernel_size in kernel_sizes:
# 创建反卷积层
deconv = nn.ConvTranspose2d(
1, 1,
kernel_size=kernel_size,
stride=2,
padding=0 if kernel_size % 2 == 0 else 1
)
# 使用均匀权重
nn.init.constant_(deconv.weight, 1.0)
if deconv.bias is not None:
nn.init.constant_(deconv.bias, 0.0)
output = deconv(x)
print(f"核大小 {kernel_size}:")
print(f" 输入形状: {x.shape}")
print(f" 输出形状: {output.shape}")
print(f" 输出值范围: [{output.min().item():.3f}, {output.max().item():.3f}]")
# 计算重叠模式
overlap_pattern = self.calculate_overlap_pattern(output)
print(f" 重叠模式方差: {overlap_pattern.var().item():.3f}")
def calculate_overlap_pattern(self, output):
"""计算输出中的重叠模式"""
# 简化的重叠模式分析
return output
def demonstrate_solutions(self):
"""演示解决棋盘格效应的方法"""
print(f"\n棋盘格效应解决方案")
print("=" * 50)
x = torch.ones(1, 32, 8, 8)
# 解决方案1: 使用最近邻插值 + 卷积
print("方案1: 插值 + 卷积")
upsampled = F.interpolate(x, scale_factor=2, mode='nearest')
conv_after_upsample = nn.Conv2d(32, 32, kernel_size=3, padding=1)(upsampled)
print(f" 输出形状: {conv_after_upsample.shape}")
# 解决方案2: 调整核大小和步长
print("方案2: 调整核参数")
deconv_adjusted = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)
adjusted_output = deconv_adjusted(x)
print(f" 输出形状: {adjusted_output.shape}")
# 解决方案3: 使用PixelShuffle
print("方案3: PixelShuffle")
pixelshuffle = nn.PixelShuffle(2)
# 首先将通道数增加到 32 * 4
conv_before_shuffle = nn.Conv2d(32, 32 * 4, kernel_size=3, padding=1)(x)
shuffle_output = pixelshuffle(conv_before_shuffle)
print(f" 输出形状: {shuffle_output.shape}")
checkerboard_analysis = CheckerboardAnalysis()
checkerboard_analysis.analyze_checkerboard_artifacts()
checkerboard_analysis.demonstrate_solutions()
*/ 应用场景
- /图像超分辨率/:将低分辨率图像上采样到高分辨率
- /语义分割/:在编码器-解码器架构中恢复空间分辨率
- /生成对抗网络/:从潜在向量生成图像
- /自编码器/:在解码器中重建输入尺寸
*/ 与其他上采样方法的比较
| 方法 | 可学习参数 | 计算成本 | 棋盘格效应 | 适用场景 |
|---|---|---|---|---|
| 最近邻插值 | 无 | 低 | 无 | 实时应用 |
| 双线性插值 | 无 | 低 | 无 | 一般上采样 |
| 反卷积 | 有 | 高 | 可能 | 需要学习的上采样 |
| PixelShuffle | 有 | 中 | 无 | 高质量上采样 |
*/ 最佳实践
- /核大小选择/:使用能被步长整除的核大小以减少棋盘格效应
- /初始化策略/:使用双线性插值初始化反卷积权重
- /结合其他方法/:可先插值再卷积以获得更好效果
- /监控输出/:训练过程中检查是否出现棋盘格模式
Modulation
调制(Modulation)是一种在生成模型和条件生成任务中广泛使用的技术,通过外部条件信息来调整网络的行为和特征表示。
基本概念
调制通过引入条件信息 \( c \) 来调整网络权重或激活值,使模型能够根据输入条件生成不同的输出:
其中 \( \theta( c) \) 是根据条件 \( c \) 动态生成的网络参数。
条件批归一化(Conditional Batch Normalization)
在批归一化中引入条件信息:
其中 \( \gamma( c), \beta( c) \) 是根据条件 \( c \) 生成的缩放和偏移参数。
class ConditionalBatchNorm2d(nn.Module):
"""条件批归一化"""
def __init__(self, num_features, condition_dim, hidden_dim=128):
super().__init__()
self.num_features = num_features
self.condition_dim = condition_dim
# 主批归一化层
self.bn = nn.BatchNorm2d(num_features, affine=False)
# 条件映射网络
self.condition_net = nn.Sequential(
nn.Linear(condition_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 2 * num_features)
)
def forward(self, x, condition):
"""
Args:
x: 输入张量 [B, C, H, W]
condition: 条件向量 [B, condition_dim]
"""
# 标准批归一化
x_normalized = self.bn(x)
# 根据条件生成参数
condition_params = self.condition_net(condition) # [B, 2*C]
gamma = condition_params[:, :self.num_features].unsqueeze(-1).unsqueeze(-1) # [B, C, 1, 1]
beta = condition_params[:, self.num_features:].unsqueeze(-1).unsqueeze(-1) # [B, C, 1, 1]
# 应用调制
return gamma * x_normalized + beta
def demonstrate_conditional_bn():
"""演示条件批归一化"""
print("条件批归一化演示")
print("=" * 50)
batch_size, channels, height, width = 4, 32, 16, 16
condition_dim = 64
x = torch.randn(batch_size, channels, height, width)
condition = torch.randn(batch_size, condition_dim)
conditional_bn = ConditionalBatchNorm2d(channels, condition_dim)
output = conditional_bn(x, condition)
print(f"输入形状: {x.shape}")
print(f"条件向量形状: {condition.shape}")
print(f"输出形状: {output.shape}")
# 验证不同条件产生不同输出
print(f"\n不同条件的影响:")
condition1 = torch.randn(batch_size, condition_dim)
condition2 = torch.randn(batch_size, condition_dim)
output1 = conditional_bn(x, condition1)
output2 = conditional_bn(x, condition2)
diff = torch.abs(output1 - output2).mean()
print(f" 不同条件输出的平均差异: {diff:.6f}")
return {
'conditional_bn': conditional_bn,
'outputs': {
'condition1': output1,
'condition2': output2
}
}
conditional_bn_demo = demonstrate_conditional_bn()
特征调制(Feature-wise Modulation)
通过条件信息直接调制特征图,如 FiLM(Feature-wise Linear Modulation):
其中 \( \gamma( c), \beta( c) \) 是根据条件 \( c \) 生成的调制参数。
class FiLMLayer(nn.Module):
"""FiLM(Feature-wise Linear Modulation)层"""
def __init__(self, feature_dim, condition_dim, hidden_dim=128):
super().__init__()
self.feature_dim = feature_dim
self.condition_dim = condition_dim
# 条件映射网络
self.condition_net = nn.Sequential(
nn.Linear(condition_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 2 * feature_dim)
)
def forward(self, x, condition):
"""
Args:
x: 输入特征 [..., feature_dim]
condition: 条件向量 [..., condition_dim]
"""
original_shape = x.shape
# 重塑为 [*, feature_dim]
x_flat = x.view(-1, self.feature_dim)
condition_flat = condition.view(-1, self.condition_dim)
# 生成调制参数
condition_params = self.condition_net(condition_flat) # [*, 2*feature_dim]
gamma = condition_params[:, :self.feature_dim]
beta = condition_params[:, self.feature_dim:]
# 应用调制
x_modulated = gamma * x_flat + beta
# 恢复原始形状
return x_modulated.view(original_shape)
class FiLMResBlock(nn.Module):
"""使用 FiLM 的残差块"""
def __init__(self, in_channels, out_channels, condition_dim, stride=1):
super().__init__()
# 第一个卷积层
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
# FiLM 调制
self.film1 = FiLMLayer(out_channels, condition_dim)
# 第二个卷积层
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.film2 = FiLMLayer(out_channels, condition_dim)
# 激活函数
self.relu = nn.ReLU(inplace=True)
# 下采样
self.downsample = None
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x, condition):
identity = x
# 主路径
out = self.conv1(x)
out = self.bn1(out)
out = self.film1(out, condition) # FiLM 调制
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.film2(out, condition) # FiLM 调制
# 捷径
if self.downsample is not None:
identity = self.downsample(x)
# 残差连接
out += identity
out = self.relu(out)
return out
def demonstrate_film():
"""演示 FiLM 调制"""
print("FiLM 调制演示")
print("=" * 50)
# 基础 FiLM 测试
batch_size, feature_dim, condition_dim = 4, 64, 32
x = torch.randn(batch_size, feature_dim)
condition = torch.randn(batch_size, condition_dim)
film_layer = FiLMLayer(feature_dim, condition_dim)
output = film_layer(x, condition)
print(f"输入特征形状: {x.shape}")
print(f"条件向量形状: {condition.shape}")
print(f"输出特征形状: {output.shape}")
# FiLM 残差块测试
print(f"\nFiLM 残差块演示:")
batch_size, in_channels, height, width = 4, 32, 16, 16
condition_dim = 64
x_conv = torch.randn(batch_size, in_channels, height, width)
condition_conv = torch.randn(batch_size, condition_dim)
film_resblock = FiLMResBlock(in_channels, 64, condition_dim, stride=1)
output_conv = film_resblock(x_conv, condition_conv)
print(f"卷积输入形状: {x_conv.shape}")
print(f"卷积输出形状: {output_conv.shape}")
# 条件敏感性测试
print(f"\n条件敏感性测试:")
# 相同输入,不同条件
condition_a = torch.randn(batch_size, condition_dim)
condition_b = torch.randn(batch_size, condition_dim)
output_a = film_resblock(x_conv, condition_a)
output_b = film_resblock(x_conv, condition_b)
diff = torch.abs(output_a - output_b).mean()
print(f" 不同条件输出的平均差异: {diff:.6f}")
return {
'film_layer': film_layer,
'film_resblock': film_resblock,
'outputs': {
'basic': output,
'conv': output_conv,
'condition_a': output_a,
'condition_b': output_b
}
}
film_demo = demonstrate_film()
SFT
SFT(Scale and Shift Transform)是一种轻量级的特征调制方法,通过缩放(Scale)和 偏移(Shift)操作将条件信息融入特征表示中,广泛应用于图像超分辨率、风格迁移等任 务。
数学定义
对于输入特征 \( x \in \mathbb{R}^{C \times H \times W} \) 和条件信息 \( c \),SFT 计算:
其中:
- \( \gamma( c), \beta( c) \in \mathbb{R}^{C \times H \times W} \) 是根据条件 \( c \) 生成的缩放和偏移参数
- \( \odot \) 表示逐元素相乘
网络结构
SFT 通常包含:
- 条件编码网络:将条件信息映射到调制参数
- 特征提取网络:处理输入特征
- SFT 层:应用缩放和偏移调制
实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class SFTLayer(nn.Module):
"""SFT(Scale and Shift Transform)层"""
def __init__(self, channels, condition_dim, hidden_dim=64):
"""
参数:
channels: 输入特征通道数
condition_dim: 条件向量维度
hidden_dim: 隐藏层维度
"""
super().__init__()
self.channels = channels
# 条件映射网络,生成缩放和偏移参数
self.condition_net = nn.Sequential(
nn.Linear(condition_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 2 * channels)
)
def forward(self, x, condition):
"""
Args:
x: 输入特征 [B, C, H, W]
condition: 条件向量 [B, condition_dim]
Returns:
调制后的特征 [B, C, H, W]
"""
B, C, H, W = x.shape
# 生成调制参数
condition_params = self.condition_net(condition) # [B, 2*C]
gamma = condition_params[:, :self.channels].view(B, C, 1, 1) # [B, C, 1, 1]
beta = condition_params[:, self.channels:].view(B, C, 1, 1) # [B, C, 1, 1]
# 应用缩放和偏移
return gamma * x + beta
class SFTResidualBlock(nn.Module):
"""使用 SFT 的残差块"""
def __init__(self, channels, condition_dim, hidden_dim=64):
super().__init__()
# 第一个卷积层
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
# SFT 调制层
self.sft1 = SFTLayer(channels, condition_dim, hidden_dim)
# 第二个卷积层
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
# 第二个 SFT 调制层
self.sft2 = SFTLayer(channels, condition_dim, hidden_dim)
# 激活函数
self.relu = nn.ReLU(inplace=True)
def forward(self, x, condition):
identity = x
# 主路径
out = self.conv1(x)
out = self.bn1(out)
out = self.sft1(out, condition) # SFT 调制
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.sft2(out, condition) # SFT 调制
# 残差连接
out += identity
out = self.relu(out)
return out
class SFTNet(nn.Module):
"""完整的 SFT 网络示例(用于图像超分辨率)"""
def __init__(self, in_channels=3, out_channels=3, base_channels=64,
condition_dim=32, num_blocks=16):
super().__init__()
# 特征提取
self.feature_extract = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
# SFT 残差块
self.sft_blocks = nn.ModuleList([
SFTResidualBlock(base_channels, condition_dim)
for _ in range(num_blocks)
])
# 重建层
self.reconstruct = nn.Sequential(
nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(base_channels, out_channels, kernel_size=3, padding=1)
)
def forward(self, x, condition):
# 特征提取
features = self.feature_extract(x)
# 通过 SFT 块
for sft_block in self.sft_blocks:
features = sft_block(features, condition)
# 重建
output = self.reconstruct(features)
return output
def demonstrate_sft():
"""演示 SFT 调制"""
print("SFT(Scale and Shift Transform)演示")
print("=" * 50)
# 基础 SFT 测试
batch_size, channels, height, width = 4, 32, 16, 16
condition_dim = 64
x = torch.randn(batch_size, channels, height, width)
condition = torch.randn(batch_size, condition_dim)
sft_layer = SFTLayer(channels, condition_dim)
output = sft_layer(x, condition)
print(f"输入特征形状: {x.shape}")
print(f"条件向量形状: {condition.shape}")
print(f"SFT 输出形状: {output.shape}")
# 验证调制效果
print(f"\n调制效果验证:")
print(f" 输入均值: {x.mean().item():.4f}, 标准差: {x.std().item():.4f}")
print(f" 输出均值: {output.mean().item():.4f}, 标准差: {output.std().item():.4f}")
# SFT 残差块测试
print(f"\nSFT 残差块演示:")
sft_resblock = SFTResidualBlock(channels, condition_dim)
output_res = sft_resblock(x, condition)
print(f" 残差块输入形状: {x.shape}")
print(f" 残差块输出形状: {output_res.shape}")
# 条件敏感性测试
print(f"\n条件敏感性测试:")
# 相同输入,不同条件
condition1 = torch.randn(batch_size, condition_dim)
condition2 = torch.randn(batch_size, condition_dim)
output1 = sft_layer(x, condition1)
output2 = sft_layer(x, condition2)
diff = torch.abs(output1 - output2).mean()
print(f" 不同条件输出的平均差异: {diff:.6f}")
# 完整 SFT 网络测试
print(f"\n完整 SFT 网络演示:")
sft_net = SFTNet(in_channels=3, out_channels=3, condition_dim=condition_dim)
rgb_input = torch.randn(batch_size, 3, 32, 32)
rgb_output = sft_net(rgb_input, condition)
print(f" 网络输入形状: {rgb_input.shape}")
print(f" 网络输出形状: {rgb_output.shape}")
return {
'sft_layer': sft_layer,
'sft_resblock': sft_resblock,
'sft_net': sft_net,
'outputs': {
'basic': output,
'residual': output_res,
'network': rgb_output,
'condition1': output1,
'condition2': output2
}
}
sft_demo = demonstrate_sft()
与其它调制方法的比较
def compare_modulation_methods():
"""比较不同调制方法"""
print("调制方法比较")
print("=" * 50)
batch_size, channels, height, width = 4, 32, 16, 16
condition_dim = 64
x = torch.randn(batch_size, channels, height, width)
condition = torch.randn(batch_size, condition_dim)
# 不同调制方法
methods = {
'SFT': SFTLayer(channels, condition_dim),
'FiLM': FiLMLayer(channels, condition_dim),
'ConditionalBN': ConditionalBatchNorm2d(channels, condition_dim)
}
results = {}
print(f"输入形状: {x.shape}")
print(f"条件向量形状: {condition.shape}")
print(f"\n不同调制方法输出:")
for name, method in methods.items():
if name == 'ConditionalBN':
output = method(x, condition)
else:
output = method(x, condition)
print(f" {name:<15} - 形状: {output.shape}, 均值: {output.mean():.4f}, 标准差: {output.std():.4f}")
results[name] = {
'method': method,
'output': output,
'params': sum(p.numel() for p in method.parameters())
}
# 参数量比较
print(f"\n参数量比较:")
for name, result in results.items():
print(f" {name:<15}: {result['params']:,} 参数")
# 计算效率比较
print(f"\n计算效率比较 (100次前向传播):")
import time
for name, result in results.items():
method = result['method']
# 预热
if name == 'ConditionalBN':
_ = method(x, condition)
else:
_ = method(x, condition)
# 计时
start = time.time()
for _ in range(100):
if name == 'ConditionalBN':
_ = method(x, condition)
else:
_ = method(x, condition)
elapsed = time.time() - start
print(f" {name:<15}: {elapsed:.4f}s")
return results
modulation_comparison = compare_modulation_methods()
*/ 应用场景
- /图像超分辨率/:根据退化信息调制特征
- /风格迁移/:根据风格条件调整内容特征
- /条件图像生成/:根据语义条件生成图像
- /多模态学习/:融合不同模态的信息
*/ 优势与特点
优势:
- 计算轻量,仅增加少量参数
- 灵活性强,可适应各种条件信息
- 易于集成到现有架构中
- 在低层特征和高层特征中都有效
特点:
- 保持特征空间结构
- 允许细粒度的特征控制
- 支持端到端训练
- 对条件信息敏感
Gate
门控机制(Gating Mechanism)是深度学习中的关键技术,通过可学习的开关控制信息流动, 在循环神经网络、注意力机制和现代Transformer架构中发挥重要作用。
基本概念
门控机制通过sigmoid激活函数生成0到1之间的门控值,控制信息的保留与遗忘:
其中 \( \sigma \) 是sigmoid函数,\( \odot \) 表示逐元素相乘。
常见门控类型
LSTM 门控
长短期记忆网络(LSTM)包含三个门控单元:
- /输入门/:控制新信息的流入
- /遗忘门/:控制旧信息的遗忘
- /输出门/:控制信息的输出
数学表达式:
\begin{aligned} i_t &= \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ f_t &= \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \\ o_t &= \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \end{aligned}
GRU 门控
门控循环单元(GRU)包含两个门控:
- /重置门/:控制历史信息的重置
- /更新门/:控制信息的更新比例
数学表达式:
门控线性单元(GLU)
GLU通过sigmoid门控控制线性变换的输出:
实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class GLU(nn.Module):
"""门控线性单元(Gated Linear Unit)"""
def __init__(self, input_dim, output_dim=None):
super().__init__()
if output_dim is None:
output_dim = input_dim
self.linear = nn.Linear(input_dim, 2 * output_dim)
self.gate_activation = nn.Sigmoid()
def forward(self, x):
"""
Args:
x: 输入张量 [..., input_dim]
Returns:
门控输出 [..., output_dim]
"""
# 线性变换
projected = self.linear(x)
# 分割为值和门控
value, gate = torch.chunk(projected, 2, dim=-1)
# 应用门控
return value * self.gate_activation(gate)
class GatedResidualBlock(nn.Module):
"""门控残差块"""
def __init__(self, hidden_dim, dropout=0.1):
super().__init__()
self.hidden_dim = hidden_dim
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, 4 * hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(4 * hidden_dim, hidden_dim),
nn.Dropout(dropout)
)
# 门控权重
self.gate_weight = nn.Parameter(torch.zeros(1, hidden_dim))
# 层归一化
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
"""
Args:
x: 输入张量 [..., hidden_dim]
Returns:
门控残差输出
"""
residual = x
# 前馈网络
ffn_output = self.ffn(self.norm(x))
# 计算门控值
gate = torch.sigmoid(self.gate_weight)
# 门控残差连接
output = gate * ffn_output + (1 - gate) * residual
return output
class MultiHeadGatedAttention(nn.Module):
"""多头门控注意力"""
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
assert self.head_dim * n_heads == d_model, "d_model必须能被n_heads整除"
# QKV投影
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
# 输出投影
self.w_o = nn.Linear(d_model, d_model)
# 门控参数
self.gate = nn.Parameter(torch.zeros(1, 1, d_model))
self.dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
def forward(self, x, mask=None):
"""
Args:
x: 输入序列 [batch_size, seq_len, d_model]
mask: 注意力掩码 [batch_size, seq_len, seq_len]
"""
batch_size, seq_len, d_model = x.shape
# QKV投影
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
# 注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 注意力权重
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 注意力输出
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
# 输出投影
output = self.w_o(attn_output)
# 门控残差连接
gate = torch.sigmoid(self.gate)
output = gate * output + (1 - gate) * x
return output
def demonstrate_gating():
"""演示门控机制"""
print("门控机制演示")
print("=" * 50)
# GLU 演示
print("GLU(门控线性单元):")
batch_size, seq_len, hidden_dim = 4, 10, 64
x = torch.randn(batch_size, seq_len, hidden_dim)
glu = GLU(hidden_dim)
glu_output = glu(x)
print(f" 输入形状: {x.shape}")
print(f" 输出形状: {glu_output.shape}")
print(f" 输出范围: [{glu_output.min():.3f}, {glu_output.max():.3f}]")
# 门控残差块演示
print(f"\n门控残差块:")
gated_residual = GatedResidualBlock(hidden_dim)
residual_output = gated_residual(x)
print(f" 输入形状: {x.shape}")
print(f" 输出形状: {residual_output.shape}")
print(f" 门控权重均值: {torch.sigmoid(gated_residual.gate_weight).mean().item():.3f}")
# 多头门控注意力演示
print(f"\n多头门控注意力:")
gated_attention = MultiHeadGatedAttention(hidden_dim, n_heads=8)
attention_output = gated_attention(x)
print(f" 输入形状: {x.shape}")
print(f" 输出形状: {attention_output.shape}")
print(f" 注意力门控均值: {torch.sigmoid(gated_attention.gate).mean().item():.3f}")
# 门控效果分析
print(f"\n门控效果分析:")
# 测试不同输入的门控行为
test_inputs = [
torch.randn(1, hidden_dim), # 随机输入
torch.zeros(1, hidden_dim), # 零输入
torch.ones(1, hidden_dim) # 单位输入
]
input_names = ["随机输入", "零输入", "单位输入"]
for input_tensor, name in zip(test_inputs, input_names):
glu_out = glu(input_tensor)
gate_values = torch.sigmoid(glu.linear(input_tensor).chunk(2, dim=-1)[1])
print(f" {name}:")
print(f" 门控值范围: [{gate_values.min():.3f}, {gate_values.max():.3f}]")
print(f" 门控值均值: {gate_values.mean():.3f}")
return {
'glu': glu,
'gated_residual': gated_residual,
'gated_attention': gated_attention,
'outputs': {
'glu': glu_output,
'residual': residual_output,
'attention': attention_output
}
}
gating_demo = demonstrate_gating()
*/ 门控机制的优势
- /梯度流控制/:缓解梯度消失问题,改善深层网络训练
- /信息选择/:自适应地选择重要信息,抑制噪声
- /长期依赖/:在序列模型中更好地捕捉长期依赖关系
- /模型容量/:增加模型表达能力而不显著增加参数
*/ 应用场景
- /语言建模/:LSTM、GRU中的门控机制
- /机器翻译/:Transformer中的门控前馈网络
- /图像生成/:GAN中的门控卷积
- /推荐系统/:门控注意力网络
- /多模态学习/:跨模态信息门控融合
PixelShuffle
PixelShuffle(像素重排)是一种上采样操作,通过重新排列特征图的通道维度来增加空间 分辨率,广泛应用于图像超分辨率、图像生成等任务。
基本概念
PixelShuffle 通过周期性的重排操作将通道维度中的信息重新组织到空间维度,实现高效 的上采样:
其中 \( r \) 是上采样因子。
数学定义
给定输入张量 \( X \in \mathbb{R}^{C \times H \times W} \) 和上采样因子 \( r \),输出计算为:
更直观地,将输入通道维度视为 \( r \times r \) 个块,每个块对应输出特征图的一个空间位置。
实现原理
import torch
import torch.nn as nn
import torch.nn.functional as F
def manual_pixel_shuffle(x, upscale_factor):
"""
手动实现 PixelShuffle
Args:
x: 输入张量 [B, C, H, W]
upscale_factor: 上采样因子
Returns:
重排后的张量 [B, C/(r^2), r*H, r*W]
"""
batch_size, channels, height, width = x.shape
r = upscale_factor
# 检查通道数是否可被 r^2 整除
assert channels % (r * r) == 0, f"通道数 {channels} 必须能被 {r*r} 整除"
# 重塑为 [B, C/(r^2), r, r, H, W]
out_channels = channels // (r * r)
x_reshaped = x.view(batch_size, out_channels, r, r, height, width)
# 置换维度为 [B, C/(r^2), H, r, W, r]
x_permuted = x_reshaped.permute(0, 1, 4, 2, 5, 3)
# 重塑为 [B, C/(r^2), r*H, r*W]
output = x_permuted.contiguous().view(batch_size, out_channels, height * r, width * r)
return output
class PixelShuffleBlock(nn.Module):
"""PixelShuffle 上采样块"""
def __init__(self, in_channels, out_channels, upscale_factor):
super().__init__()
self.upscale_factor = upscale_factor
# 卷积层将通道数扩展到 r^2 * out_channels
self.conv = nn.Conv2d(
in_channels,
out_channels * (upscale_factor ** 2),
kernel_size=3,
padding=1
)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
x = self.pixel_shuffle(x)
return x
class ESPCN(nn.Module):
"""ESPCN(Efficient Sub-Pixel Convolutional Network)示例"""
def __init__(self, in_channels=3, out_channels=3, upscale_factor=2, base_channels=64):
super().__init__()
self.upscale_factor = upscale_factor
# 特征提取层
self.feature_extraction = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=5, padding=2),
nn.Tanh(),
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
nn.Tanh(),
)
# 子像素卷积层
self.subpixel_conv = nn.Conv2d(
base_channels // 2,
out_channels * (upscale_factor ** 2),
kernel_size=3,
padding=1
)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
# 特征提取
features = self.feature_extraction(x)
# 子像素卷积和像素重排
output = self.subpixel_conv(features)
output = self.pixel_shuffle(output)
return output
def
demonstrate_pixel_shuffle():
"""演示 PixelShuffle 操作"""
print("PixelShuffle 演示")
print("=" * 50)
# 基础 PixelShuffle 测试
batch_size, channels, height, width = 4, 16, 8, 8
upscale_factor = 2
x = torch.randn(batch_size, channels, height, width)
print(f"输入张量形状: {x.shape}")
print(f"上采样因子: {upscale_factor}")
# 使用 PyTorch 内置函数
pixel_shuffle = nn.PixelShuffle(upscale_factor)
output_pt = pixel_shuffle(x)
print(f"PyTorch PixelShuffle 输出形状: {output_pt.shape}")
# 手动实现验证
output_manual = manual_pixel_shuffle(x, upscale_factor)
print(f"手动实现输出形状: {output_manual.shape}")
# 验证两种实现的一致性
diff = torch.abs(output_pt - output_manual).max()
print(f"两种实现的最大差异: {diff.item():.6f}")
# PixelShuffle 块演示
print(f"\nPixelShuffle 块演示:")
shuffle_block = PixelShuffleBlock(16, 8, upscale_factor=2)
block_output = shuffle_block(x)
print(f"块输入形状: {x.shape}")
print(f"块输出形状: {block_output.shape}")
print(f"卷积层权重形状: {shuffle_block.conv.weight.shape}")
# ESPCN 网络演示
print(f"\nESPCN 网络演示:")
espcn = ESPCN(in_channels=3, out_channels=3, upscale_factor=2)
rgb_input = torch.randn(batch_size, 3, 16, 16)
rgb_output = espcn(rgb_input)
print(f"网络输入形状: {rgb_input.shape}")
print(f"网络输出形状: {rgb_output.shape}")
# 不同上采样因子的效果
print(f"\n不同上采样因子的效果:")
upscale_factors = [2, 3, 4]
for factor in upscale_factors:
# 调整输入通道数以匹配要求
adjusted_channels = 9 * factor * factor # 确保可被 r^2 整除
test_input = torch.randn(batch_size, adjusted_channels, 8, 8)
shuffle_layer = nn.PixelShuffle(factor)
test_output = shuffle_layer(test_input)
print(f" 上采样因子 {factor}: 输入 {test_input.shape} -> 输出 {test_output.shape}")
return {
'pixel_shuffle': pixel_shuffle,
'shuffle_block': shuffle_block,
'espcn': espcn,
'outputs': {
'pytorch': output_pt,
'manual': output_manual,
'block': block_output,
'espcn': rgb_output
}
}
pixel_shuffle_demo = demonstrate_pixel_shuffle()
与其它上采样方法比较
def compare_upsampling_methods():
"""比较不同上采样方法"""
print("上采样方法比较")
print("=" * 50)
batch_size, channels, height, width = 4, 32, 16, 16
upscale_factor = 2
x = torch.randn(batch_size, channels, height, width)
print(f"输入形状: {x.shape}")
print(f"目标上采样因子: {upscale_factor}")
# 不同上采样方法
methods = {
'最近邻插值': lambda x: F.interpolate(x, scale_factor=upscale_factor, mode='nearest'),
'双线性插值': lambda x: F.interpolate(x, scale_factor=upscale_factor, mode='bilinear', align_corners=False),
'双三次插值': lambda x: F.interpolate(x, scale_factor=upscale_factor, mode='bicubic', align_corners=False),
'转置卷积': nn.ConvTranspose2d(channels, channels, kernel_size=4, stride=upscale_factor, padding=1),
'PixelShuffle': nn.PixelShuffle(upscale_factor)
}
results = {}
print(f"\n不同上采样方法输出:")
for name, method in methods.items():
if name == 'PixelShuffle':
# 调整输入通道数
adjusted_channels = channels // (upscale_factor ** 2) * (upscale_factor ** 2)
if adjusted_channels == 0:
adjusted_channels = upscale_factor ** 2
test_input = x[:, :adjusted_channels, :, :]
output = method(test_input)
else:
output = method(x)
print(f" {name:<15} - 输出形状: {output.shape}")
# 计算计算量(近似)
if hasattr(method, 'weight'):
params = method.weight.numel() + (method.bias.numel() if method.bias is not None else 0)
else:
params = 0
results[name] = {
'output': output,
'params': params,
'method': method
}
#
计算效率比较
print(f"\n计算效率比较 (100次前向传播):")
import time
for name, result in results.items():
method = result['method']
if name == 'PixelShuffle':
test_input = x[:, :adjusted_channels, :, :]
else:
test_input = x
# 预热
_ = method(test_input)
# 计时
start = time.time()
for _ in range(100):
_ = method(test_input)
elapsed = time.time() - start
print(f" {name:<15}: {elapsed:.4f}s, 参数量: {result['params']:,}")
# 输出质量比较(使用简单的图像相似度指标)
print(f"\n输出质量比较 (使用PSNR):")
# 创建简单的测试图像
test_image = torch.zeros(1, 1, 8, 8)
test_image[0, 0, 3:5, 3:5] = 1.0 # 中心方块
target_size = (16, 16)
for name, method in methods.items():
if name == 'PixelShuffle':
# 为 PixelShuffle 准备合适的输入
ps_input = torch.randn(1, 4, 8, 8) # 4 = 1 * 2^2
output = method(ps_input)
else:
output = method(test_image)
# 调整到目标尺寸(如果需要)
if output.shape[-2:] != target_size:
output = F.interpolate(output, size=target_size, mode='bilinear')
# 计算与双线性插值的差异(作为
质量参考)
reference = F.interpolate(test_image, size=target_size, mode='bilinear')
mse = F.mse_loss(output, reference)
psnr = 10 * torch.log10(1.0 / mse) if mse > 0 else float('inf')
print(f" {name:<15}: PSNR = {psnr:.2f} dB")
return results
upsampling_comparison = compare_upsampling_methods()
优势与特点
*/ 优势
- /计算效率/:相比转置卷积,计算量更小
- /无棋盘效应/:减少上采样过程中的棋盘状伪影
- /端到端可训练/:支持梯度反向传播
- /内存友好/:不需要存储大的卷积核
*/ 特点
- /通道要求/:输入通道数必须能被 \( r^2 \) 整除
- /信息重排/:通过通道维度的重排实现空间上采样
- /结合卷积/:通常与卷积层配合使用(子像素卷积)
应用场景
- /图像超分辨率/:ESPCN、SRCNN 等网络
- /图像生成/:GAN 中的上采样层
- /语义分割/:解码器中的上采样操作
- /风格迁移/:特征图的空间分辨率提升
- /视频超分/:视频帧的时空分辨率提升
变体与扩展
*/ PixelUnshuffle
PixelShuffle 的逆操作,用于下采样:
def demonstrate_pixel_unshuffle():
"""演示 PixelUnshuffle 操作"""
print("PixelUnshuffle 演示")
print("=" * 50)
batch_size, channels, height, width = 4, 8, 16, 16
downscale_factor = 2
x = torch.randn(batch_size, channels, height, width)
print(f"输入张量形状: {x.shape}")
print(f"下采样因子: {downscale_factor}")
# PixelUnshuffle
pixel_unshuffle = nn.PixelUnshuffle(downscale_factor)
output = pixel_unshuffle(x)
print(f"PixelUnshuffle 输出形状: {output.shape}")
# 验证可逆性
pixel_shuffle = nn.PixelShuffle(downscale_factor)
reconstructed = pixel_shuffle(output)
print(f"重建后形状: {reconstructed.shape}")
# 验证重建精度
diff = torch.abs(x - reconstructed).max()
print(f"重建最大误差: {diff.item():.6f}")
return {
'pixel_unshuffle': pixel_unshuffle,
'reconstructed': reconstructed,
'max_error': diff.item()
}
unshuffle_demo = demonstrate_pixel_unshuffle()
*/ 深度可分离 PixelShuffle
结合深度可分离卷积的变体,进一步减少计算量:
class DepthwiseSeparablePixelShuffle(nn.Module):
"""深度可分离 PixelShuffle"""
def __init__(self, in_channels, out_channels, upscale_factor):
super().__init__()
self.upscale_factor = upscale_factor
# 深度可分离卷积
self.depthwise = nn.Conv2d(
in_channels, in_channels, kernel_size=3,
padding=1, groups=in_channels
)
self.pointwise = nn.Conv2d(
in_channels, out_channels * (upscale_factor ** 2),
kernel_size=1
)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
x = self.activation(x)
x = self.pixel_shuffle(x)
return x
def compare_pixel_shuffle_variants():
"""比较 PixelShuffle 变体"""
print("PixelShuffle 变体比较")
print("=" * 50)
batch_size, in_channels, height, width = 4, 32, 16, 16
out_channels = 16
upscale_factor = 2
x = torch.randn(batch_size, in_channels, height, width)
variants = {
'标准 PixelShuffle': PixelShuffleBlock(in_channels, out_channels, upscale_factor),
'深度可分离 PixelShuffle': DepthwiseSeparablePixelShuffle(in_channels, out_channels, upscale_factor)
}
print(f"输入形状: {x.shape}")
print(f"目标输出通道: {out_channels}")
print(f"上采样因子: {upscale_factor}")
results = {}
for name, variant in variants.items():
output = variant(x)
# 计算参数量
params = sum(p.numel() for p in variant.parameters())
print(f"\n{name}:")
print(f" 输出形状: {output.shape}")
print(f" 参数量: {params:,}")
results[name] = {
'variant': variant,
'output': output,
'params': params
}
# 参数量减少比例
standard_params = results['标准 PixelShuffle']['params']
separable_params = results['深度可分离 PixelShuffle']['params']
reduction = (1 - separable_params / standard_params) * 100
print(f"\n参数量减少: {reduction:.1f}%")
return results
variants_comparison = compare_pixel_shuffle_variants()