← 返回首页
🧠

联邦学习LLM

📂 llm ⏱ 7 min 1266 words

--- title: "联邦学习LLM" description: "介绍联邦学习在大语言模型中的应用,包括安全聚合、通信优化、联邦微调等关键技术" tags: ["联邦学习", "安全聚合", "通信优化", "联邦微调"] category: "llm" icon: "🧠"

联邦学习LLM

联邦学习概述

联邦学习(Federated Learning)是一种分布式机器学习方法,允许多个参与方在不共享原始数据的情况下协作训练模型。在LLM领域,联邦学习可以解决数据孤岛、隐私保护和合规性等问题。

联邦学习的基本流程

import torch
import torch.nn as nn
from typing import List, Dict, Any

class FederatedLearningSystem:
    def __init__(self, num_clients: int, global_model: nn.Module):
        self.num_clients = num_clients
        self.global_model = global_model
        self.client_models = []
        self.aggregation_weights = [1.0 / num_clients] * num_clients
    
    def federated_averaging(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """联邦平均聚合(FedAvg)"""
        global_dict = self.global_model.state_dict()
        
        # 加权平均
        for key in global_dict.keys():
            global_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)
            
            for i, client_update in enumerate(client_updates):
                if key in client_update:
                    global_dict[key] += self.aggregation_weights[i] * client_update[key].float()
        
        return global_dict
    
    def distribute_global_model(self) -> Dict[str, torch.Tensor]:
        """分发全局模型到客户端"""
        return self.global_model.state_dict().copy()
    
    def train_round(self, client_data: List[Any], local_epochs: int = 5) -> None:
        """执行一轮联邦训练"""
        # 1. 分发全局模型
        global_state = self.distribute_global_model()
        
        # 2. 客户端本地训练
        client_updates = []
        for i, data in enumerate(client_data):
            client_model = self._create_client_model(global_state)
            client_update = self._local_training(client_model, data, local_epochs)
            client_updates.append(client_update)
        
        # 3. 聚合更新
        new_global_state = self.federated_averaging(client_updates)
        self.global_model.load_state_dict(new_global_state)
    
    def _create_client_model(self, state_dict: Dict[str, torch.Tensor]) -> nn.Module:
        """创建客户端模型"""
        client_model = self._copy_model(self.global_model)
        client_model.load_state_dict(state_dict)
        return client_model
    
    def _local_training(self, model: nn.Module, data: Any, epochs: int) -> Dict[str, torch.Tensor]:
        """客户端本地训练"""
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        model.train()
        for epoch in range(epochs):
            # 模拟本地训练
            pass
        
        # 返回模型更新(差值)
        return model.state_dict()
    
    def _copy_model(self, model: nn.Module) -> nn.Module:
        """复制模型"""
        import copy
        return copy.deepcopy(model)

安全聚合

同态加密聚合

import numpy as np
from typing import List

class HomomorphicEncryption:
    def __init__(self):
        # 简化的同态加密实现
        self.public_key = None
        self.private_key = None
    
    def key_generation(self, key_size: int = 2048):
        """密钥生成"""
        # 实际实现中使用RSA或Paillier加密
        self.public_key = {"n": 1234567890, "g": 12345}
        self.private_key = {"lambda": 12345, "mu": 12345}
    
    def encrypt(self, plaintext: float) -> float:
        """加密"""
        # 简化实现
        return plaintext * self.public_key["g"]
    
    def decrypt(self, ciphertext: float) -> float:
        """解密"""
        # 简化实现
        return ciphertext / self.public_key["g"]
    
    def add_encrypted(self, enc1: float, enc2: float) -> float:
        """密文加法"""
        return enc1 + enc2
    
    def multiply_encrypted(self, enc: float, scalar: float) -> float:
        """密文标量乘法"""
        return enc * scalar

class SecureAggregation:
    def __init__(self, num_clients: int):
        self.num_clients = num_clients
        self.he = HomomorphicEncryption()
        self.he.key_generation()
    
    def secure_aggregate(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """安全聚合"""
        aggregated = {}
        
        for key in client_updates[0].keys():
            # 加密每个客户端的更新
            encrypted_updates = []
            for client_update in client_updates:
                encrypted = self.he.encrypt(client_update[key].float().numpy())
                encrypted_updates.append(encrypted)
            
            # 安全聚合(密文加法)
            encrypted_sum = encrypted_updates[0]
            for enc_update in encrypted_updates[1:]:
                encrypted_sum = self.he.add_encrypted(encrypted_sum, enc_update)
            
            # 解密聚合结果
            aggregated_value = self.he.decrypt(encrypted_sum)
            
            # 平均
            aggregated[key] = torch.tensor(aggregated_value / self.num_clients)
        
        return aggregated
    
    def secure_weighted_aggregate(self, client_updates: List[Dict[str, torch.Tensor]], 
                                 weights: List[float]) -> Dict[str, torch.Tensor]:
        """安全加权聚合"""
        aggregated = {}
        
        for key in client_updates[0].keys():
            encrypted_updates = []
            
            for i, client_update in enumerate(client_updates):
                # 加密并加权
                encrypted = self.he.encrypt(client_update[key].float().numpy())
                weighted = self.he.multiply_encrypted(encrypted, weights[i])
                encrypted_updates.append(weighted)
            
            # 安全聚合
            encrypted_sum = encrypted_updates[0]
            for enc_update in encrypted_updates[1:]:
                encrypted_sum = self.he.add_encrypted(encrypted_sum, enc_update)
            
            # 解密
            aggregated_value = self.he.decrypt(encrypted_sum)
            aggregated[key] = torch.tensor(aggregated_value)
        
        return aggregated

安全多方计算聚合

class SecureMultiPartyAggregation:
    def __init__(self, num_parties: int):
        self.num_parties = num_parties
    
    def secret_sharing(self, secret: torch.Tensor) -> List[torch.Tensor]:
        """秘密共享"""
        shares = []
        remaining = secret.clone()
        
        for i in range(self.num_parties - 1):
            share = torch.randn_like(secret)
            shares.append(share)
            remaining -= share
        
        shares.append(remaining)
        return shares
    
    def secure_addition(self, shares_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
        """安全加法"""
        result_shares = []
        
        for i in range(self.num_parties):
            share_sum = torch.zeros_like(shares_list[0][i])
            for client_shares in shares_list:
                share_sum += client_shares[i]
            result_shares.append(share_sum)
        
        return result_shares
    
    def reveal(self, shares: List[torch.Tensor]) -> torch.Tensor:
        """揭示秘密"""
        return sum(shares)
    
    def mpc_aggregate(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """MPC聚合"""
        aggregated = {}
        
        for key in client_updates[0].keys():
            # 每个客户端秘密共享其更新
            client_shares = []
            for client_update in client_updates:
                shares = self.secret_sharing(client_update[key])
                client_shares.append(shares)
            
            # 安全聚合
            result_shares = self.secure_addition(client_shares)
            
            # 揭示聚合结果
            aggregated[key] = self.reveal(result_shares) / len(client_updates)
        
        return aggregated

通信优化

模型压缩

class CommunicationOptimizer:
    def __init__(self, compression_rate: float = 0.1):
        self.compression_rate = compression_rate
    
    def quantize_update(self, update: torch.Tensor, bits: int = 8) -> torch.Tensor:
        """量化模型更新"""
        # 计算量化范围
        min_val = update.min()
        max_val = update.max()
        
        # 量化
        scale = (max_val - min_val) / (2 ** bits - 1)
        quantized = torch.round((update - min_val) / scale)
        
        # 反量化
        dequantized = quantized * scale + min_val
        
        return dequantized
    
    def sparsify_update(self, update: torch.Tensor) -> torch.Tensor:
        """稀疏化模型更新"""
        # 计算阈值(保留top-k%的参数)
        k = int(update.numel() * self.compression_rate)
        threshold = torch.kthvalue(update.abs().flatten(), k).values
        
        # 创建掩码
        mask = update.abs() >= threshold
        
        # 应用掩码
        sparse_update = update * mask
        
        return sparse_update
    
    def compress_update(self, update: torch.Tensor) -> Dict[str, Any]:
        """压缩模型更新"""
        # 稀疏化
        sparse_update = self.sparsify_update(update)
        
        # 提取非零值
        non_zero_indices = torch.nonzero(sparse_update)
        non_zero_values = sparse_update[sparse_update != 0]
        
        # 量化非零值
        quantized_values = self.quantize_update(non_zero_values)
        
        compressed = {
            "indices": non_zero_indices,
            "values": quantized_values,
            "shape": update.shape,
            "compression_rate": self.compression_rate
        }
        
        return compressed
    
    def decompress_update(self, compressed: Dict[str, Any]) -> torch.Tensor:
        """解压模型更新"""
        # 重建稀疏张量
        sparse_update = torch.zeros(compressed["shape"])
        sparse_update[compressed["indices"]] = compressed["values"]
        
        return sparse_update
    
    def gradient_accumulation(self, local_updates: List[torch.Tensor], 
                            accumulation_steps: int) -> torch.Tensor:
        """梯度累积"""
        accumulated = torch.zeros_like(local_updates[0])
        
        for i, update in enumerate(local_updates):
            accumulated += update / accumulation_steps
        
        return accumulated

通信协议优化

class FederatedCommunicationProtocol:
    def __init__(self, num_clients: int, communication_rounds: int):
        self.num_clients = num_clients
        self.communication_rounds = communication_rounds
        self.bandwidth_limits = [1000] * num_clients  # 模拟带宽限制
    
    def adaptive_client_selection(self, client_performance: List[float]) -> List[int]:
        """自适应客户端选择"""
        # 根据性能选择客户端
        sorted_indices = np.argsort(client_performance)[::-1]
        selected = sorted_indices[:self.num_clients // 2]
        
        return selected.tolist()
    
    def asynchronous_update(self, client_updates: Dict[int, torch.Tensor], 
                          staleness: Dict[int, int]) -> torch.Tensor:
        """异步更新"""
        # 根据延迟调整权重
        weights = {}
        total_weight = 0
        
        for client_id, update in client_updates.items():
            # 延迟越高,权重越低
            weight = 1.0 / (1 + staleness[client_id])
            weights[client_id] = weight
            total_weight += weight
        
        # 加权聚合
        aggregated = torch.zeros_like(list(client_updates.values())[0])
        for client_id, update in client_updates.items():
            aggregated += weights[client_id] / total_weight * update
        
        return aggregated
    
    def differential_compression(self, update: torch.Tensor, 
                               previous_update: torch.Tensor) -> torch.Tensor:
        """差分压缩"""
        # 计算差值
        delta = update - previous_update
        
        # 只传输差值
        compressed_delta = self.compress_update(delta)
        
        return compressed_delta
    
    def compress_update(self, update: torch.Tensor) -> torch.Tensor:
        """压缩更新"""
        # 简化的压缩实现
        return update * 0.1  # 模拟压缩

联邦微调

联邦LoRA微调

class FederatedLoRA:
    def __init__(self, base_model: nn.Module, rank: int = 8):
        self.base_model = base_model
        self.rank = rank
        self.lora_weights = self._initialize_lora_weights()
    
    def _initialize_lora_weights(self) -> Dict[str, torch.Tensor]:
        """初始化LoRA权重"""
        lora_weights = {}
        
        for name, param in self.base_model.named_parameters():
            if 'weight' in name and param.dim() == 2:
                # 为每个线性层创建LoRA权重
                lora_weights[f"{name}_lora_A"] = torch.randn(param.size(0), self.rank) * 0.01
                lora_weights[f"{name}_lora_B"] = torch.randn(self.rank, param.size(1)) * 0.01
        
        return lora_weights
    
    def federated_lora_training(self, client_data: List[Any], 
                               local_epochs: int = 3) -> Dict[str, torch.Tensor]:
        """联邦LoRA训练"""
        # 只传输LoRA权重,减少通信开销
        client_lora_updates = []
        
        for data in client_data:
            # 本地LoRA训练
            lora_update = self._local_lora_training(data, local_epochs)
            client_lora_updates.append(lora_update)
        
        # 聚合LoRA权重
        aggregated_lora = self._aggregate_lora_weights(client_lora_updates)
        
        return aggregated_lora
    
    def _local_lora_training(self, data: Any, epochs: int) -> Dict[str, torch.Tensor]:
        """本地LoRA训练"""
        # 创建LoRA适配器
        lora_adapter = self._create_lora_adapter()
        
        # 训练逻辑
        for epoch in range(epochs):
            # 模拟训练
            pass
        
        # 返回LoRA权重更新
        return lora_adapter.state_dict()
    
    def _create_lora_adapter(self) -> nn.Module:
        """创建LoRA适配器"""
        class LoRAAdapter(nn.Module):
            def __init__(self, lora_weights):
                super().__init__()
                self.lora_weights = nn.ParameterDict(lora_weights)
            
            def forward(self, x):
                # 简化的LoRA前向传播
                return x
        
        return LoRAAdapter(self.lora_weights)
    
    def _aggregate_lora_weights(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """聚合LoRA权重"""
        aggregated = {}
        
        for key in client_updates[0].keys():
            aggregated[key] = torch.zeros_like(client_updates[0][key])
            
            for client_update in client_updates:
                aggregated[key] += client_update[key] / len(client_updates)
        
        return aggregated
    
    def apply_lora_to_model(self, lora_weights: Dict[str, torch.Tensor]):
        """将LoRA权重应用到模型"""
        # 更新基础模型的权重
        for name, param in self.base_model.named_parameters():
            if name in lora_weights:
                # 应用LoRA适配
                lora_A = lora_weights[f"{name}_lora_A"]
                lora_B = lora_weights[f"{name}_lora_B"]
                
                # 低秩更新
                delta_weight = lora_A @ lora_B
                param.data += delta_weight

联邦提示微调

class FederatedPromptTuning:
    def __init__(self, base_model: nn.Module, num_prompts: int = 10, prompt_length: int = 20):
        self.base_model = base_model
        self.num_prompts = num_prompts
        self.prompt_length = prompt_length
        self.prompt_embeddings = self._initialize_prompts()
    
    def _initialize_prompts(self) -> torch.Tensor:
        """初始化提示嵌入"""
        # 假设模型的嵌入维度为768
        embedding_dim = 768
        prompts = torch.randn(self.num_prompts, self.prompt_length, embedding_dim)
        return prompts
    
    def federated_prompt_training(self, client_data: List[Any], 
                                local_epochs: int = 5) -> torch.Tensor:
        """联邦提示微调"""
        client_prompt_updates = []
        
        for data in client_data:
            prompt_update = self._local_prompt_training(data, local_epochs)
            client_prompt_updates.append(prompt_update)
        
        # 聚合提示嵌入
        aggregated_prompts = self._aggregate_prompts(client_prompt_updates)
        
        return aggregated_prompts
    
    def _local_prompt_training(self, data: Any, epochs: int) -> torch.Tensor:
        """本地提示训练"""
        # 创建可训练的提示
        trainable_prompts = self.prompt_embeddings.clone().requires_grad_(True)
        
        optimizer = torch.optim.Adam([trainable_prompts], lr=0.01)
        
        for epoch in range(epochs):
            # 模拟训练
            optimizer.zero_grad()
            
            # 前向传播(简化)
            loss = torch.tensor(0.0, requires_grad=True)
            
            loss.backward()
            optimizer.step()
        
        return trainable_prompts.detach()
    
    def _aggregate_prompts(self, client_updates: List[torch.Tensor]) -> torch.Tensor:
        """聚合提示嵌入"""
        aggregated = torch.zeros_like(self.prompt_embeddings)
        
        for client_update in client_updates:
            aggregated += client_update / len(client_updates)
        
        return aggregated
    
    def apply_prompts_to_model(self, prompts: torch.Tensor):
        """将提示应用到模型"""
        # 在模型输入前添加提示
        original_forward = self.base_model.forward
        
        def prompted_forward(input_ids):
            # 添加提示token
            batch_size = input_ids.size(0)
            prompt_tokens = prompts.expand(batch_size, -1, -1)
            
            # 拼接提示和输入
            prompted_input = torch.cat([prompt_tokens, input_ids], dim=1)
            
            return original_forward(prompted_input)
        
        self.base_model.forward = prompted_forward

联邦学习评估

class FederatedLearningEvaluator:
    def __init__(self):
        self.metrics = {}
    
    def evaluate_global_model(self, global_model: nn.Module, test_data: Any) -> Dict[str, float]:
        """评估全局模型"""
        # 计算准确率
        accuracy = self._calculate_accuracy(global_model, test_data)
        
        # 计算损失
        loss = self._calculate_loss(global_model, test_data)
        
        # 计算公平性指标
        fairness = self._calculate_fairness(global_model, test_data)
        
        return {
            "accuracy": accuracy,
            "loss": loss,
            "fairness": fairness
        }
    
    def evaluate_client_contribution(self, client_updates: List[Dict[str, torch.Tensor]], 
                                   global_update: Dict[str, torch.Tensor]) -> List[float]:
        """评估客户端贡献"""
        contributions = []
        
        for client_update in client_updates:
            # 计算客户端更新与全局更新的相似度
            similarity = self._calculate_similarity(client_update, global_update)
            contributions.append(similarity)
        
        return contributions
    
    def _calculate_accuracy(self, model: nn.Module, data: Any) -> float:
        """计算准确率"""
        # 简化实现
        return 0.85
    
    def _calculate_loss(self, model: nn.Module, data: Any) -> float:
        """计算损失"""
        # 简化实现
        return 0.35
    
    def _calculate_fairness(self, model: nn.Module, data: Any) -> float:
        """计算公平性"""
        # 简化实现
        return 0.92
    
    def _calculate_similarity(self, update1: Dict[str, torch.Tensor], 
                            update2: Dict[str, torch.Tensor]) -> float:
        """计算相似度"""
        similarity = 0.0
        count = 0
        
        for key in update1.keys():
            if key in update2:
                cos_sim = torch.nn.functional.cosine_similarity(
                    update1[key].flatten().unsqueeze(0),
                    update2[key].flatten().unsqueeze(0)
                ).item()
                similarity += cos_sim
                count += 1
        
        return similarity / count if count > 0 else 0.0

总结

联邦学习为LLM的分布式训练提供了强大的解决方案。通过安全聚合、通信优化和联邦微调等技术,可以在保护数据隐私的同时,实现高效的模型训练。在实际应用中,需要根据具体场景选择合适的联邦学习策略,并持续优化通信效率和模型性能。