ESC
输入关键词搜索文章
目录

模型层

池化层

激活函数

ReLU

ReLU(Rectified Linear Unit)是一种常用的激活函数,定义为:

$$f(x) = \max(0, x)$$

特点

变体

适用场景

Tanh

Tanh(双曲正切)激活函数:

$$\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$$

/特点/:

/导数/:

$$\frac{d}{dx}\tanh(x) = 1 - \tanh^2(x)$$

/适用场景/:

import numpy as np
import matplotlib.pyplot as plt

# 生成数据
x = np.linspace(-5, 5, 1000)
y = np.tanh(x)

# 绘制图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, 'b-', linewidth=2, label='tanh(x)')
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='-', alpha=0.3)
plt.grid(True, alpha=0.3)
plt.xlabel('x')
plt.ylabel('tanh(x)')
plt.title('Hyperbolic Tangent Function (tanh)')
plt.legend()
plt.axis([-5, 5, -1.2, 1.2])
fname = "images/tanh.png"
plt.savefig(fname)
fname

输出说明:

sigmoid

激活函数,将输入压缩到 (0,1):

$$\sigma(x) = \frac{1}{1 + e^{-x}}$$

sigmoid 函数的 python 表示

import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    """Sigmoid激活函数"""
    return 1 / (1 + np.exp(-x))

# 示例使用
x = np.linspace(-10, 10, 1000)
y = sigmoid(x)

# 绘制图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, 'r-', linewidth=2, label='sigmoid(x)')
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.axhline(y=1, color='k', linestyle='-', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='-', alpha=0.3)
plt.grid(True, alpha=0.3)
plt.xlabel('x')
plt.ylabel('sigmoid(x)')
plt.title('Sigmoid Activation Function')
plt.legend()
plt.axis([-10, 10, -0.1, 1.1])
plt.show()

# 输出特性
print(f"sigmoid(0) = {sigmoid(0):.3f}")  # 输出: 0.5
print(f"sigmoid(5) = {sigmoid(5):.3f}")  # 输出: 0.993
print(f"sigmoid(-5) = {sigmoid(-5):.3f}")  # 输出: 0.007

关键特性:

Softmax

多分类激活函数,将向量转换为概率分布:

$$\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}}$$

softmax python 实现

import numpy as np
import matplotlib.pyplot as plt

def softmax(x):
    """Softmax activation function with numerical stability"""
    exp_x = np.exp(x - np.max(x))
    return exp_x / np.sum(exp_x)

# Create sample data
x = np.linspace(-5, 5, 100)
scores = np.array([x, np.zeros_like(x), -x]).T  # Three classes

# Apply softmax
probabilities = np.array([softmax(score) for score in scores])

# Plot results
plt.figure(figsize=(10, 6))
for i in range(3):
    plt.plot(x, probabilities[:, i], label=f'Class {i+1}', linewidth=2)

plt.xlabel('Input Score Difference')
plt.ylabel('Probability')
plt.title('Softmax Function Visualization')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
fname = "images/softmax.png"
plt.savefig(fname)
plt.close()
fname
# Additional visualization: Single input case
single_scores = np.array([3.0, 1.0, 0.2])
single_probs = softmax(single_scores)

plt.figure(figsize=(8, 5))
plt.bar(range(len(single_scores)), single_probs, color=['skyblue', 'lightcoral', 'lightgreen'])
plt.xticks(range(len(single_scores)), [f'Class {i+1}' for i in range(len(single_scores))])
plt.ylabel('Probability')
plt.title('Softmax Output for Single Input Example')
plt.ylim(0, 1)
for i, prob in enumerate(single_probs):
    plt.text(i, prob + 0.02, f'{prob:.3f}', ha='center', va='bottom')
plt.tight_layout()
fname = "images/softmax_output_for_single_input_example.png"
plt.savefig(fname)
plt.close()
fname

每行概率和: [1. 1. 1.]

关键特性:

SiLU

SiLU(Sigmoid Linear Unit)激活函数,也称为Swish激活函数,是一种平滑、非单调的激活 函数,在深度学习中表现出优越的性能。

数学定义

SiLU激活函数定义为:

$$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$

其中 \( \sigma(x) \) 是sigmoid函数。

性质分析

  1. /平滑性/:SiLU是无限可微的平滑函数
  2. /非单调性/:当 \( x < 0 \) 时,函数值可能为负
  3. /下界无界/:当 \( x \to -\infty \)\( \text{SiLU}(x) \to 0 \)
  4. /上界无界/:当 \( x \to +\infty \)\( \text{SiLU}(x) \sim x \)

导数计算

SiLU的导数为:

$$\frac{d}{dx}\text{SiLU}(x) = \sigma(x) + x \cdot \sigma(x)(1 - \sigma(x)) = \text{SiLU}(x) + \sigma(x)(1 - \sigma(x))$$

实现


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class SiLU(nn.Module):
    """SiLU激活函数实现"""

    def __init__(self):
        super().__init__()

    def forward(self, x):
        """前向传播"""
        return x * torch.sigmoid(x)

    def derivative(self, x):
        """计算导数"""
        sigmoid_x = torch.sigmoid(x)
        return sigmoid_x * (1 + x * (1 - sigmoid_x))

class SiLUWithBeta(nn.Module):
    """带可学习参数的SiLU变体"""

    def __init__(self, beta=1.0, learnable=False):
        """
        参数:
            beta: 缩放参数
            learnable: 是否可学习
        """
        super().__init__()
        if learnable:
            self.beta = nn.Parameter(torch.tensor(float(beta)))
        else:
            self.register_buffer('beta', torch.tensor(float(beta)))

    def forward(self, x):
        """前向传播"""
        return x * torch.sigmoid(self.beta * x)

    def extra_repr(self):
        return f"beta={self.beta.item():.2f}, learnable={self.beta.requires_grad}"

def demonstrate_silu():
    """演示SiLU激活函数"""

    print("SiLU激活函数演示")
    print("=" * 50)

    # 创建测试数据
    x = torch.linspace(-4, 4,
 100)

    # 计算SiLU函数值
    silu = SiLU()
    y = silu(x)
    dy = silu.derivative(x)

    print(f"输入范围: [{x.min():.1
f}, {x.max():.1f}]")
    print(f"输出范围: [{y.min():.3f}, {y.max():.3f}]")
    print(f"导数范围: [{dy.min():.3f}, {dy.max():.3f}]")

    # 关键点分析
    critical_points = [-4, -2, 0, 2, 4]
    print(f"\n关键点分析:")
    for point in critical_points:
        x_point = torch.tensor(float(point))
        y_point = silu(x_point)
        dy_point = silu.derivative(x_point)
        print(f"  x={point:2.0f}: SiLU={y_point:.3f}, 导数={dy_point:.3f}")

    # 与其它激活函数比较
    relu = F.relu(x)
    leaky_relu = F.leaky_relu(x, 0.01)
    gelu = F.gelu(x)
    elu = F.elu(x)

    print(f"\n激活函数比较 (
x=0):")
    print(f"  SiLU: {silu(torch.tensor(0.0)):.3f}")
    print(f"  ReLU: {F.relu(torch.tensor(0.0)):.3f}")
    print(f"  GELU: {F.gelu(torch.tensor(0.0)):.3f}")
    print(f"  ELU: {F.elu(torch.tensor(0.0)):.3
f}")

    return {
        'x': x,
        'silu': y,
        'silu_derivative': dy,
        'comparison': {
            'relu': relu,
            'leaky_relu': leaky_relu,
            'gelu': gelu,
            'elu': elu
        }
    }

silu_demo = demonstrate_silu()

SwiGLU

Back to Basics: Let Denoising Generative Models Denoise (99+ 封私信 / 99+ 条消息) 详解SwiGLU激活函数 - 知乎

SwiGLU(x) = SiLU(W1x) ⊙ (W2x)

$$\text{SwiGLU}(a, b) = \text{Swith}(a)\otimes b$$

其中,

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, input_dim, bias=False)

    def forward(self, x):
        # 分割输入为两部分(假设输入维度为偶数)
        assert x.shape[-1] % 2 == 0, "输入维度需为偶数"
        a, b = x.chunk(2, dim=-1)

        # Swish门控计算
        gate = F.silu(self.w1(a))  # F.silu等价于Swish(β=1)
        filtered = gate * self.w2(b)
        return self.w3(filtered)

作用:

GeLU

GELU(Gaussian Error Linear Unit)是一种基于高斯误差函数的激活函数,结合了ReLU和dropout的思想,在Transformer等现代架构中广泛使用。

数学定义

GELU激活函数定义为:

$$\text{GELU}(x) = x \cdot \Phi (x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]$$

其中 \( \Phi(x) \) 是标准正态分布的累积分布函数,\( \text{ erf} \) 是误差函数。

近似表达式

由于误差函数计算较复杂,常用以下近似:

$$\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)$$

性质分析

  1. /平滑性/:GELU是无限可微的平滑函数
  2. /非单调性/:当 \( x < 0 \) 时,函数值为负
  3. /概率解释/:基于输入的概率分布进行门控
  4. /性能优势/:在自然语言处理任务中表现优异

导数计算

GELU的导数为:

$$\frac{d}{dx}\text{GELU}(x) = \Phi(x) + x \cdot \phi(x)$$

其中 \( \phi(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2} \) 是标准正态分布的概率密度函数。

实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erf

def gelu_exact(x):

 """精确GELU实现"""
    return 0.5 * x * (1 + erf(x / np.sqrt(2)))

def gelu_approximate(x):
    """近似GELU实现"""
    return 0.5 * x * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.
044715 * x**3)))

class GELUAnalysis:
    """GELU激活函数分析"""

    def __init__(self):
        pass

    def analyze_gelu(self, x_range=(-4, 4), num_points=1000):
        """分析GELU函数特性"""
        x = np.linspace(x_range[0], x_range[1], num_points)

        # 计算精确和近似值
        y_exact = gelu_exact(x)
        y_approx = gelu_approximate(x)

        # 计算导数
        phi_x = np.exp(-x**2/2) / np.sqrt(2*np.pi)  # 标准正态PDF
        Phi_x = 0.5 * (1 + erf(x/np.sqrt(2)))       # 标准正态CDF
        derivative = Phi_x + x * phi_x

        # 误差分析
        error = np.abs(y_exact - y_approx)

        return {
            'x': x,
            'y_exact': y_exact,
            'y_approx': y_approx,
            'derivative': derivative,
            'error': error,
            'phi_x': phi_x,
            'Phi_x': Phi_x
        }

    def compare_with_other_activations(self, x):
        """与其他激活函数比较"""
        x_tensor = torch.tensor(x, dtype=torch.float32)

        activations = {
            'GELU': F.gelu(x_tensor).numpy(),
            'ReLU': F.relu(x_tensor).numpy(),
            'SiLU': x_tensor * torch.sigmoid(x_tensor).numpy(),
            'ELU': F.elu(x_tensor).numpy(),
            'Tanh': torch.tanh(x_tensor).numpy()
        }

        return activations

def demonstrate_gelu():
    """演示GELU激活函数"""

    print("GELU激活函数演示")
    print("=" * 50)

    analyzer = GELUAnalysis()
    results = analyzer.analyze_gelu()

    x = results['x']
    y_exact = results['y_exact']
    y_approx = results['y_approx']
    derivative = results['derivative']
    error = results['error']

    print(f"输入范围: [{x.min():.1f}, {x.max():.1f}]")
    print(f"输出范围: [{y_exact.min():.3f}, {y_exact.max():.3f}]")
    print(f"导数范围: [{derivative.min():.3f}, {derivative.max():.3f}]")
    print(f"近似最大误差: {error.max():.6f}")

    # 关键点分析
    critical_points = [-3, -1, 0, 1, 3]
    print(f"\n关键点分析:")
    for point in critical_points:
        idx = np.argmin(np.abs(x - point))
        y_exact_val = y_exact[idx]
        y_approx_val = y_approx[idx]
        deriv_val = derivative[idx]
        print(f"  x={point:2.0f}: GELU={y_exact_val:.3f}, 近似={y_approx_val:.3f}, 导数={deriv_val:.3f}")

    # 与其他激活函数比较
    activations = analyzer.compare_with_other_activations(x)
    print(f"\n激活函数比较 (x=0):")
    for name, values in activations.items():
        zero_idx = np.argmin(np.abs(x))
        print(f"  {name}: {values[zero_idx]:.3f}")

    return results, activations

gelu_results, gelu_comparison = demonstrate_gelu()

*/ 可视化

plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'Hiragino Sans GB', 'STHeiti']
plt.rcParams['axes.unicode_minus'] = False

# 创建可视化
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# GELU函数与近似
ax1 = axes[0, 0]
ax1.plot(gelu_results['x'], gelu_results['y_exact'], 'b-', linewidth=2, label='GELU (精确)')
ax1.plot(gelu_results['x'], gelu_results['y_approx'], 'r--', linewidth=2, label='GELU (近似)')
ax1.set_xlabel('x')
ax1.set_ylabel('GELU(x)')
ax1.set_title('GELU激活函数:精确 vs 近似')
ax1.grid(True, alpha=0.3)
ax1.legend()

# 导数
ax2 = axes[0, 1]
ax2.plot(gelu_results['x'], gelu_results['derivative'], 'g-', linewidth=2, label='GELU导数')
ax2.set_xlabel('x')
ax2.set_ylabel('dGELU/dx')
ax2.set_title('GELU导数')
ax2.grid(True, alpha=0.3)
ax2.legend()

# 近似误差
ax3 = axes[1, 0]
ax3.plot(gelu_results['x'], gelu_results['error'], 'm-', linewidth=2)
ax3.set_xlabel('x')
ax3.set_ylabel('绝对误差')
ax3.set_title('GELU近似误差')
ax3.grid(True, alpha=0.3)

# 与其他激活函数比较
ax4 = axes[1, 1]
for name, values in gelu_comparison.items():
    ax4.plot(gelu_results['x'], values, linewidth=2, label=name)
ax4.set_xlabel('x')
ax4.set_ylabel('激活值')
ax4.set_title('GELU与其他激活函数比较')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
fname = "images/gelu_analysis.png"
plt.savefig(fname, dpi=300, bbox_inches='tight')
plt.close()
fname

*/ 应用场景

  1. /Transformer架构/:BERT、GPT等模型的标准激活函数
  2. /自然语言处理/:在NLP任务中表现优异
  3. /计算机视觉/:在ViT等视觉Transformer中使用
  4. /需要平滑激活的场景/:相比ReLU提供更好的梯度流

*/ 优势与局限

优势:

局限:

归一化层

LayerNorm

层归一化(Layer Normalization)是 Transformer 架构中的关键组件,用于稳定训练过程和提高模型性能。

为什么用LayerNorm 而非 BatchNorm

在 Transformer 架构中选择 LayerNorm 而非 BatchNorm 主要基于以下几个原因:

在 Transformer 中的位置

数学公式

对于一个输入向量 \( x \in \mathbb{R}^d \),层归一化的计算如下:

$$\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

其中:

与 BatchNorm 的区别

PyTorch 实现

import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))  # 缩放参数
        self.beta = nn.Parameter(torch.zeros(d_model))  # 偏移参数

    def forward(self, x):
        """
        Args:
            x: 输入张量 [batch_size, seq_len, d_model] 或 [batch_size, d_model]
        Returns:
            归一化后的张量
        """
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)

        # 归一化
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)

        # 缩放和偏移
        return self.gamma * x_normalized + self.beta

# 测试代码
if __name__ == "__main__":
    batch_size, seq_len, d_model = 2, 5, 512

    # 测试 LayerNorm
    x = torch.randn(batch_size, seq_len, d_model)
    layer_norm = LayerNorm(d_model)

    output = layer_norm(x)
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {output.shape}")
    print(f"输出均值: {output.mean().item():.6f}")
    print(f"输出标准差: {output.std().item():.6f}")

    # 与 PyTorch 官方实现对比
    official_ln = nn.LayerNorm(d_model)
    official_output = official_ln(x)

    print(f"自定义与官方实现差异: {torch.abs(output - official_output).max().item():.6f}")

在 Transformer 中的应用

在 Transformer 中,LayerNorm 通常应用于:

  1. 多头注意力后的残差连接
  2. 前馈网络后的残差连接

具体形式:

$$\text{Output} = \text{LayerNorm}(x + \text{Sublayer}(x))$$
其中 \(\text{Sublayer}\) 可以是多头注意力或前馈网络。

GroupNorm

组归一化(Group Normalization)是一种深度学习归一化技术,特别适用于小批量训练和 计算机视觉任务。它将通道维度分组进行归一化,不依赖于批量大小。

数学定义

给定输入张量 \( \mathbf{x} \in \mathbb{R}^{B \times C \times H \times W} \)(批量大小 \( B \)、通道数 \( C \)、高度 \( H \)、宽度 \( W \)),组归一化计算:

$$\mu_{bg} = \frac{1}{G \cdot H \cdot W} \sum_{c=gK}^{(g+1)K} \sum_{h=1}^H \sum_{w=1}^W x_{bchw}$$
$$\sigma_{bg}^2 = \frac{1}{G \cdot H \cdot W} \sum_{c=gK}^{(g+1)K} \sum_{h=1}^H \sum_{w=1}^W (x_{bchw} - \mu_{bg})^2$$
$$\hat{x}_{bchw} = \frac{x_{bchw} - \mu_{bg}}{\sqrt{\sigma_{bg}^2 + \epsilon}}$$
$$y_{bchw} = \gamma_c \hat{x}_{bchw} + \beta_c$$

其中:

实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GroupNorm(nn.Module):
    """组归一化实现"""

    def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
        """
        参数:
            num_groups: 分组数
            num_channels: 通道总数
            eps: 数值稳定性常数
            affine: 是否使用可学习的仿射参数
        """
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine

        # 参数检查
        assert num_channels % num_groups == 0, \
            f"通道数{num_channels}必须能被分组数{num_groups}整除"

        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_channels))
            self.bias = nn.Parameter(torch.zeros(num_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        """前向传播"""
        B, C, H, W = x.shape

        # 重塑为 (B, G, C//G, H, W)
        x_reshaped = x.view(B, self.num_groups, C // self.num_groups, H, W)

        # 计算均值和方差
        mean = x_reshaped.mean(dim=[2, 3, 4], keepdim=True)
        var = x_reshaped.var(dim=[2, 3, 4], unbiased=False, keepdim=True)

        # 归一化
        x_normalized = (x_reshaped - mean) / torch.sqrt(var + self.eps)

        # 恢复原始形状
        x_normalized = x_normalized.view(B, C, H, W)

        # 仿射变换
        if self.affine:
            x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)

        return x_normalized

    def extra_repr(self):
        return f"groups={self.num_groups}, channels={self.num_channels}, " \
               f"eps={self.eps}, affine={self.affine}"

class GroupNorm1D(nn.Module):
    """一维组归一化"""

    def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine

        assert num_channels % num_groups == 0

        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_channels))
            self.bias = nn.Parameter(torch.zeros(num_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        """前向传播"""
        B, C, L = x.shape

        # 重塑为 (B, G, C//G, L)
        x_reshaped = x.view(B, self.num_groups, C // self.num_groups, L)

        # 计算均值和方差
        mean = x_reshaped.mean(dim=[2, 3], keepdim=True)
        var = x_reshaped.var(dim=[2, 3], unbiased=False, keepdim=True)

        # 归一化
        x_normalized = (x_reshaped - mean) / torch.sqrt(var + self.eps)

        # 恢复原始形状
        x_normalized = x_normalized.view(B, C, L)

        # 仿射变换
        if self.affine:
            x_normalized = self.weight.view(1, C, 1) * x_normalized + self.bias.view(1, C, 1)

        return x_normalized

def demonstrate_group_norm():
    """演示组归一化"""

    print("组归一化演示")
    print("=" * 50)

    # 测试数据 - 图像特征图
    batch_size, channels, height, width = 4, 16, 8, 8
    x = torch.randn(batch_size, channels, height, width)

    print(f"输入张量形状: {x.shape}")
    print(f"输入统计 - 均值: {x.mean():.4f}, 标准差: {x.std():.4f}")

    # 不同分组数的组归一化
    group_configs = [
        (1, "LayerNorm-like"),   # 1组 = 层归一化
        (2, "2 Groups"),
        (4, "4 Groups"),
        (8, "8 Groups"),
        (16, "InstanceNorm-like")  # 每组1通道 = 实例归一化
    ]

    results = {}

    for num_groups, description in group_configs:
        print(f"\n{description} (分组数: {num_groups}):")

        group_norm = GroupNorm(num_groups, channels)
        y = group_norm(x)

        print(f"  输出统计 - 均值: {y.mean():.4f}, 标准差: {y.std():.4f}")

        # 验证组内归一化
        for b in range(min(2, batch_size)):  # 只检查前2个样本
            for g in range(num_groups):
                group_channels = channels // num_groups
                start_ch = g * group_channels
                end_ch = (g + 1) * group_channels

                group_data = y[b, start_ch:end_ch]
                group_mean = group_data.mean().item()
                group_std = group_data.std().item()

                if g == 0:  # 只打印第一组作为示例
                    print(f"  样本{b}组{g} - 均值: {group_mean:.6f}, 标准差: {group_std:.6f}")

        results[description] = y

    # 一维组归一化演示
    print(f"\n一维组归一化演示:")
    batch_size_1d, channels_1d, length = 2, 12, 20
    x_1d = torch.randn(batch_size_1d, channels_1d, length)

    group_norm_1d = GroupNorm1D(3, channels_1d)  # 3组
    y_1d = group_norm_1d(x_1d)

    print(f"  输入形状: {x_1d.shape}")
    print(f"  输出统计 - 均值: {y_1d.mean():.4f}, 标准差: {y_1d.std():.4f}")

    return {
        'configs': group_configs,
        'results': results,
        'group_norm_1d': group_norm_1d,
        'output_1d': y_1d
    }

group_norm_demo = demonstrate_group_norm()

与其它归一化方法的比较

class BatchNorm2d(nn.Module):
    """批归一化实现(用于比较)"""

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

        self.training = True

    def forward(self, x):
        """前向传播"""
        B, C, H, W = x.shape

        if self.training:
            # 沿批量、空间维度计算统计量
            mean = x.mean(dim=[0, 2, 3])
            var = x.var(dim=[0, 2, 3], unbiased=False)

            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        # 归一化
        x_normalized = (x - mean.view(1, C, 1, 1)) / torch.sqrt(var.view(1, C, 1, 1) + self.eps)

        return self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)

class InstanceNorm2d(nn.Module):
    """实例归一化实现"""

    def __init__(self, num_features, eps=1e-5, affine=True):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine

        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        """前向传播"""
        B, C, H, W = x.shape

        # 沿空间维度计算统计量
        mean = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], unbiased=False, keepdim=True)

        # 归一化
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)

        # 仿射变换
        if self.affine:
            x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)

        return x_normalized

def compare_normalization_methods():
    """比较不同归一化方法"""

    print("归一化方法比较")
    print("=" * 50)

    # 测试不同批量大小
    batch_sizes = [1, 4, 16]
    channels, height, width = 8, 16, 16

    results = {}

    for batch_size in batch_sizes:
        print(f"\n批量大小: {batch_size}")
        print("-" * 30)

        x = torch.randn(batch_size, channels, height, width)

        # 批归一化
        if batch_size > 1:  # 批归一化需要批量大小 > 1
            batch_norm = BatchNorm2d(channels)
            y_batch = batch_norm(x)
            batch_mean = y_batch.mean().item()
            batch_std = y_batch.std().item()
            print(f"  批归一化 - 均值: {batch_mean:.4f}, 标准差: {batch_std:.4f}")
        else:
            print(f"  批归一化 - 不适用(批量大小=1)")

        # 组归一化 (4组)
        group_norm = GroupNorm(4, channels)
        y_group = group_norm(x)
        group_mean = y_group.mean().item()
        group_std = y_group.std().item()
        print(f"  组归一化 - 均值: {group_mean:.4f}, 标准差: {group_std:.4f}")

        # 实例归一化
        instance_norm = InstanceNorm2d(channels)
        y_instance = instance_norm(x)
        instance_mean = y_instance.mean().item()
        instance_std = y_instance.std().item()
        print(f"  实例归一化 - 均值: {instance_mean:.4f}, 标准差: {instance_std:.4f}")

        # 层归一化 (重塑为2D)
        layer_norm = nn.LayerNorm([channels, height, width])
        y_layer = layer_norm(x)
        layer_mean = y_layer.mean().item()
        layer_std = y_layer.std().item()
        print(f"  层归一化 - 均值: {layer_mean:.4f}, 标准差: {layer_std:.4f}")

        results[batch_size] = {
            'group': y_group,
            'instance': y_instance,
            'layer': y_layer
        }
        if batch_size > 1:
            results[batch_size]['batch'] = y_batch

    # 计算效率比较
    print(f"\n计算效率比较:")
    large_tensor = torch.randn(32, 64, 32, 32)

    import time

    # 组归一化
    gn = GroupNorm(8, 64)
    start = time.time()
    _ = gn(large_tensor)
    gn_time = time.time() - start

    # 批归一化
    bn = BatchNorm2d(64)
    start = time.time()
    _ = bn(large_tensor)
    bn_time = time.time() - start

    print(f"  组归一化时间: {gn_time:.6f}s")
    print(f"  批归一化时间: {bn_time:.6f}s")
    print(f"  相对效率: {gn_time/bn_time:.2f}x")

    return results

normalization_comparison = compare_normalization_methods()

在ResNet中的应用

class BasicBlock(nn.Module):
    """基础残差块(使用组归一化)"""

    def __init__(self, in_channels, out_channels, stride=1, groups=8):
        super().__init__()

        # 第一个卷积层
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.gn1 = GroupNorm(groups, out_channels)

        # 第二个卷积层
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.gn2 = GroupNorm(groups, out_channels)

        # 激活函数
        self.relu = nn.ReLU(inplace=True)

        # 下采样
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                GroupNorm(groups, out_channels)
            )

    def forward(self, x):
        """前向传播"""
        identity = x

        # 主路径
        out = self.conv1(x)
        out = self.gn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.gn2(out)

        # 捷径
        if self.downsample is not None:
            identity = self.downsample(x)

        # 残差连接
        out += identity
        out = self.relu(out)

        return out

class ResNetGN(nn.Module):
    """使用组归一化的ResNet"""

    def __init__(self, block, layers, num_classes=1000, groups=8):
        super().__init__()
        self.in_channels = 64
        self.groups = groups

        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.gn1 = GroupNorm(groups, 64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 残差层
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 分类头
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        """构建残差层"""
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, self.groups))
        self.in_channels = out_channels

        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels, 1, self.groups))

        return nn.Sequential(*layers)

    def forward(self, x):
        """前向传播"""
        x = self.conv1(x)
        x = self.gn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def demonstrate_resnet_gn():
    """演示ResNet中的组归一化"""

    print("ResNet中的组归一化")
    print("=" * 50)

    # 创建小型ResNet
    model = ResNetGN(BasicBlock, [2, 2, 2, 2], num_classes=10, groups=8)

    print(f"模型结构:")
    print(f"  总参数量: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  分组数: {model.groups}")

    # 测试前向传播
    batch_size = 2
    x = torch.randn(batch_size, 3, 32, 32)

    with torch.no_grad():
        output = model(x)

    print(f"\n前向传播测试:")
    print(f"  输入形状: {x.shape}")
    print(f"  输出形状: {output.shape}")
    print(f"  输出范围: [{output.min():.3f}, {output.max():.3f}]")

    # 不同批量大小测试
    print(f"\n不同批量大小测试:")
    for bs in [1, 2, 4, 8]:
        x_test = torch.randn(bs, 3, 32, 32)
        with torch.no_grad():
            output_test = model(x_test)

        print(f"  批量大小{bs:2d} - 输出均值: {output_test.mean():.4f}, 标准差: {output_test.std():.4f}")

    # 梯度测试
    print(f"\n梯度稳定性测试:")
    x_grad = torch.randn(2, 3, 32, 32, requires_grad=True)
    output_grad = model(x_grad)
    loss = output_grad.sum()
    loss.backward()

    grad_norm = x_grad.grad.norm()
    print(f"  输入梯度范数: {grad_norm:.4f}")

    return model

resnet_gn_demo = demonstrate_resnet_gn()

BatchNorm

批归一化(Batch Normalization)是深度学习中广泛使用的归一化技术,通过对每个特征通道在批次维度上进行归一化来加速训练并提高模型稳定性。

数学定义

给定输入张量 \( \mathbf{x} \in \mathbb{R}^{B \times C \times H \times W} \)(批量大小 \( B \)、通道数 \( C \)、高度 \( H \)、宽度 \( W \)),批归一化计算:

$$\mu_c = \frac{1}{B \cdot H \cdot W} \sum_{b=1}^B \sum_{h=1}^H \sum_{w=1}^W x_{bchw}$$
$$\sigma_c^2 = \frac{1}{ B \cdot H \cdot W} \sum_{b=1}^B \sum_{h=1}^H \sum_{w=1}^W (x_{bchw} - \mu_c)^2$$
$$\hat{x}_{bchw} = \frac{x_{bchw} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}}$$
$$y_{bchw} = \gamma_c \hat{x}_{bchw} + \beta_c$$

其中:

训练与推理模式

实现

import torch
import torch.nn as nn
import numpy as np

class BatchNorm2d(nn.Module):
    """批归一化2D实现"""

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        """
        参数:
            num_features: 特征通道数
            eps: 数值稳定性常数
            momentum: 运行统计量动量
            affine: 是否使用可学习参数
        """
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine

        # 可学习参数
        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        # 运行统计量
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

        # 重置参数
        self.reset_parameters()

    def reset_parameters(self):
        """重置参数"""
        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.num_batches_tracked.zero_()

    def forward(self, x):
        """前向传播"""
        if self.training:
            return self._forward_train(x)
        else:
            return self._forward_eval(x)

    def _forward_train(self, x):
        """训练模式前向传播"""
        B, C, H, W = x.shape

        # 计算批次统计量
        mean = x.mean(dim=[0, 2, 3])  # 沿批次和空间维度
        var = x.var(dim=[0, 2, 3], unbiased=False)

        # 更新运行统计量
        with torch.no_grad():
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            self.num_batches_tracked += 1

        # 归一化
        x_normalized = (x - mean.view(1, C, 1, 1)) / torch.sqrt(var.view(1, C, 1, 1) + self.eps)

        # 仿射变换
        if self.affine:
            x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)

        return x_normalized

    def _forward_eval(self, x):
        """评估模式前向传播"""
        B, C, H, W = x.shape

        # 使用运行统计量
        mean = self.running_mean
        var = self.running_var

        # 归一化
        x_normalized = (x - mean.view(1, C, 1, 1)) / torch.sqrt(var.view(1, C, 1, 1) + self.eps)

        # 仿射变换
        if self.affine:
            x_normalized = self.weight.view(1, C, 1, 1) * x_normalized + self.bias.view(1, C, 1, 1)

        return x_normalized

    def extra_repr(self):
        return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={self.affine}"

class BatchNorm1d(nn.Module):
    """批归一化1D实现"""

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine

        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

        self.reset_parameters()

    def reset_parameters(self):
        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.num_batches_tracked.zero_()

    def forward(self, x):
        if self.training:
            return self._forward_train(x)
        else:
            return self._forward_eval(x)

    def _forward_train(self, x):
        if x.dim() == 2:  # (B, C)
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
        elif x.dim() == 3:  # (B, C, L)
            mean = x.mean(dim=[0, 2])
            var = x.var(dim=[0, 2], unbiased=False)
        else:
            raise ValueError(f"期望2D或3D输入,得到 {x.dim()}D")

        # 更新运行统计量
        with torch.no_grad():
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            self.num_batches_tracked += 1

        # 归一化
        if x.dim() == 2:
            x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        else:
            x_normalized = (x - mean.view(1, -1, 1)) / torch.sqrt(var.view(1, -1, 1) + self.eps)

        # 仿射变换
        if self.affine:
            if x.dim() == 2:
                x_normalized = self.weight * x_normalized + self.bias
            else:
                x_normalized = self.weight.view(1, -1, 1) * x_normalized + self.bias.view(1, -1, 1)

        return x_normalized

    def _forward_eval(self, x):
        mean = self.running_mean
        var = self.running_var

        if x.dim() == 2:
            x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        else:
            x_normalized = (x - mean.view(1, -1, 1)) / torch.sqrt(var.view(1, -1, 1) + self.eps)

        if self.affine:
            if x.dim() == 2:
                x_normalized = self.weight * x_normalized + self.bias
            else:
                x_normalized = self.weight.view(1, -1, 1) * x_normalized + self.bias.view(1, -1, 1)

        return x_normalized

def demonstrate_batch_norm():
    """演示批归一化"""

    print("批归一化演示")
    print("=" * 50)

    # 2D批归一化测试
    print("2D批归一化:")
    batch_size, channels, height, width = 4, 8, 16, 16
    x_2d = torch.randn(batch_size, channels, height, width)

    bn_2d = BatchNorm2d(channels)
    y_2d = bn_2d(x_2d)

    print(f"  输入形状: {x_2d.shape}")
    print(f"  输出统计 - 均值: {y_2d.mean():.6f}, 标准差: {y_2d.std():.6f}")

    # 验证通道统计量
    print(f"  通道统计量验证:")
    for c in range(min(3, channels)):  # 只检查前3个通道
        channel_mean = y_2d[:, c].mean().item()
        channel_std = y_2d[:, c].std().item()
        print(f"    通道{c}: 均值={channel_mean:.6f}, 标准差={channel_std:.6f}")

    # 1D批归一化测试
    print(f"\n1D批归一化:")
    batch_size_1d, features, seq_len = 8, 16, 32
    x_1d = torch.randn(batch_size_1d, features, seq_len)

    bn_1d = BatchNorm1d(features)
    y_1d = bn_1d(x_1d)

    print(f"  输入形状: {x_1d.shape}")
    print(f"  输出统计 - 均值: {y_1d.mean():.6f}, 标准差: {y_1d.std():.6f}")

    # 训练与推理模式对比
    print(f"\n训练与推理模式对比:")
    bn_2d.eval()  # 切换到推理模式
    y_eval = bn_2d(x_2d)

    print(f"  训练模式输出均值: {y_2d.mean():.6f}")
    print(f"  推理模式输出均值: {y_eval.mean():.6f}")
    print(f"  运行统计量 - 均值: {bn_2d.running_mean[:3]}")
    print(f"  运行统计量 - 方差: {bn_2d.running_var[:3]}")

    return {
        'bn_2d': bn_2d,
        'bn_1d': bn_1d,
        'output_2d': y_2d,
        'output_1d': y_1d,
        'output_eval': y_eval
    }

batch_norm_demo = demonstrate_batch_norm()

优势与局限

RMSNorm

RMSNorm(Root Mean Square Normalization)是一种基于均方根的归一化方法,在 Transformer 架构中作为 LayerNorm 的轻量级替代方案,去除了均值中心化操作。

数学定义

对于输入向量 \( x \in \mathbb{R}^d \),RMSNorm 计算如下:

$$\text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^d x_i^2}$$
$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x) + \epsilon} \odot g$$

其中:

与 LayerNorm 的区别

LayerNorm:

$$\text{LayerNorm}(x) = \frac{x - \mu}{\sigma} \odot g + b$$

RMSNorm:

$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \odot g$$

关键区别:

优势

  1. 计算效率高
  2. 适用于小批量或者是单样本
  3. 稳定:使用均方根进行归一化,可以在一定程度上避免梯度爆炸和梯度消失,提高训练稳定性

RMSNorm 实现

import torch
import torch.nn as nn
import numpy as np

class RMSNorm(nn.Module):
    """RMSNorm 实现"""

    def __init__(self, d_model, eps=1e-8):
        """
        参数:
            d_model: 特征维度
            eps: 数值稳定性常数
        """
        super().__init__()
        self.d_model = d_model
        self.eps = eps

        # 可学习的缩放参数
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        """
        前向传播

        Args:
            x: 输入张量 [batch_size, seq_len, d_model] 或 [batch_size, d_model]
        Returns:
            归一化后的张量
        """
        # 计算均方根 (RMS)
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)

        # 归一化并缩放
        x_normalized = x / rms
        return x_normalized * self.weight

    def extra_repr(self):
        return f"d_model={self.d_model}, eps={self.eps}"

class RMSNorm1D(nn.Module):
    """1D RMSNorm 实现"""

    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        """处理 1D 输入 [batch_size, d_model]"""
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x_normalized = x / rms
        return x_normalized * self.weight

def demonstrate_rms_norm():
    """演示 RMSNorm"""

    print("RMSNorm 演示")
    print("=" * 50)

    # 测试数据
    batch_size, seq_len, d_model = 4, 10, 512
    x = torch.randn(batch_size, seq_len, d_model)

    print(f"输入形状: {x.shape}")
    print(f"输入统计 - 均值: {x.mean():.4f}, 标准差: {x.std():.4f}")

    # RMSNorm

 rms_norm = RMSNorm(d_model)
    y_rms = rms_norm(x)

    print(f"\nRMSNorm 输出:")
    print(f"  输出统计 - 均值: {y_rms.mean():.6f}, 标准差: {y_rms.std():.6f}")

    # 验证 RMS 归一化
    print(f"\nRMS 归一化验证:")
    for b in range(min(2, batch_size)):
        for s in range(min(2, seq_len)):
            sample = x[b, s]
            rms_manual = torch.sqrt(torch.mean(sample**2))
            normalized_manual = sample / rms_manual
            normalized_actual = y_rms[b, s] / rms_norm.weight

            diff = torch.abs(normalized_manual - normalized_actual).max()
            print(f"  样本[{b},{s}] - RMS手动: {rms_manual:.4f}, 最大差异: {diff:.6f}")
            break  # 只检查第一个位置
        break

    # 与 LayerNorm 比较
    print(f"\n与 LayerNorm 比较:")
    layer_norm = nn.LayerNorm(d_model)
    y_layer = layer_norm(x)

    print(f"  LayerNorm 输出 - 均值: {y_layer.mean():.6f}, 标准差: {y_layer.std():.6f}")

    # 计算效率比较
    import time

    # 预热
    _ = rms_norm(x)
    _ = layer_norm(x)

    # RMSNorm 时间
    start = time.time()
    for _ in range(1000):
        _ = rms_norm(x)
    rms_time = time.time() - start

    # LayerNorm 时间
    start = time.time()
    for _ in range(1000):
        _ = layer_norm(x)
    layer_time = time.time() - start

    print(f"\n计算效率比较 (1000次前向传播):")
    print(f"  RMSNorm 时间: {rms_time:.4f}s")
    print(f"  LayerNorm 时间: {layer_time:.4f}s")
    print(f"  RMSNorm 相对速度: {layer_time/rms_time:.2f}x")

    # 内存使用比较
    print(f"\n参数数量比较:")
    print(f"  RMSNorm 参数量: {sum(p.numel() for p in rms_norm.parameters())}")
    print(f"  LayerNorm 参数量: {sum(p.numel() for p in layer_norm.parameters())}")

    return {
        'rms_norm': rms_norm,
        'layer_norm': layer_norm,
        'output_rms': y_rms,
        'output_layer': y_layer,
        'timing': {'rms': rms_time, 'layer': layer_time}
    }

rms_norm_demo = demonstrate_rms_norm()

在 Transformer 中的应用

class RMSNormTransformerBlock(nn.Module):
    """使用 RMSNorm 的 Transformer 块"""

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # 自注意力机制
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)

        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )

        # 使用 RMSNorm 替代 LayerNorm
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        """前向传播"""
        # 自注意力 + 残差连接
        attn_output, _ = self.self_attn(
            self.norm1(x),
            self.norm1(x),
            self.norm1(x),
            attn_mask=attn_mask
        )
        x = x + self.dropout(attn_output)

        # 前馈网络 + 残差连接
        ffn_output = self.ffn(self.norm2(x))
        x = x + self.dropout(ffn_output)

        return x

def demonstrate_rms_norm_transformer():
    """演示 RMSNorm 在 Transformer 中的应用"""

    print("RMSNorm Transformer 演示")
    print("=" * 50)

    # 创建模型
    d_model, nhead, seq_len, batch_size = 512, 8, 20, 4
    transformer_block = RMSNormTransformerBlock(d_model, nhead)

    print(f"模型配置:")
    print(f"  d_model: {d_model}")
    print(f"  nhead: {nhead}")
    print(f"  总参数量: {sum(p.numel() for p in transformer_block.parameters()):,}")

    # 测试前向传播
    x = torch.randn(batch_size, seq_len, d_model)
    output = transformer_block(x)

    print(f"\n前向传播测试:")
    print(f"  输入形状: {x.shape}")
    print(f"  输出形状: {output.shape}")
    print(f"  输入统计 - 均值: {x.mean():.4f}, 标准差: {x.std():.4f}")
    print(f"  输出统计 - 均值: {output.mean():.4f}, 标准差: {output.std():.4f}")

    # 梯度测试
    print(f"\n梯度稳定性测试:")
    x_grad = torch.randn(2, seq_len, d_model, requires_grad=True)
    output_grad = transformer_block(x_grad)
    loss = output_grad.sum()
    loss.backward()

    grad_norm = x_grad.grad.norm()
    print(f"  输入梯度范数: {grad_norm:.4f}")

    # 不同序列长度测试
    print(f"\n不同序列长度测试:")
    for test_seq_len in [10, 20, 50, 100]:
        x_test = torch.randn(batch_size, test_seq_len, d_model)
        with torch.no_grad():
            output_test = transformer_block(x_test)

        print(f"  序列长度{test_seq_len:3d} - 输出均值: {output_test.mean():.4f}, 标准差: {output_test.std():.4f}")

    return transformer_block

rms_transformer_demo = demonstrate_rms_norm_transformer()

性能分析

def analyze_rms_norm_performance():
    """分析 RMSNorm 性能"""

    print("RMSNorm 性能分析")
    print("=" * 50)

    # 不同批量大小测试
    batch_sizes = [1, 4, 16, 64]
    seq_len, d_model = 128, 512

    print("不同批量大小下的性能:")
    print(f"{'批量大小':<10} {'RMSNorm时间(s)':<15} {'LayerNorm时间(s)':<15} {'速度比':<10}")
    print("-" * 50)

    for batch_size in batch_sizes:
        x = torch.randn(batch_size, seq_len, d_model)

        rms_norm = RMSNorm(d_model)
        layer_norm = nn.LayerNorm(d_model)

        # 预热
        _ = rms_norm(x)
        _ = layer_norm(x)

        # RMSNorm 时间
        start = time.time()
        for _ in range(100):
            _ = rms_norm(x)
        rms_time = time.time() - start

        # LayerNorm 时间
        start = time.time()
        for _ in range(100):
            _ = layer_norm(x)
        layer_time = time.time() - start

        speed_ratio = layer_time / rms_time
        print(f"{batch_size:<10} {rms_time:<15.4f} {layer_time:<15.4f} {speed_ratio:<10.2f}")

    # 不同特征维度测试
    print(f"\n不同特征维度下的性能:")
    batch_size, seq_len = 8, 64
    d_models = [128, 256, 512, 1024, 2048]

    print(f"{'特征维度':<10} {'RMSNorm时间(s)':<15} {'LayerNorm时间(s)':<15} {'速度比':<10}")
    print("-" * 50)

    for d_model in d_models:
        x = torch.randn(batch_size, seq_len, d_model)

        rms_norm = RMSNorm(d_model)
        layer_norm = nn.LayerNorm(d_model)

        # 预热
        _ = rms_norm(x)
        _ = layer_norm(x)

        # RMSNorm 时间
        start = time.time()
        for _ in range(100):
            _ = rms_norm(x)
        rms_time = time.time() - start

        # LayerNorm 时间
        start = time.time()
        for _ in range(100):
            _ = layer_norm(x)
        layer_time = time.time() - start

        speed_ratio = layer_time / rms_time
        print(f"{d_model:<10} {rms_time:<15.4f} {layer_time:<15.4f} {speed_ratio:<10.2f}")

    # 内存使用分析
    print(f"\n内存使用分析:")
    d_model = 512
    rms_norm = RMSNorm(d_model)
    layer_norm = nn.LayerNorm(d_model)

    rms_params = sum(p.numel() for p in rms_norm.parameters())
    layer_params = sum(p.numel() for p in layer_norm.parameters())

    print(f"  RMSNorm 参数量: {rms_params}")
    print(f"  LayerNorm 参数量: {layer_params}")
    print(f"  参数减少比例: {(layer_params - rms_params) / layer_params * 100:.1f}%")

    return {
        'batch_size_results': batch_sizes,
        'dimension_results': d_models
    }

performance_analysis = analyze_rms_norm_performance()

适用场景总结

  1. /大规模语言模型/:如 LLaMA、GPT-NeoX 等现代架构
  2. /计算资源受限环境/:需要轻量级归一化的场景
  3. /长序列处理/:在长序列任务中表现稳定
  4. /小批量训练/:对批量大小不敏感

优势与局限总结

优势:

局限:

AdaLN

在 Layer Normalization(LN)的基础上进行了优化,用来增强AI模型在处理不同输入条件时的适应能力

传统的 LN 没有条件信息,就纯归一化:

$$LN(x)=\frac{x-\mu}{\sigma}$$

AdaLN 则引入了 Modulation

例如:

$$AdaLN(x,c)=\alpha( c)\frac{x-\mu}{\sigma}+\beta( c)$$

更为一般的表达是:

$$AdaLN(x,c)=f_\theta( c)\odot LN(x)$$
# 输入条件的特征提取网络
self.adaLN_modulation = nn.Sequential(
    nn.SiLU(),
    nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
# c代表输入的条件信息
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)

shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp ,这6个变量分别对应了多头自注意力机制的归一化参数与缩放参数

AdaLN-Zero

DiT中具体的初始化设置如下所示:

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

卷积层

基本概念

卷积层是深度学习中用于处理网格状数据(如图像、序列)的核心组件,通过滑动窗口的方式在输入数据上提取局部特征。

数学定义

对于二维卷积,给定输入张量 \( X \in \mathbb{R}^{C_{in} \times H \times W} \) 和卷积核 \( K \in \mathbb{R}^{C_{out} \times C_{in} \times k_h \times k_w} \),输出计算为:

$$Y_{c_{out},i,j} = \sum_{c_{in}=0}^{C_{in}-1} \sum_{m=0}^{k_h-1} \sum_{n=0}^{k_w-1} K_{c_{out},c_{in},m,n} \cdot X_{c_{in}, i+m, j+n} + b_{c_{out}}$$

其中:

关键参数

  1. /步长(Stride)/:卷积核移动的步长
  2. /填充(Padding)/:在输入边界添加的零值区域
  3. /膨胀率(Dilation)/:卷积核元素间的间距
  4. /分组(Groups)/:输入输出通道的分组方式

PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicConv2d(nn.Module):
    """基础二维卷积层"""

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__()

        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=padding, dilation=dilation,
            groups=groups, bias=bias
        )

        # 初始化权重
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
        if bias:
            nn.init.constant_(self.conv.bias, 0)

    def forward(self, x):
        return self.conv(x)

def demonstrate_convolution():
    """演示卷积操作"""

    print("卷积层演示")
    print("=" * 50)

    # 基础卷积
    batch_size, in_channels, height, width = 4, 3, 32, 32
    x = torch.randn(batch_size, in_channels, height, width)

    conv = BasicConv2d(in_channels, 64, kernel_size=3, padding=1)
    output = conv(x)

    print(f"输入形状: {x.shape}")
    print(f"卷积核形状: {conv.conv.weight.shape}")
    print(f"输出形状: {output.shape}")

    # 不同参数的影响
    print(f"\n不同卷积参数的影响:")

    # 不同步长
    conv_stride2 = BasicConv2d(in_channels, 64, kernel_size=3, stride=2, padding=1)
    output_stride2 = conv_stride2(x)
    print(f"  步长=2: 输出形状 {output_stride2.shape}")

    # 不同填充
    conv_pad2 = BasicConv2d(in_channels, 64, kernel_size=5, padding=2)
    output_pad2 = conv_pad2(x)
    print(f"  核大小=5, 填充=2: 输出形状 {output_pad2.shape}")

    # 分组卷积
    conv_group = BasicConv2d(in_channels, 64, kernel_size=3, groups=in_channels)
    output_group = conv_group(x)
    print(f"  分组卷积: 输出形状 {output_group.shape}")

    # 膨胀卷积
    conv_dilation = BasicConv2d(in_channels, 64, kernel_size=3, dilation=2, padding=2)
    output_dilation = conv_dilation(x)
    print(f"  膨胀率=2: 输出形状 {output_dilation.shape}")

    return {
        'basic_conv': conv,
        'outputs': {
            'basic': output,
            'stride2': output_stride2,
            'pad2': output_pad2,
            'group': output_group,
            'dilation': output_dilation
        }
    }

conv_demo = demonstrate_convolution()

深度可分离卷积

将标准卷积分解为深度卷积和逐点卷积:

$$\text{DepthwiseConv}(X)_{c,i,j} = \sum_{m=0}^{k_h-1} \sum_{n=0}^{k_w-1} K_{c,m,n} \cdot X_{c, i+m, j+n}$$
$$\text{PointwiseConv}(X)_{c_{out},i,j} = \sum_{c_{in}=0}^{C_{in}-1} W_{c_{out},c_{in}} \cdot X_{c_{in},i,j}$$

class DepthwiseSeparableConv(nn.Module):
    """深度可分离卷积"""

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()

        # 深度卷积
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size,
            stride=stride, padding=padding, groups=in_channels
        )

        # 逐点卷积
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

def compare_convolution_types():
    """比较不同类型卷积"""

    print("卷积类型比较")
    print("=" * 50)

    batch_size, in_channels, height, width = 4, 32, 16, 16
    out_channels = 64
    x = torch.randn(batch_size, in_channels, height, width)

    # 标准卷积
    standard_conv = BasicConv2d(in_channels, out_channels, kernel_size=3, padding=1)

    # 深度可分离卷积
    separable_conv = DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, padding=1)

    # 计算参数量
    standard_params = sum(p.numel() for p in standard_conv.parameters())
    separable_params = sum(p.numel() for p in separable_conv.parameters())

    print(f"标准卷积参数量: {standard_params:,}")
    print(f"深度可分离卷积参数量: {separable_params:,}")
    print(f"参数减少比例: {(1 - separable_params/standard_params)*100:.1f}%")

    # 前向传播测试
    with torch.no_grad():
        standard_output = standard_conv(x)
        separable_output = separable_conv(x)

    print(f"\n输出形状:")
    print(f"  标准卷积: {standard_output.shape}")
    print(f"  深度可分离卷积: {separable_output.shape}")

    return {
        'standard_conv': standard_conv,
        'separable_conv': separable_conv,
        'parameter_ratio': separable_params/standard_params
    }

conv_comparison = compare_convolution_types()

转置卷积

用于上采样操作,通过插入零值实现尺寸扩大:

$$Y_{i,j} = \sum_{m,n} K_{m,n} \cdot X_{\lfloor i/s \rfloor - m + p, \lfloor j/s \rfloor - n + p}$$
class TransposedConvDemo:
    """转置卷积演示"""

    def __init__(self):
        pass

    def demonstrate_transposed_conv(self):
        """演示转置卷积"""

        print("转置卷积演示")
        print("=" * 50)

        # 输入特征图
        batch_size, channels, height, width = 4, 32, 8, 8
        x = torch.randn(batch_size, channels, height, width)

        print(f"输入形状: {x.shape}")

        # 不同上采样倍数的转置卷积
        upscale_factors = [2, 4]

        for factor in upscale_factors:
            # 转置卷积
            conv_transpose = nn.ConvTranspose2d(
                channels, channels, kernel_size=3,
                stride=factor, padding=1, output_padding=factor-1
            )

            output = conv_transpose(x)
            print(f"上采样倍数 {factor}: 输出形状 {output.shape}")

        # 与插值方法比较
        print(f"\n与插值方法比较:")

        # 双线性插值
        bilinear_upsample = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        print(f"双线性插值: 输出形状 {bilinear_upsample.shape}")

        # 转置卷积
        transposed_output = nn.ConvTranspose2d(channels, channels, kernel_size=4, stride=2, padding=1)(x)
        print(f"转置卷积: 输出形状 {transposed_output.shape}")

        return {
            'input': x,
            'bilinear': bilinear_upsample,
            'transposed': transposed_output
        }

transposed_demo = TransposedConvDemo()
transposed_results = transposed_demo.demonstrate_transposed_conv()

*/ 应用场景

  1. /计算机视觉/:图像分类、目标检测、语义分割
  2. /自然语言处理/:文本分类、序列建模
  3. /语音处理/:语音识别、音频生成
  4. /医学影像/:病灶检测、图像分割

ZeroConvolution

零卷积(Zero Convolution)是一种特殊的卷积初始化技术,在控制网络、条件生成和适配 器模块中广泛应用,通过零初始化确保初始状态下不改变输入特征。

基本概念

零卷积的核心思想是将卷积层的权重和偏置初始化为零,使得网络在训练初期表现为恒等映射:

$$\text{ZeroConv}(x) = 0 \cdot x + 0 = 0}$$
初始状态

随着训练进行,卷积层逐渐学习到有意义的变换。

数学定义

对于标准卷积:

$$y = W * x + b$$

零卷积初始化:

$$W^{(0)} = 0, \quad b^{(0)} = 0$$
因此:
$$y^{(0)} = 0 * x + 0 = 0$$

实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class ZeroConv1d(nn.Module):
    """一维零卷积"""

    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
                             stride=stride, padding=padding)
        self._initialize_weights()

    def _initialize_weights(self):
        """零初始化权重和偏置"""
        nn.init.zeros_(self.conv.weight)
        if self.conv.bias is not None:
            nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.conv(x)

class ZeroConv2d(nn.Module):
    """二维零卷积"""

    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                             stride=stride, padding=padding)
        self._initialize_weights()

    def _initialize_weights(self):
        """零初始化权重和偏置"""
        nn.init.zeros_(self.conv.weight)
        if self.conv.bias is not None:
            nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.conv(x)

class ZeroConv3d(nn.Module):
    """三维零卷积"""

    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
                             stride=stride, padding=padding)
        self._initialize_weights()

    def _initialize_weights(self):
        """零初始化权重和偏置"""
        nn.init.zeros_(self.conv.weight)
        if self.conv.bias is not None:
            nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.conv(x)

class ZeroConvAdapter(nn.Module):
    """零卷积适配器模块"""

    def __init__(self, in_channels, out_channels, hidden_ratio=4):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        hidden_channels = in_channels * hidden_ratio

        # 下投影
        self.down_proj = ZeroConv2d(in_channels, hidden_channels, 1)

        # 非线性激活
        self.activation = nn.GELU()

        # 上投影
        self.up_proj = ZeroConv2d(hidden_channels, out_channels, 1)

        # 缩放因子
        self.scale = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        """
        Args:
            x: 输入张量 [B, C, H, W]
        Returns:
            适配器输出
        """
        identity = x

        # 适配器路径
        h = self.down_proj(x)
        h = self.activation(h)
        h = self.up_proj(h)

        # 缩放并残差连接
        return identity + self.scale * h

def demonstrate_zero_conv():
    """演示零卷积"""

    print("零卷积演示")
    print("=" * 50)

    # 1D 零卷积测试
    print("1D 零卷积:")
    batch_size, in_channels, seq_len = 4, 32
, 100
    x_1d = torch.randn(batch_size, in_channels, seq_len)

    zero_conv_1d = ZeroConv1d(in_channels, 64, kernel_size=3, padding=1)
    output_1d = zero_conv_1d(x_1d)

    print(f"  输入形状: {x_1d.shape}")
    print(f"  输出形状: {output_1d.shape}")
    print(f"  初始输出均值: {output_1d.mean().item():.6f}")
    print(f"  初始输出标准差: {output_1d.std().item():.6f}")

    # 2D 零卷积测试
    print(f"\n2D 零卷积:")
    batch_size, in_channels, height, width = 4, 3, 32, 32
    x_2d = torch.randn(batch_size, in_channels, height, width)

    zero_conv_2d = ZeroConv2d(in_channels, 64, kernel_size=3, padding=1)
    output_2d = zero_conv_2d(x_2d)

    print(f"  输入形状: {x_2d.shape}")
    print(f"  输出形状: {output_2d.shape}")
    print(f"  初始输出均值: {output_2d.mean().item():.6f}")
    print(f"  初始输出标准差: {output_2d.std().item():.6f}")

    # 验证零初始化
    print(f"\n零初始化验证:")
    weight_norm = zero_conv_2d.conv.weight.norm().item()
    bias_norm = zero_conv_2d.conv.bias.norm().item() if zero_conv_2d.conv.bias is not None else 0

    print(f"  权重范数: {weight_norm:.6f}")
    print(f"  偏置范数: {bias_norm:.6f}")
    print(f"  输出与零张量的差异: {torch.abs(output_2d).max().item():.6f}")

    # 零卷积适配器演示
    print(f"\n零卷积适配器
:")
    adapter = ZeroConvAdapter(64, 64)
    adapter_output = adapter(x_2d)

    print(f"  适配器输入形状: {x_2d.shape}")
    print(f"  适配器输出形状: {adapter_output.shape}")
    print(f"  初始适配器输出均值: {adapter_output.mean().item():.6f}")
    print(f"  缩放因子: {adapter.scale.item():.6f}")

    # 训练过程中的行为
    print(f"\n训练行为模拟:")

    # 模拟一步训练
    optimizer = torch.optim.SGD(zero_conv_2d.parameters(), lr=0.1)
    loss = output_2d.sum()  # 简单的损失函数
    loss.backward()
    optimizer.step()

    # 检查训练后的输出
    output_after_train = zero_conv_2d(x_2d)
    print(f"  训练一步后输出均值: {output_after_train.mean().item():.6f}")
    print(f"  训练一步后权重范数: {zero_conv_2d.conv.weight.norm().item():.6f}")

    return {
        'zero_conv_1d': zero_conv_1d,
        'zero_conv_2d': zero_conv_2d,
        'adapter': adapter,
        'outputs': {
            '1d': output_1d,
            '2d': output_2d,
            'adapter': adapter_output,
            'after_train': output_after_train
        }
    }

zero_conv_demo = demonstrate_zero_conv()

*/ 在 ControlNet 中的应用

class ZeroConvControlNetBlock(nn.Module):
    """ControlNet 风格的零卷积块"""

    def __init__(self, in_channels, out_channels, condition_channels):
        super().__init__()

        # 主路径卷积
        self.main_conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        # 条件路径 - 使用零卷积
        self.condition_conv = ZeroConv2d(condition_channels, out_channels, 1)

        # 融合门控
        self.gate = nn.Parameter(torch.zeros(1))

        # 输出零卷积
        self.output_conv = ZeroConv2d(out_channels, out_channels, 1)

    def forward(self, x, condition):
        """
        Args:
            x: 主输入 [B, C, H, W]
            condition: 条件输入 [B, C_cond, H, W]
        """
        # 主路径
        main_out = self.main_conv(x)

        # 条件路径
        condition_out = self.condition_conv(condition)

        # 融合
        fused = main_out + self.gate * condition_out

        # 输出变换
        output = self.output_conv(fused)

        return output

class ZeroConvResBlock(nn.Module):
    """零卷积残差块"""

    def __init__(self, channels, condition_channels=None):
        super().__init__()

        # 第一个卷积层
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(32, channels)

        # 条件卷积(可选)
        if condition_channels is not None:
            self.condition_conv = ZeroConv2d(condition_channels, channels, 1)
        else:
            self.condition_conv = None

        # 第二个卷积层
        self.conv2 = ZeroConv2d(channels, channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(32, channels)

        self.activation = nn.SiLU()

        # 门控参数
        self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, x, condition=None):
        identity = x

        # 第一个卷积
        h = self.conv1(x)
        h = self.norm1(h)

        # 条件融合
        if condition is not None and self.condition_conv is not None:
            cond_feat = self.condition_conv(condition)
            h = h + self.gate * cond_feat

        h = self.activation(h)

        # 第二个卷积
        h = self.conv2(h)
        h = self.norm2(h)

        # 残差连接
        return identity + h

def demonstrate_controlnet_application():
    """演示 ControlNet 中的应用"""

    print("ControlNet 中的零卷积应用")
    print("=" * 50)

    # 创建测试数据
    batch_size, channels, height, width = 4, 64, 32, 32
    condition_channels = 128

    x = torch.randn(batch_size, channels, height, width)
    condition = torch.randn(batch_size, condition_channels, height, width)

    # ControlNet 块
    controlnet_block = ZeroConvControlNetBlock(channels, channels, condition_channels)
    output = controlnet_block(x, condition)

    print(f"主输入形状: {x.shape}")
    print(f"条件输入形状: {condition.shape}")
    print(f"输出形状: {output.shape}")

    # 验证初始输出
    print(f"\n初始状态验证:")
    print(f"条件路径输出均值: {controlnet_block.condition_conv(condition).mean().item():.6f}")
    print(f"输出卷积输出均值: {controlnet_block.output_conv(torch.zeros_like(output)).mean().item():.6f}")
    print(f"门控参数: {controlnet_block.gate.item():.6f}")

    # 零卷积残差块
    print(f"\n零卷积残差块:")
    zero_res_block = ZeroConvResBlock(channels, condition_channels)
    res_output = zero_res_block(x, condition)

    print(f"残差块输入形状: {x.shape}")
    print(f"残差块输出形状: {res_output.shape}")
    print(f"初始残差输出均值: {res_output.mean().item():.6f}")

    # 训练稳定性测试
    print(f"\n训练稳定性测试:")

    # 多次前向传播
    with torch.no_grad():
        outputs = []
        for i in range(5):
            output_i = controlnet_block(x, condition)
            outputs.append(output_i.mean().item())

        output_variance = torch.var(torch.tensor(outputs))
        print(f"多次前向传播输出方差: {output_v
ariance.item():.6f}")

    return {
        'controlnet_block': controlnet_block,
        'zero_res_block': zero_res_block,
        'outputs': {
            'controlnet': output,
            'residual': res_output
        }
    }

controlnet_demo = demonstrate
_controlnet_application()

优势与特点

  1. /稳定初始化/:确保网络从恒等映射开始,训练过程更稳定
  2. /渐进式学习/:模型逐渐学习条件信息的影响,避免剧烈变化
  3. /模块化设计/:便于插入到现有架构中作为适配器
  4. /训练友好/:减少训练初期的梯度爆炸风险

应用场景

  1. /ControlNet/:在稳定扩散模型中添加空间条件控制
  2. /模型微调/:作为适配器模块进行参数高效微调
  3. /多模态融合/:融合不同模态的特征表示
  4. /渐进式训练/:从简单任务逐渐过渡到复杂任务

注意事项

  1. /学习率调整/:零卷积层

可能需要不同的学习率调度

  1. /梯度流/:确保零初始化不会阻碍梯度传播
  2. /收敛速度/:初始阶段学习较慢,需要适当训练轮数
  3. /参数初始化/:与其他层的

初始化策略协调

反卷积

反卷积(Deconvolution),也称为转置卷积(Transposed Convolution)或分数步长卷积(Fractionally-strided Convolution),是一种用于特征图上采样的操作。

数学原理

对于输入特征图 \( X \in \mathbb{R}^{C_{in} \times H_{in} \times W_{in}} \) 和卷积核 \( K \in \mathbb{R}^{C_{out} \times C_{in} \times k_h \times k_w} \),反卷积的输出计算为:

$$Y_{c,i,j} = \sum_{m=0}^{k_h-1} \sum_{n=0}^{k_w-1} \sum_{c_{in}=0}^{C_{in}-1} K_{c,c_{in},m,n} \cdot X_{c_{in}, \lfloor i/s \rfloor - m + p, \lfloor j/s \rfloor - n + p}$$

其中:

实现细节

class DeconvolutionDemo:
    """反卷积演示类"""

    def __init__(self):
        pass

    def demonstrate_deconvolution(self):
        """演示反卷积操作"""

        print("反卷积演示")
        print("=" * 50)

        # 输入特征图
        batch_size, in_channels, height, width = 4, 32, 8, 8
        out_channels = 64
        x = torch.randn(batch_size, in_channels, height, width)

        print(f"输入形状: {x.shape}")

        # 不同配置的反卷积
        configurations = [
            {'kernel_size': 3, 'stride': 2, 'padding': 1, 'output_padding': 1},
            {'kernel_size': 4, 'stride': 2, 'padding': 1},
            {'kernel_size': 3, 'stride': 3, 'padding': 1, 'output_padding': 2}
        ]

        for i, config in enumerate(configurations):
            deconv = nn.ConvTranspose2d(
                in_channels, out_channels,
                kernel_size=config['kernel_size'],
                stride=config['stride'],
                padding=config['padding'],
                output_padding=config.get('output_padding', 0)
            )

            output = deconv(x)
            print(f"配置 {i+1}: 核大小 {config['kernel_size']}, 步长 {config['stride']}")
            print(f"  输出形状: {output.shape}")

            # 计算参数量
            params = sum(p.numel() for p in deconv.parameters())
            print(f"  参数量: {params:,}")

        return deconv, output

    def compare_upsampling_methods(self):
        """比较不同上采样方法"""

        print(f"\n上采样方法比较")
        print("=" * 50)

        # 测试输入
        x = torch.randn(2, 32, 16, 16)
        target_size = (32, 32)

        print(f"输入形状: {x.shape}")
        print(f"目标输出形状: (2, 32, {target_size[0]}, {target_size[1]})")

        # 1. 最近邻插值
        nearest = F.interpolate(x, size=target_size, mode='nearest')

        # 2. 双线性插值
        bilinear = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)

        # 3. 双三次插值
        bicubic = F.interpolate(x, size=target_size, mode='bicubic', align_corners=False)

        # 4. 反卷积
        deconv = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)
        deconv_output = deconv(x)

        print(f"\n不同方法输出形状:")
        print(f"  最近邻插值: {nearest.shape}")
        print(f"  双线性插值: {bilinear.shape}")
        print(f"  双三次插值: {bicubic.shape}")
        print(f"  反卷积: {deconv_output.shape}")

        # 计算计算量(FLOPs)
        def calculate_flops_2d(input_shape, output_shape, kernel_size, groups=1):
            """计算2D卷积的FLOPs"""
            batch, in_c, h, w = input_shape
            _, out_c, out_h, out_w = output_shape

            # 每个输出位置的计算量
            flops_per_position = in_c * kernel_size * kernel_size / groups
            total_flops = batch * out_c * out_h * out_w * flops_per_position * 2  # 乘加操作

            return total_flops

        deconv_flops = calculate_flops_2d(x.shape, deconv_output.shape, 4)
        print(f"\n计算量比较:")
        print(f"  反卷积 FLOPs: {deconv_flops:,.0f}")

        return {
            'nearest': nearest,
            'bilinear': bilinear,
            'bicubic': bicubic,
            'deconv': deconv_output
        }

# 演示反卷积
deconv_demo = DeconvolutionDemo()
deconv_results = deconv_demo.demonstrate_deconvolution()
upsampling_comparison = deconv_demo.compare_upsampling_methods()

*/ 棋盘格效应

反卷积在生成图像时可能产生棋盘格效应(checkerboard artifacts),这是由于不均匀的重叠模式造成的。

class CheckerboardAnalysis:
    """棋盘格效应分析"""

    def __init__(self):
        pass

    def analyze_checkerboard_artifacts(self):
        """分析棋盘格效应"""

        print("棋盘格效应分析")
        print("=" * 50)

        # 创建均匀输入
        x = torch.ones(1, 1, 4, 4)

        # 不同核大小的反卷积
        kernel_sizes = [2, 3, 4]

        for kernel_size in kernel_sizes:
            # 创建反卷积层
            deconv = nn.ConvTranspose2d(
                1, 1,
                kernel_size=kernel_size,
                stride=2,
                padding=0 if kernel_size % 2 == 0 else 1
            )

            # 使用均匀权重
            nn.init.constant_(deconv.weight, 1.0)
            if deconv.bias is not None:
                nn.init.constant_(deconv.bias, 0.0)

            output = deconv(x)

            print(f"核大小 {kernel_size}:")
            print(f"  输入形状: {x.shape}")
            print(f"  输出形状: {output.shape}")
            print(f"  输出值范围: [{output.min().item():.3f}, {output.max().item():.3f}]")

            # 计算重叠模式
            overlap_pattern = self.calculate_overlap_pattern(output)
            print(f"  重叠模式方差: {overlap_pattern.var().item():.3f}")

    def calculate_overlap_pattern(self, output):
        """计算输出中的重叠模式"""
        # 简化的重叠模式分析
        return output

    def demonstrate_solutions(self):
        """演示解决棋盘格效应的方法"""

        print(f"\n棋盘格效应解决方案")
        print("=" * 50)

        x = torch.ones(1, 32, 8, 8)

        # 解决方案1: 使用最近邻插值 + 卷积
        print("方案1: 插值 + 卷积")
        upsampled = F.interpolate(x, scale_factor=2, mode='nearest')
        conv_after_upsample = nn.Conv2d(32, 32, kernel_size=3, padding=1)(upsampled)
        print(f"  输出形状: {conv_after_upsample.shape}")

        # 解决方案2: 调整核大小和步长
        print("方案2: 调整核参数")
        deconv_adjusted = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)
        adjusted_output = deconv_adjusted(x)
        print(f"  输出形状: {adjusted_output.shape}")

        # 解决方案3: 使用PixelShuffle
        print("方案3: PixelShuffle")
        pixelshuffle = nn.PixelShuffle(2)
        # 首先将通道数增加到 32 * 4
        conv_before_shuffle = nn.Conv2d(32, 32 * 4, kernel_size=3, padding=1)(x)
        shuffle_output = pixelshuffle(conv_before_shuffle)
        print(f"  输出形状: {shuffle_output.shape}")

checkerboard_analysis = CheckerboardAnalysis()
checkerboard_analysis.analyze_checkerboard_artifacts()
checkerboard_analysis.demonstrate_solutions()

*/ 应用场景

  1. /图像超分辨率/:将低分辨率图像上采样到高分辨率
  2. /语义分割/:在编码器-解码器架构中恢复空间分辨率
  3. /生成对抗网络/:从潜在向量生成图像
  4. /自编码器/:在解码器中重建输入尺寸

*/ 与其他上采样方法的比较

方法 可学习参数 计算成本 棋盘格效应 适用场景
最近邻插值 实时应用
双线性插值 一般上采样
反卷积 可能 需要学习的上采样
PixelShuffle 高质量上采样

*/ 最佳实践

  1. /核大小选择/:使用能被步长整除的核大小以减少棋盘格效应
  2. /初始化策略/:使用双线性插值初始化反卷积权重
  3. /结合其他方法/:可先插值再卷积以获得更好效果
  4. /监控输出/:训练过程中检查是否出现棋盘格模式

Modulation

调制(Modulation)是一种在生成模型和条件生成任务中广泛使用的技术,通过外部条件信息来调整网络的行为和特征表示。

基本概念

调制通过引入条件信息 \( c \) 来调整网络权重或激活值,使模型能够根据输入条件生成不同的输出:

$$y = f(x; \theta( c))$$

其中 \( \theta( c) \) 是根据条件 \( c \) 动态生成的网络参数。

条件批归一化(Conditional Batch Normalization)

在批归一化中引入条件信息:

$$\text{ConditionalBN}(x|c) = \gamma( c) \cdot \frac{x - \mu}{\sigma} + \beta( c)$$

其中 \( \gamma( c), \beta( c) \) 是根据条件 \( c \) 生成的缩放和偏移参数。

class ConditionalBatchNorm2d(nn.Module):
    """条件批归一化"""

    def __init__(self, num_features, condition_dim, hidden_dim=128):
        super().__init__()
        self.num_features = num_features
        self.condition_dim = condition_dim

        # 主批归一化层
        self.bn = nn.BatchNorm2d(num_features, affine=False)

        # 条件映射网络
        self.condition_net = nn.Sequential(
            nn.Linear(condition_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 2 * num_features)
        )

    def forward(self, x, condition):
        """
        Args:
            x: 输入张量 [B, C, H, W]
            condition: 条件向量 [B, condition_dim]
        """
        # 标准批归一化
        x_normalized = self.bn(x)

        # 根据条件生成参数
        condition_params = self.condition_net(condition)  # [B, 2*C]
        gamma = condition_params[:, :self.num_features].unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        beta = condition_params[:, self.num_features:].unsqueeze(-1).unsqueeze(-1)   # [B, C, 1, 1]

        # 应用调制
        return gamma * x_normalized + beta

def demonstrate_conditional_bn():
    """演示条件批归一化"""

    print("条件批归一化演示")
    print("=" * 50)

    batch_size, channels, height, width = 4, 32, 16, 16
    condition_dim = 64

    x = torch.randn(batch_size, channels, height, width)
    condition = torch.randn(batch_size, condition_dim)

    conditional_bn = ConditionalBatchNorm2d(channels, condition_dim)
    output = conditional_bn(x, condition)

    print(f"输入形状: {x.shape}")
    print(f"条件向量形状: {condition.shape}")
    print(f"输出形状: {output.shape}")

    # 验证不同条件产生不同输出
    print(f"\n不同条件的影响:")

    condition1 = torch.randn(batch_size, condition_dim)
    condition2 = torch.randn(batch_size, condition_dim)

    output1 = conditional_bn(x, condition1)
    output2 = conditional_bn(x, condition2)

    diff = torch.abs(output1 - output2).mean()
    print(f"  不同条件输出的平均差异: {diff:.6f}")

    return {
        'conditional_bn': conditional_bn,
        'outputs': {
            'condition1': output1,
            'condition2': output2
        }
    }

conditional_bn_demo = demonstrate_conditional_bn()

特征调制(Feature-wise Modulation)

通过条件信息直接调制特征图,如 FiLM(Feature-wise Linear Modulation):

$$\text{FiLM}(x|c) = \gamma( c) \odot x + \beta( c)$$

其中 \( \gamma( c), \beta( c) \) 是根据条件 \( c \) 生成的调制参数。

class FiLMLayer(nn.Module):
    """FiLM(Feature-wise Linear Modulation)层"""

    def __init__(self, feature_dim, condition_dim, hidden_dim=128):
        super().__init__()
        self.feature_dim = feature_dim
        self.condition_dim = condition_dim

        # 条件映射网络
        self.condition_net = nn.Sequential(
            nn.Linear(condition_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 2 * feature_dim)
        )

    def forward(self, x, condition):
        """
        Args:
            x: 输入特征 [..., feature_dim]
            condition: 条件向量 [..., condition_dim]
        """
        original_shape = x.shape

        # 重塑为 [*, feature_dim]
        x_flat = x.view(-1, self.feature_dim)
        condition_flat = condition.view(-1, self.condition_dim)

        # 生成调制参数
        condition_params = self.condition_net(condition_flat)  # [*, 2*feature_dim]
        gamma = condition_params[:, :self.feature_dim]
        beta = condition_params[:, self.feature_dim:]

        # 应用调制
        x_modulated = gamma * x_flat + beta

        # 恢复原始形状
        return x_modulated.view(original_shape)

class FiLMResBlock(nn.Module):
    """使用 FiLM 的残差块"""

    def __init__(self, in_channels, out_channels, condition_dim, stride=1):
        super().__init__()

        # 第一个卷积层
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        # FiLM 调制
        self.film1 = FiLMLayer(out_channels, condition_dim)

        # 第二个卷积层
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.film2 = FiLMLayer(out_channels, condition_dim)

        # 激活函数
        self.relu = nn.ReLU(inplace=True)

        # 下采样
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x, condition):
        identity = x

        # 主路径
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.film1(out, condition)  # FiLM 调制
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.film2(out, condition)  # FiLM 调制

        # 捷径
        if self.downsample is not None:
            identity = self.downsample(x)

        # 残差连接
        out += identity
        out = self.relu(out)

        return out

def demonstrate_film():
    """演示 FiLM 调制"""

    print("FiLM 调制演示")
    print("=" * 50)

    # 基础 FiLM 测试
    batch_size, feature_dim, condition_dim = 4, 64, 32
    x = torch.randn(batch_size, feature_dim)
    condition = torch.randn(batch_size, condition_dim)

    film_layer = FiLMLayer(feature_dim, condition_dim)
    output = film_layer(x, condition)

    print(f"输入特征形状: {x.shape}")
    print(f"条件向量形状: {condition.shape}")
    print(f"输出特征形状: {output.shape}")

    # FiLM 残差块测试
    print(f"\nFiLM 残差块演示:")

    batch_size, in_channels, height, width = 4, 32, 16, 16
    condition_dim = 64

    x_conv = torch.randn(batch_size, in_channels, height, width)
    condition_conv = torch.randn(batch_size, condition_dim)

    film_resblock = FiLMResBlock(in_channels, 64, condition_dim, stride=1)
    output_conv = film_resblock(x_conv, condition_conv)

    print(f"卷积输入形状: {x_conv.shape}")
    print(f"卷积输出形状: {output_conv.shape}")

    # 条件敏感性测试
    print(f"\n条件敏感性测试:")

    # 相同输入,不同条件
    condition_a = torch.randn(batch_size, condition_dim)
    condition_b = torch.randn(batch_size, condition_dim)

    output_a = film_resblock(x_conv, condition_a)
    output_b = film_resblock(x_conv, condition_b)

    diff = torch.abs(output_a - output_b).mean()
    print(f"  不同条件输出的平均差异: {diff:.6f}")

    return {
        'film_layer': film_layer,
        'film_resblock': film_resblock,
        'outputs': {
            'basic': output,
            'conv': output_conv,
            'condition_a': output_a,
            'condition_b': output_b
        }
    }

film_demo = demonstrate_film()

SFT

SFT(Scale and Shift Transform)是一种轻量级的特征调制方法,通过缩放(Scale)和 偏移(Shift)操作将条件信息融入特征表示中,广泛应用于图像超分辨率、风格迁移等任 务。

数学定义

对于输入特征 \( x \in \mathbb{R}^{C \times H \times W} \) 和条件信息 \( c \),SFT 计算:

$$\text{SFT}(x|c) = \gamma( c) \odot x + \beta( c)$$

其中:

网络结构

SFT 通常包含:

  1. 条件编码网络:将条件信息映射到调制参数
  2. 特征提取网络:处理输入特征
  3. SFT 层:应用缩放和偏移调制

实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class SFTLayer(nn.Module):
    """SFT(Scale and Shift Transform)层"""

    def __init__(self, channels, condition_dim, hidden_dim=64):
        """
        参数:
            channels: 输入特征通道数
            condition_dim: 条件向量维度
            hidden_dim: 隐藏层维度
        """
        super().__init__()
        self.channels = channels

        # 条件映射网络,生成缩放和偏移参数
        self.condition_net = nn.Sequential(
            nn.Linear(condition_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 2 * channels)
        )

    def forward(self, x, condition):
        """
        Args:
            x: 输入特征 [B, C, H, W]
            condition: 条件向量 [B, condition_dim]
        Returns:
            调制后的特征 [B, C, H, W]
        """
        B, C, H, W = x.shape

        # 生成调制参数
        condition_params = self.condition_net(condition)  # [B, 2*C]
        gamma = condition_params[:, :self.channels].view(B, C, 1, 1)  # [B, C, 1, 1]
        beta = condition_params[:, self.channels:].view(B, C, 1, 1)   # [B, C, 1, 1]

        # 应用缩放和偏移
        return gamma * x + beta

class SFTResidualBlock(nn.Module):
    """使用 SFT 的残差块"""

    def __init__(self, channels, condition_dim, hidden_dim=64):
        super().__init__()

        # 第一个卷积层
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)

        # SFT 调制层
        self.sft1 = SFTLayer(channels, condition_dim, hidden_dim)

        # 第二个卷积层
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

        # 第二个 SFT 调制层
        self.sft2 = SFTLayer(channels, condition_dim, hidden_dim)

        # 激活函数
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, condition):
        identity = x

        # 主路径
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.sft1(out, condition)  # SFT 调制
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.sft2(out, condition)  # SFT 调制

        # 残差连接
        out += identity
        out = self.relu(out)

        return out

class SFTNet(nn.Module):
    """完整的 SFT 网络示例(用于图像超分辨率)"""

    def __init__(self, in_channels=3, out_channels=3, base_channels=64,
                 condition_dim=32, num_blocks=16):
        super().__init__()

        # 特征提取
        self.feature_extract = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # SFT 残差块
        self.sft_blocks = nn.ModuleList([
            SFTResidualBlock(base_channels, condition_dim)
            for _ in range(num_blocks)
        ])

        # 重建层
        self.reconstruct = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x, condition):
        # 特征提取
        features = self.feature_extract(x)

        # 通过 SFT 块
        for sft_block in self.sft_blocks:
            features = sft_block(features, condition)

        # 重建
        output = self.reconstruct(features)

        return output

def demonstrate_sft():
    """演示 SFT 调制"""

    print("SFT(Scale and Shift Transform)演示")
    print("=" * 50)

    # 基础 SFT 测试
    batch_size, channels, height, width = 4, 32, 16, 16
    condition_dim = 64

    x = torch.randn(batch_size, channels, height, width)
    condition = torch.randn(batch_size, condition_dim)

    sft_layer = SFTLayer(channels, condition_dim)
    output = sft_layer(x, condition)

    print(f"输入特征形状: {x.shape}")
    print(f"条件向量形状: {condition.shape}")
    print(f"SFT 输出形状: {output.shape}")

    # 验证调制效果
    print(f"\n调制效果验证:")
    print(f"  输入均值: {x.mean().item():.4f}, 标准差: {x.std().item():.4f}")
    print(f"  输出均值: {output.mean().item():.4f}, 标准差: {output.std().item():.4f}")

    # SFT 残差块测试
    print(f"\nSFT 残差块演示:")
    sft_resblock = SFTResidualBlock(channels, condition_dim)
    output_res = sft_resblock(x, condition)

    print(f"  残差块输入形状: {x.shape}")
    print(f"  残差块输出形状: {output_res.shape}")

    # 条件敏感性测试
    print(f"\n条件敏感性测试:")

    # 相同输入,不同条件
    condition1 = torch.randn(batch_size, condition_dim)
    condition2 = torch.randn(batch_size, condition_dim)

    output1 = sft_layer(x, condition1)
    output2 = sft_layer(x, condition2)

    diff = torch.abs(output1 - output2).mean()
    print(f"  不同条件输出的平均差异: {diff:.6f}")

    # 完整 SFT 网络测试
    print(f"\n完整 SFT 网络演示:")
    sft_net = SFTNet(in_channels=3, out_channels=3, condition_dim=condition_dim)

    rgb_input = torch.randn(batch_size, 3, 32, 32)
    rgb_output = sft_net(rgb_input, condition)

    print(f"  网络输入形状: {rgb_input.shape}")
    print(f"  网络输出形状: {rgb_output.shape}")

    return {
        'sft_layer': sft_layer,
        'sft_resblock': sft_resblock,
        'sft_net': sft_net,
        'outputs': {
            'basic': output,
            'residual': output_res,
            'network': rgb_output,
            'condition1': output1,
            'condition2': output2
        }
    }

sft_demo = demonstrate_sft()

与其它调制方法的比较

def compare_modulation_methods():
    """比较不同调制方法"""

    print("调制方法比较")
    print("=" * 50)

    batch_size, channels, height, width = 4, 32, 16, 16
    condition_dim = 64

    x = torch.randn(batch_size, channels, height, width)
    condition = torch.randn(batch_size, condition_dim)

    # 不同调制方法
    methods = {
        'SFT': SFTLayer(channels, condition_dim),
        'FiLM': FiLMLayer(channels, condition_dim),
        'ConditionalBN': ConditionalBatchNorm2d(channels, condition_dim)
    }

    results = {}

    print(f"输入形状: {x.shape}")
    print(f"条件向量形状: {condition.shape}")
    print(f"\n不同调制方法输出:")

    for name, method in methods.items():
        if name == 'ConditionalBN':
            output = method(x, condition)
        else:
            output = method(x, condition)

        print(f"  {name:<15} - 形状: {output.shape}, 均值: {output.mean():.4f}, 标准差: {output.std():.4f}")

        results[name] = {
            'method': method,
            'output': output,
            'params': sum(p.numel() for p in method.parameters())
        }

    # 参数量比较
    print(f"\n参数量比较:")
    for name, result in results.items():
        print(f"  {name:<15}: {result['params']:,} 参数")

    # 计算效率比较
    print(f"\n计算效率比较 (100次前向传播):")
    import time

    for name, result in results.items():
        method = result['method']

        # 预热
        if name == 'ConditionalBN':
            _ = method(x, condition)
        else:
            _ = method(x, condition)

        # 计时
        start = time.time()
        for _ in range(100):
            if name == 'ConditionalBN':
                _ = method(x, condition)
            else:
                _ = method(x, condition)
        elapsed = time.time() - start

        print(f"  {name:<15}: {elapsed:.4f}s")

    return results

modulation_comparison = compare_modulation_methods()

*/ 应用场景

  1. /图像超分辨率/:根据退化信息调制特征
  2. /风格迁移/:根据风格条件调整内容特征
  3. /条件图像生成/:根据语义条件生成图像
  4. /多模态学习/:融合不同模态的信息

*/ 优势与特点

优势:

特点:

Gate

门控机制(Gating Mechanism)是深度学习中的关键技术,通过可学习的开关控制信息流动, 在循环神经网络、注意力机制和现代Transformer架构中发挥重要作用。

基本概念

门控机制通过sigmoid激活函数生成0到1之间的门控值,控制信息的保留与遗忘:

$$g = \sigma(Wx + b)$$
$$y = g \odot x$$

其中 \( \sigma \) 是sigmoid函数,\( \odot \) 表示逐元素相乘。

常见门控类型

LSTM 门控

长短期记忆网络(LSTM)包含三个门控单元:

数学表达式:

\begin{aligned} i_t &= \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ f_t &= \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \\ o_t &= \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \end{aligned}

GRU 门控

门控循环单元(GRU)包含两个门控:

数学表达式:

$$</p> <p>\begin{aligned} r_t &= \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \\ z_t &= \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \end{aligned}</p> <p>$$

门控线性单元(GLU)

GLU通过sigmoid门控控制线性变换的输出:

$$\text{GLU}(x) = (W_1x + b_1) \odot \sigma(W_2x + b_2)$$

实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class GLU(nn.Module):
    """门控线性单元(Gated Linear Unit)"""

    def __init__(self, input_dim, output_dim=None):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim

        self.linear = nn.Linear(input_dim, 2 * output_dim)
        self.gate_activation = nn.Sigmoid()

    def forward(self, x):
        """
        Args:
            x: 输入张量 [..., input_dim]
        Returns:
            门控输出 [..., output_dim]
        """
        # 线性变换
        projected = self.linear(x)

        # 分割为值和门控
        value, gate = torch.chunk(projected, 2, dim=-1)

        # 应用门控
        return value * self.gate_activation(gate)

class GatedResidualBlock(nn.Module):
    """门控残差块"""

    def __init__(self, hidden_dim, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim

        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * hidden_dim, hidden_dim),
            nn.Dropout(dropout)
        )

        # 门控权重
        self.gate_weight = nn.Parameter(torch.zeros(1, hidden_dim))

        # 层归一化
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        """
        Args:
            x: 输入张量 [..., hidden_dim]
        Returns:
            门控残差输出
        """
        residual = x

        # 前馈网络
        ffn_output = self.ffn(self.norm(x))

        # 计算门控值
        gate = torch.sigmoid(self.gate_weight)

        # 门控残差连接
        output = gate * ffn_output + (1 - gate) * residual

        return output

class MultiHeadGatedAttention(nn.Module):
    """多头门控注意力"""

    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        assert self.head_dim * n_heads == d_model, "d_model必须能被n_heads整除"

        # QKV投影
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)

        # 输出投影
        self.w_o = nn.Linear(d_model, d_model)

        # 门控参数
        self.gate = nn.Parameter(torch.zeros(1, 1, d_model))

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x, mask=None):
        """
        Args:
            x: 输入序列 [batch_size, seq_len, d_model]
            mask: 注意力掩码 [batch_size, seq_len, seq_len]
        """
        batch_size, seq_len, d_model = x.shape

        # QKV投影
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        # 注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # 注意力权重
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 注意力输出
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        # 输出投影
        output = self.w_o(attn_output)

        # 门控残差连接
        gate = torch.sigmoid(self.gate)
        output = gate * output + (1 - gate) * x

        return output

def demonstrate_gating():
    """演示门控机制"""

    print("门控机制演示")
    print("=" * 50)

    # GLU 演示
    print("GLU(门控线性单元):")
    batch_size, seq_len, hidden_dim = 4, 10, 64
    x = torch.randn(batch_size, seq_len, hidden_dim)

    glu = GLU(hidden_dim)
    glu_output = glu(x)

    print(f"  输入形状: {x.shape}")
    print(f"  输出形状: {glu_output.shape}")
    print(f"  输出范围: [{glu_output.min():.3f}, {glu_output.max():.3f}]")

    # 门控残差块演示
    print(f"\n门控残差块:")
    gated_residual = GatedResidualBlock(hidden_dim)
    residual_output = gated_residual(x)

    print(f"  输入形状: {x.shape}")
    print(f"  输出形状: {residual_output.shape}")
    print(f"  门控权重均值: {torch.sigmoid(gated_residual.gate_weight).mean().item():.3f}")

    # 多头门控注意力演示
    print(f"\n多头门控注意力:")
    gated_attention = MultiHeadGatedAttention(hidden_dim, n_heads=8)
    attention_output = gated_attention(x)

    print(f"  输入形状: {x.shape}")
    print(f"  输出形状: {attention_output.shape}")
    print(f"  注意力门控均值: {torch.sigmoid(gated_attention.gate).mean().item():.3f}")

    # 门控效果分析
    print(f"\n门控效果分析:")

    # 测试不同输入的门控行为
    test_inputs = [
        torch.randn(1, hidden_dim),  # 随机输入
        torch.zeros(1, hidden_dim),  # 零输入
        torch.ones(1, hidden_dim)    # 单位输入
    ]

    input_names = ["随机输入", "零输入", "单位输入"]

    for input_tensor, name in zip(test_inputs, input_names):
        glu_out = glu(input_tensor)
        gate_values = torch.sigmoid(glu.linear(input_tensor).chunk(2, dim=-1)[1])

        print(f"  {name}:")
        print(f"    门控值范围: [{gate_values.min():.3f}, {gate_values.max():.3f}]")
        print(f"    门控值均值: {gate_values.mean():.3f}")

    return {
        'glu': glu,
        'gated_residual': gated_residual,
        'gated_attention': gated_attention,
        'outputs': {
            'glu': glu_output,
            'residual': residual_output,
            'attention': attention_output
        }
    }

gating_demo = demonstrate_gating()

*/ 门控机制的优势

  1. /梯度流控制/:缓解梯度消失问题,改善深层网络训练
  2. /信息选择/:自适应地选择重要信息,抑制噪声
  3. /长期依赖/:在序列模型中更好地捕捉长期依赖关系
  4. /模型容量/:增加模型表达能力而不显著增加参数

*/ 应用场景

  1. /语言建模/:LSTM、GRU中的门控机制
  2. /机器翻译/:Transformer中的门控前馈网络
  3. /图像生成/:GAN中的门控卷积
  4. /推荐系统/:门控注意力网络
  5. /多模态学习/:跨模态信息门控融合

PixelShuffle

PixelShuffle(像素重排)是一种上采样操作,通过重新排列特征图的通道维度来增加空间 分辨率,广泛应用于图像超分辨率、图像生成等任务。

基本概念

PixelShuffle 通过周期性的重排操作将通道维度中的信息重新组织到空间维度,实现高效 的上采样:

$$\text{PixelShuffle}(X): \mathbb{R}^{C \times H \times W} \rightarrow \mathbb{R}^{\frac{C}{r^2} \times rH \times rW}$$

其中 \( r \) 是上采样因子。

数学定义

给定输入张量 \( X \in \mathbb{R}^{C \times H \times W} \) 和上采样因子 \( r \),输出计算为:

$$Y_{c,i,j} = X_{c \cdot r^2 + (i \mod r) \cdot r + (j \mod r), \lfloor i/r \rfloor, \lfloor j/r \rfloor}$$

更直观地,将输入通道维度视为 \( r \times r \) 个块,每个块对应输出特征图的一个空间位置。

实现原理

import torch
import torch.nn as nn
import torch.nn.functional as F

def manual_pixel_shuffle(x, upscale_factor):
    """
    手动实现 PixelShuffle
    Args:
        x: 输入张量 [B, C, H, W]
        upscale_factor: 上采样因子
    Returns:
        重排后的张量 [B, C/(r^2), r*H, r*W]
    """
    batch_size, channels, height, width = x.shape
    r = upscale_factor

    # 检查通道数是否可被 r^2 整除
    assert channels % (r * r) == 0, f"通道数 {channels} 必须能被 {r*r} 整除"

    # 重塑为 [B, C/(r^2), r, r, H, W]
    out_channels = channels // (r * r)
    x_reshaped = x.view(batch_size, out_channels, r, r, height, width)

    # 置换维度为 [B, C/(r^2), H, r, W, r]
    x_permuted = x_reshaped.permute(0, 1, 4, 2, 5, 3)

    # 重塑为 [B, C/(r^2), r*H, r*W]
    output = x_permuted.contiguous().view(batch_size, out_channels, height * r, width * r)

    return output

class PixelShuffleBlock(nn.Module):
    """PixelShuffle 上采样块"""

    def __init__(self, in_channels, out_channels, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor

        # 卷积层将通道数扩展到 r^2 * out_channels
        self.conv = nn.Conv2d(
            in_channels,
            out_channels * (upscale_factor ** 2),
            kernel_size=3,
            padding=1
        )

        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        x = self.pixel_shuffle(x)
        return x

class ESPCN(nn.Module):
    """ESPCN(Efficient Sub-Pixel Convolutional Network)示例"""

    def __init__(self, in_channels=3, out_channels=3, upscale_factor=2, base_channels=64):
        super().__init__()
        self.upscale_factor = upscale_factor

        # 特征提取层
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, kernel_size=5, padding=2),
            nn.Tanh(),
            nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
            nn.Tanh(),
        )

        # 子像素卷积层
        self.subpixel_conv = nn.Conv2d(
            base_channels // 2,
            out_channels * (upscale_factor ** 2),
            kernel_size=3,
            padding=1
        )

        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        # 特征提取
        features = self.feature_extraction(x)

        # 子像素卷积和像素重排
        output = self.subpixel_conv(features)
        output = self.pixel_shuffle(output)

        return output

def
 demonstrate_pixel_shuffle():
    """演示 PixelShuffle 操作"""

    print("PixelShuffle 演示")
    print("=" * 50)

    # 基础 PixelShuffle 测试
    batch_size, channels, height, width = 4, 16, 8, 8
    upscale_factor = 2

    x = torch.randn(batch_size, channels, height, width)

    print(f"输入张量形状: {x.shape}")
    print(f"上采样因子: {upscale_factor}")

    # 使用 PyTorch 内置函数
    pixel_shuffle = nn.PixelShuffle(upscale_factor)
    output_pt = pixel_shuffle(x)

    print(f"PyTorch PixelShuffle 输出形状: {output_pt.shape}")

    # 手动实现验证
    output_manual = manual_pixel_shuffle(x, upscale_factor)

    print(f"手动实现输出形状: {output_manual.shape}")

    # 验证两种实现的一致性
    diff = torch.abs(output_pt - output_manual).max()
    print(f"两种实现的最大差异: {diff.item():.6f}")

    # PixelShuffle 块演示
    print(f"\nPixelShuffle 块演示:")
    shuffle_block = PixelShuffleBlock(16, 8, upscale_factor=2)
    block_output = shuffle_block(x)

    print(f"块输入形状: {x.shape}")
    print(f"块输出形状: {block_output.shape}")
    print(f"卷积层权重形状: {shuffle_block.conv.weight.shape}")

    # ESPCN 网络演示
    print(f"\nESPCN 网络演示:")
    espcn = ESPCN(in_channels=3, out_channels=3, upscale_factor=2)

    rgb_input = torch.randn(batch_size, 3, 16, 16)
    rgb_output = espcn(rgb_input)

    print(f"网络输入形状: {rgb_input.shape}")
    print(f"网络输出形状: {rgb_output.shape}")

    # 不同上采样因子的效果
    print(f"\n不同上采样因子的效果:")
    upscale_factors = [2, 3, 4]

    for factor in upscale_factors:
        # 调整输入通道数以匹配要求
        adjusted_channels = 9 * factor * factor  # 确保可被 r^2 整除
        test_input = torch.randn(batch_size, adjusted_channels, 8, 8)

        shuffle_layer = nn.PixelShuffle(factor)
        test_output = shuffle_layer(test_input)

        print(f"  上采样因子 {factor}: 输入 {test_input.shape} -> 输出 {test_output.shape}")

    return {
        'pixel_shuffle': pixel_shuffle,
        'shuffle_block': shuffle_block,
        'espcn': espcn,
        'outputs': {
            'pytorch': output_pt,
            'manual': output_manual,
            'block': block_output,
            'espcn': rgb_output
        }
    }

pixel_shuffle_demo = demonstrate_pixel_shuffle()

与其它上采样方法比较

def compare_upsampling_methods():
    """比较不同上采样方法"""

    print("上采样方法比较")
    print("=" * 50)

    batch_size, channels, height, width = 4, 32, 16, 16
    upscale_factor = 2

    x = torch.randn(batch_size, channels, height, width)

    print(f"输入形状: {x.shape}")
    print(f"目标上采样因子: {upscale_factor}")

    # 不同上采样方法
    methods = {
        '最近邻插值': lambda x: F.interpolate(x, scale_factor=upscale_factor, mode='nearest'),
        '双线性插值': lambda x: F.interpolate(x, scale_factor=upscale_factor, mode='bilinear', align_corners=False),
        '双三次插值': lambda x: F.interpolate(x, scale_factor=upscale_factor, mode='bicubic', align_corners=False),
        '转置卷积': nn.ConvTranspose2d(channels, channels, kernel_size=4, stride=upscale_factor, padding=1),
        'PixelShuffle': nn.PixelShuffle(upscale_factor)
    }

    results = {}

    print(f"\n不同上采样方法输出:")

    for name, method in methods.items():
        if name == 'PixelShuffle':
            # 调整输入通道数
            adjusted_channels = channels // (upscale_factor ** 2) * (upscale_factor ** 2)
            if adjusted_channels == 0:
                adjusted_channels = upscale_factor ** 2
            test_input = x[:, :adjusted_channels, :, :]
            output = method(test_input)
        else:
            output = method(x)

        print(f"  {name:<15} - 输出形状: {output.shape}")

        # 计算计算量(近似)
        if hasattr(method, 'weight'):
            params = method.weight.numel() + (method.bias.numel() if method.bias is not None else 0)
        else:
            params = 0

        results[name] = {
            'output': output,
            'params': params,
            'method': method
        }

    #
 计算效率比较
    print(f"\n计算效率比较 (100次前向传播):")
    import time

    for name, result in results.items():
        method = result['method']

        if name == 'PixelShuffle':
            test_input = x[:, :adjusted_channels, :, :]
        else:
            test_input = x

        # 预热
        _ = method(test_input)

        # 计时
        start = time.time()
        for _ in range(100):
            _ = method(test_input)
        elapsed = time.time() - start

        print(f"  {name:<15}: {elapsed:.4f}s, 参数量: {result['params']:,}")

    # 输出质量比较(使用简单的图像相似度指标)
    print(f"\n输出质量比较 (使用PSNR):")

    # 创建简单的测试图像
    test_image = torch.zeros(1, 1, 8, 8)
    test_image[0, 0, 3:5, 3:5] = 1.0  # 中心方块

    target_size = (16, 16)

    for name, method in methods.items():
        if name == 'PixelShuffle':
            # 为 PixelShuffle 准备合适的输入
            ps_input = torch.randn(1, 4, 8, 8)  # 4 = 1 * 2^2
            output = method(ps_input)
        else:
            output = method(test_image)

        # 调整到目标尺寸(如果需要)
        if output.shape[-2:] != target_size:
            output = F.interpolate(output, size=target_size, mode='bilinear')

        # 计算与双线性插值的差异(作为
质量参考)
        reference = F.interpolate(test_image, size=target_size, mode='bilinear')
        mse = F.mse_loss(output, reference)
        psnr = 10 * torch.log10(1.0 / mse) if mse > 0 else float('inf')

        print(f"  {name:<15}: PSNR = {psnr:.2f} dB")

    return results

upsampling_comparison = compare_upsampling_methods()

优势与特点

*/ 优势

  1. /计算效率/:相比转置卷积,计算量更小
  2. /无棋盘效应/:减少上采样过程中的棋盘状伪影
  3. /端到端可训练/:支持梯度反向传播
  4. /内存友好/:不需要存储大的卷积核

*/ 特点

  1. /通道要求/:输入通道数必须能被 \( r^2 \) 整除
  2. /信息重排/:通过通道维度的重排实现空间上采样
  3. /结合卷积/:通常与卷积层配合使用(子像素卷积)

应用场景

  1. /图像超分辨率/:ESPCN、SRCNN 等网络
  2. /图像生成/:GAN 中的上采样层
  3. /语义分割/:解码器中的上采样操作
  4. /风格迁移/:特征图的空间分辨率提升
  5. /视频超分/:视频帧的时空分辨率提升

变体与扩展

*/ PixelUnshuffle

PixelShuffle 的逆操作,用于下采样:

$$\text{PixelUnshuffle}(X): \mathbb{R}^{C \times H \times W} \rightarrow \mathbb{R}^{C \times r^2 \times H/r \times W/r}$$
def demonstrate_pixel_unshuffle():
    """演示 PixelUnshuffle 操作"""

    print("PixelUnshuffle 演示")
    print("=" * 50)

    batch_size, channels, height, width = 4, 8, 16, 16
    downscale_factor = 2

    x = torch.randn(batch_size, channels, height, width)

    print(f"输入张量形状: {x.shape}")
    print(f"下采样因子: {downscale_factor}")

    # PixelUnshuffle
    pixel_unshuffle = nn.PixelUnshuffle(downscale_factor)
    output = pixel_unshuffle(x)

    print(f"PixelUnshuffle 输出形状: {output.shape}")

    # 验证可逆性
    pixel_shuffle = nn.PixelShuffle(downscale_factor)
    reconstructed = pixel_shuffle(output)

    print(f"重建后形状: {reconstructed.shape}")

    # 验证重建精度
    diff = torch.abs(x - reconstructed).max()
    print(f"重建最大误差: {diff.item():.6f}")

    return {
        'pixel_unshuffle': pixel_unshuffle,
        'reconstructed': reconstructed,
        'max_error': diff.item()
    }

unshuffle_demo = demonstrate_pixel_unshuffle()

*/ 深度可分离 PixelShuffle

结合深度可分离卷积的变体,进一步减少计算量:

class DepthwiseSeparablePixelShuffle(nn.Module):
    """深度可分离 PixelShuffle"""

    def __init__(self, in_channels, out_channels, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor

        # 深度可分离卷积
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size=3,
            padding=1, groups=in_channels
        )
        self.pointwise = nn.Conv2d(
            in_channels, out_channels * (upscale_factor ** 2),
            kernel_size=1
        )

        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.activation(x)
        x = self.pixel_shuffle(x)
        return x

def compare_pixel_shuffle_variants():
    """比较 PixelShuffle 变体"""

    print("PixelShuffle 变体比较")
    print("=" * 50)

    batch_size, in_channels, height, width = 4, 32, 16, 16
    out_channels = 16
    upscale_factor = 2

    x = torch.randn(batch_size, in_channels, height, width)

    variants = {
        '标准 PixelShuffle': PixelShuffleBlock(in_channels, out_channels, upscale_factor),
        '深度可分离 PixelShuffle': DepthwiseSeparablePixelShuffle(in_channels, out_channels, upscale_factor)
    }

    print(f"输入形状: {x.shape}")
    print(f"目标输出通道: {out_channels}")
    print(f"上采样因子: {upscale_factor}")

    results = {}

    for name, variant in variants.items():
        output = variant(x)

        # 计算参数量
        params = sum(p.numel() for p in variant.parameters())

        print(f"\n{name}:")
        print(f"  输出形状: {output.shape}")
        print(f"  参数量: {params:,}")

        results[name] = {
            'variant': variant,
            'output': output,
            'params': params
        }

    # 参数量减少比例
    standard_params = results['标准 PixelShuffle']['params']
    separable_params = results['深度可分离 PixelShuffle']['params']
    reduction = (1 - separable_params / standard_params) * 100

    print(f"\n参数量减少: {reduction:.1f}%")

    return results

variants_comparison = compare_pixel_shuffle_variants()