← 返回首页
🧠

MoE路由策略

📂 llm ⏱ 3 min 439 words

--- title: "MoE路由策略" description: "深入解析MoE模型中的路由策略,包括Top-K路由、路由坍缩问题和动态路由机制,帮助优化专家选择" tags: ["路由策略", "Top-K路由", "路由坍缩", "动态路由"] category: "llm" icon: "🧠"

MoE路由策略

路由机制概述

路由策略是MoE模型的核心组件,负责决定每个输入token应该被分配到哪些专家。路由的质量直接影响模型的性能、训练稳定性和计算效率。一个好的路由策略需要在专家专业化和负载均衡之间找到平衡。

Top-K路由

基础Top-K路由

Top-K路由是最常用的策略,每个token选择得分最高的K个专家。

import torch
import torch.nn as nn
import torch.nn.functional as F

class TopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, top_k=2):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
        self.top_k = top_k
    
    def forward(self, x):
        # 计算门控分数
        logits = self.gate(x)  # [batch, seq_len, num_experts]
        
        # 选择top-k专家
        topk_values, topk_indices = torch.topk(logits, self.top_k, dim=-1)
        
        # 计算权重(softmax只在top-k上)
        topk_weights = F.softmax(topk_values, dim=-1)
        
        return topk_weights, topk_indices, logits

带噪声的Top-K路由

添加噪声可以改善探索性,防止路由坍缩。

class NoisyTopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, top_k=2, noise_std=0.1):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
        self.top_k = top_k
        self.noise_std = noise_std
    
    def forward(self, x, training=True):
        logits = self.gate(x)
        
        if training:
            # 添加高斯噪声
            noise = torch.randn_like(logits) * self.noise_std
            logits = logits + noise
        
        topk_values, topk_indices = torch.topk(logits, self.top_k, dim=-1)
        topk_weights = F.softmax(topk_values, dim=-1)
        
        return topk_weights, topk_indices, logits

路由坍缩问题

什么是路由坍缩

路由坍缩是指所有或大部分token都被路由到少数几个专家,导致其他专家得不到训练。

def detect_routing_collapse(gate_probs, threshold=0.8):
    """检测路由坍缩"""
    # gate_probs: [batch, seq_len, num_experts]
    
    # 计算每个专家被选择的频率
    expert_usage = (gate_probs.argmax(dim=-1) != -1).float().mean(dim=[0, 1])
    
    # 检查是否存在坍缩
    max_usage = expert_usage.max()
    min_usage = expert_usage.min()
    
    collapse_detected = max_usage > threshold
    
    return {
        "collapse_detected": collapse_detected,
        "max_expert_usage": max_usage.item(),
        "min_expert_usage": min_usage.item(),
        "usage_ratio": (max_usage / (min_usage + 1e-8)).item()
    }

路由坍缩的成因

# 路由坍缩的常见原因
causes = {
    "初始化问题": "路由器权重初始化不当,导致某些专家获得初始优势",
    "梯度不平衡": "被频繁选择的专家获得更多梯度更新,形成正反馈",
    "数据偏差": "训练数据分布不均,某些模式过于常见",
    "学习率过高": "路由器参数更新过快,过早锁定在次优解"
}

动态路由策略

自适应Top-K

根据输入的复杂度动态调整选择的专家数量。

class AdaptiveTopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, min_k=1, max_k=4):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
        self.complexity_gate = nn.Linear(hidden_dim, 1)
        self.min_k = min_k
        self.max_k = max_k
    
    def forward(self, x):
        logits = self.gate(x)
        
        # 估计输入复杂度
        complexity = torch.sigmoid(self.complexity_gate(x))
        
        # 根据复杂度决定top-k
        k = self.min_k + ((self.max_k - self.min_k) * complexity).int()
        k = k.squeeze(-1)  # [batch, seq_len]
        
        # 为每个token动态选择不同数量的专家
        batch, seq_len, num_experts = logits.shape
        output_weights = torch.zeros_like(logits)
        output_indices = torch.zeros(batch, seq_len, self.max_k, dtype=torch.long)
        
        for b in range(batch):
            for s in range(seq_len):
                current_k = k[b, s].item()
                topk_values, topk_indices = torch.topk(logits[b, s], current_k)
                weights = F.softmax(topk_values, dim=-1)
                
                output_weights[b, s, :current_k] = weights
                output_indices[b, s, :current_k] = topk_indices
        
        return output_weights, output_indices, logits

专家选择路由

反转路由方向,让专家选择token而非token选择专家。

class ExpertChoiceRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, capacity_factor=1.25):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
        self.capacity_factor = capacity_factor
    
    def forward(self, x):
        # x: [batch, seq_len, hidden_dim]
        logits = self.gate(x)  # [batch, seq_len, num_experts]
        probs = F.softmax(logits, dim=1)  # 在token维度上softmax
        
        # 每个专家选择得分最高的token
        tokens_per_expert = int(x.shape[1] * self.capacity_factor)
        
        expert_outputs = []
        for e in range(logits.shape[-1]):
            # 每个专家选择top tokens
            expert_probs, expert_tokens = torch.topk(probs[:, :, e], tokens_per_expert, dim=1)
            selected_tokens = torch.gather(x, 1, expert_tokens.unsqueeze(-1).expand(-1, -1, x.shape[-1]))
            
            expert_outputs.append({
                "tokens": selected_tokens,
                "weights": expert_probs
            })
        
        return expert_outputs

负载均衡路由

辅助负载均衡损失

def load_balancing_loss(gate_probs, num_experts):
    """辅助负载均衡损失"""
    # gate_probs: [batch, seq_len, num_experts]
    
    # 计算每个专家的路由概率
    router_probs = gate_probs.mean(dim=[0, 1])  # [num_experts]
    
    # 计算每个专家的使用频率
    expert_mask = (gate_probs.argmax(dim=-1) != -1).float()
    expert_usage = expert_mask.mean(dim=[0, 1])  # [num_experts]
    
    # 负载均衡损失
    aux_loss = num_experts * (router_probs * expert_usage).sum()
    
    return aux_loss

Z-Loss稳定训练

def z_loss(gate_logits, z_loss_coeff=0.001):
    """Z-Loss:稳定路由器训练"""
    # gate_logits: [batch, seq_len, num_experts]
    
    # 计算logits的平方和
    z = gate_logits.float().square().sum(dim=-1).mean()
    
    return z_loss_coeff * z

路由策略对比

策略 优点 缺点 适用场景
Top-K 简单高效 易坍缩 通用
带噪声Top-K 改善探索 训练不稳定 大规模训练
自适应Top-K 灵活 计算开销大 复杂任务
专家选择 负载均衡好 实现复杂 负载敏感场景

总结

路由策略是MoE模型设计的关键。Top-K路由是基础,但需要配合负载均衡损失和噪声注入来防止路由坍缩。动态路由和专家选择路由提供了更灵活的方案,但实现复杂度更高。选择合适的路由策略需要根据具体任务和资源约束进行权衡。