联邦学习LLM
--- title: "联邦学习LLM" description: "介绍联邦学习在大语言模型中的应用,包括安全聚合、通信优化、联邦微调等关键技术" tags: ["联邦学习", "安全聚合", "通信优化", "联邦微调"] category: "llm" icon: "🧠"
联邦学习LLM
联邦学习概述
联邦学习(Federated Learning)是一种分布式机器学习方法,允许多个参与方在不共享原始数据的情况下协作训练模型。在LLM领域,联邦学习可以解决数据孤岛、隐私保护和合规性等问题。
联邦学习的基本流程
import torch
import torch.nn as nn
from typing import List, Dict, Any
class FederatedLearningSystem:
def __init__(self, num_clients: int, global_model: nn.Module):
self.num_clients = num_clients
self.global_model = global_model
self.client_models = []
self.aggregation_weights = [1.0 / num_clients] * num_clients
def federated_averaging(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""联邦平均聚合(FedAvg)"""
global_dict = self.global_model.state_dict()
# 加权平均
for key in global_dict.keys():
global_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)
for i, client_update in enumerate(client_updates):
if key in client_update:
global_dict[key] += self.aggregation_weights[i] * client_update[key].float()
return global_dict
def distribute_global_model(self) -> Dict[str, torch.Tensor]:
"""分发全局模型到客户端"""
return self.global_model.state_dict().copy()
def train_round(self, client_data: List[Any], local_epochs: int = 5) -> None:
"""执行一轮联邦训练"""
# 1. 分发全局模型
global_state = self.distribute_global_model()
# 2. 客户端本地训练
client_updates = []
for i, data in enumerate(client_data):
client_model = self._create_client_model(global_state)
client_update = self._local_training(client_model, data, local_epochs)
client_updates.append(client_update)
# 3. 聚合更新
new_global_state = self.federated_averaging(client_updates)
self.global_model.load_state_dict(new_global_state)
def _create_client_model(self, state_dict: Dict[str, torch.Tensor]) -> nn.Module:
"""创建客户端模型"""
client_model = self._copy_model(self.global_model)
client_model.load_state_dict(state_dict)
return client_model
def _local_training(self, model: nn.Module, data: Any, epochs: int) -> Dict[str, torch.Tensor]:
"""客户端本地训练"""
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
# 模拟本地训练
pass
# 返回模型更新(差值)
return model.state_dict()
def _copy_model(self, model: nn.Module) -> nn.Module:
"""复制模型"""
import copy
return copy.deepcopy(model)
安全聚合
同态加密聚合
import numpy as np
from typing import List
class HomomorphicEncryption:
def __init__(self):
# 简化的同态加密实现
self.public_key = None
self.private_key = None
def key_generation(self, key_size: int = 2048):
"""密钥生成"""
# 实际实现中使用RSA或Paillier加密
self.public_key = {"n": 1234567890, "g": 12345}
self.private_key = {"lambda": 12345, "mu": 12345}
def encrypt(self, plaintext: float) -> float:
"""加密"""
# 简化实现
return plaintext * self.public_key["g"]
def decrypt(self, ciphertext: float) -> float:
"""解密"""
# 简化实现
return ciphertext / self.public_key["g"]
def add_encrypted(self, enc1: float, enc2: float) -> float:
"""密文加法"""
return enc1 + enc2
def multiply_encrypted(self, enc: float, scalar: float) -> float:
"""密文标量乘法"""
return enc * scalar
class SecureAggregation:
def __init__(self, num_clients: int):
self.num_clients = num_clients
self.he = HomomorphicEncryption()
self.he.key_generation()
def secure_aggregate(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""安全聚合"""
aggregated = {}
for key in client_updates[0].keys():
# 加密每个客户端的更新
encrypted_updates = []
for client_update in client_updates:
encrypted = self.he.encrypt(client_update[key].float().numpy())
encrypted_updates.append(encrypted)
# 安全聚合(密文加法)
encrypted_sum = encrypted_updates[0]
for enc_update in encrypted_updates[1:]:
encrypted_sum = self.he.add_encrypted(encrypted_sum, enc_update)
# 解密聚合结果
aggregated_value = self.he.decrypt(encrypted_sum)
# 平均
aggregated[key] = torch.tensor(aggregated_value / self.num_clients)
return aggregated
def secure_weighted_aggregate(self, client_updates: List[Dict[str, torch.Tensor]],
weights: List[float]) -> Dict[str, torch.Tensor]:
"""安全加权聚合"""
aggregated = {}
for key in client_updates[0].keys():
encrypted_updates = []
for i, client_update in enumerate(client_updates):
# 加密并加权
encrypted = self.he.encrypt(client_update[key].float().numpy())
weighted = self.he.multiply_encrypted(encrypted, weights[i])
encrypted_updates.append(weighted)
# 安全聚合
encrypted_sum = encrypted_updates[0]
for enc_update in encrypted_updates[1:]:
encrypted_sum = self.he.add_encrypted(encrypted_sum, enc_update)
# 解密
aggregated_value = self.he.decrypt(encrypted_sum)
aggregated[key] = torch.tensor(aggregated_value)
return aggregated
安全多方计算聚合
class SecureMultiPartyAggregation:
def __init__(self, num_parties: int):
self.num_parties = num_parties
def secret_sharing(self, secret: torch.Tensor) -> List[torch.Tensor]:
"""秘密共享"""
shares = []
remaining = secret.clone()
for i in range(self.num_parties - 1):
share = torch.randn_like(secret)
shares.append(share)
remaining -= share
shares.append(remaining)
return shares
def secure_addition(self, shares_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""安全加法"""
result_shares = []
for i in range(self.num_parties):
share_sum = torch.zeros_like(shares_list[0][i])
for client_shares in shares_list:
share_sum += client_shares[i]
result_shares.append(share_sum)
return result_shares
def reveal(self, shares: List[torch.Tensor]) -> torch.Tensor:
"""揭示秘密"""
return sum(shares)
def mpc_aggregate(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""MPC聚合"""
aggregated = {}
for key in client_updates[0].keys():
# 每个客户端秘密共享其更新
client_shares = []
for client_update in client_updates:
shares = self.secret_sharing(client_update[key])
client_shares.append(shares)
# 安全聚合
result_shares = self.secure_addition(client_shares)
# 揭示聚合结果
aggregated[key] = self.reveal(result_shares) / len(client_updates)
return aggregated
通信优化
模型压缩
class CommunicationOptimizer:
def __init__(self, compression_rate: float = 0.1):
self.compression_rate = compression_rate
def quantize_update(self, update: torch.Tensor, bits: int = 8) -> torch.Tensor:
"""量化模型更新"""
# 计算量化范围
min_val = update.min()
max_val = update.max()
# 量化
scale = (max_val - min_val) / (2 ** bits - 1)
quantized = torch.round((update - min_val) / scale)
# 反量化
dequantized = quantized * scale + min_val
return dequantized
def sparsify_update(self, update: torch.Tensor) -> torch.Tensor:
"""稀疏化模型更新"""
# 计算阈值(保留top-k%的参数)
k = int(update.numel() * self.compression_rate)
threshold = torch.kthvalue(update.abs().flatten(), k).values
# 创建掩码
mask = update.abs() >= threshold
# 应用掩码
sparse_update = update * mask
return sparse_update
def compress_update(self, update: torch.Tensor) -> Dict[str, Any]:
"""压缩模型更新"""
# 稀疏化
sparse_update = self.sparsify_update(update)
# 提取非零值
non_zero_indices = torch.nonzero(sparse_update)
non_zero_values = sparse_update[sparse_update != 0]
# 量化非零值
quantized_values = self.quantize_update(non_zero_values)
compressed = {
"indices": non_zero_indices,
"values": quantized_values,
"shape": update.shape,
"compression_rate": self.compression_rate
}
return compressed
def decompress_update(self, compressed: Dict[str, Any]) -> torch.Tensor:
"""解压模型更新"""
# 重建稀疏张量
sparse_update = torch.zeros(compressed["shape"])
sparse_update[compressed["indices"]] = compressed["values"]
return sparse_update
def gradient_accumulation(self, local_updates: List[torch.Tensor],
accumulation_steps: int) -> torch.Tensor:
"""梯度累积"""
accumulated = torch.zeros_like(local_updates[0])
for i, update in enumerate(local_updates):
accumulated += update / accumulation_steps
return accumulated
通信协议优化
class FederatedCommunicationProtocol:
def __init__(self, num_clients: int, communication_rounds: int):
self.num_clients = num_clients
self.communication_rounds = communication_rounds
self.bandwidth_limits = [1000] * num_clients # 模拟带宽限制
def adaptive_client_selection(self, client_performance: List[float]) -> List[int]:
"""自适应客户端选择"""
# 根据性能选择客户端
sorted_indices = np.argsort(client_performance)[::-1]
selected = sorted_indices[:self.num_clients // 2]
return selected.tolist()
def asynchronous_update(self, client_updates: Dict[int, torch.Tensor],
staleness: Dict[int, int]) -> torch.Tensor:
"""异步更新"""
# 根据延迟调整权重
weights = {}
total_weight = 0
for client_id, update in client_updates.items():
# 延迟越高,权重越低
weight = 1.0 / (1 + staleness[client_id])
weights[client_id] = weight
total_weight += weight
# 加权聚合
aggregated = torch.zeros_like(list(client_updates.values())[0])
for client_id, update in client_updates.items():
aggregated += weights[client_id] / total_weight * update
return aggregated
def differential_compression(self, update: torch.Tensor,
previous_update: torch.Tensor) -> torch.Tensor:
"""差分压缩"""
# 计算差值
delta = update - previous_update
# 只传输差值
compressed_delta = self.compress_update(delta)
return compressed_delta
def compress_update(self, update: torch.Tensor) -> torch.Tensor:
"""压缩更新"""
# 简化的压缩实现
return update * 0.1 # 模拟压缩
联邦微调
联邦LoRA微调
class FederatedLoRA:
def __init__(self, base_model: nn.Module, rank: int = 8):
self.base_model = base_model
self.rank = rank
self.lora_weights = self._initialize_lora_weights()
def _initialize_lora_weights(self) -> Dict[str, torch.Tensor]:
"""初始化LoRA权重"""
lora_weights = {}
for name, param in self.base_model.named_parameters():
if 'weight' in name and param.dim() == 2:
# 为每个线性层创建LoRA权重
lora_weights[f"{name}_lora_A"] = torch.randn(param.size(0), self.rank) * 0.01
lora_weights[f"{name}_lora_B"] = torch.randn(self.rank, param.size(1)) * 0.01
return lora_weights
def federated_lora_training(self, client_data: List[Any],
local_epochs: int = 3) -> Dict[str, torch.Tensor]:
"""联邦LoRA训练"""
# 只传输LoRA权重,减少通信开销
client_lora_updates = []
for data in client_data:
# 本地LoRA训练
lora_update = self._local_lora_training(data, local_epochs)
client_lora_updates.append(lora_update)
# 聚合LoRA权重
aggregated_lora = self._aggregate_lora_weights(client_lora_updates)
return aggregated_lora
def _local_lora_training(self, data: Any, epochs: int) -> Dict[str, torch.Tensor]:
"""本地LoRA训练"""
# 创建LoRA适配器
lora_adapter = self._create_lora_adapter()
# 训练逻辑
for epoch in range(epochs):
# 模拟训练
pass
# 返回LoRA权重更新
return lora_adapter.state_dict()
def _create_lora_adapter(self) -> nn.Module:
"""创建LoRA适配器"""
class LoRAAdapter(nn.Module):
def __init__(self, lora_weights):
super().__init__()
self.lora_weights = nn.ParameterDict(lora_weights)
def forward(self, x):
# 简化的LoRA前向传播
return x
return LoRAAdapter(self.lora_weights)
def _aggregate_lora_weights(self, client_updates: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""聚合LoRA权重"""
aggregated = {}
for key in client_updates[0].keys():
aggregated[key] = torch.zeros_like(client_updates[0][key])
for client_update in client_updates:
aggregated[key] += client_update[key] / len(client_updates)
return aggregated
def apply_lora_to_model(self, lora_weights: Dict[str, torch.Tensor]):
"""将LoRA权重应用到模型"""
# 更新基础模型的权重
for name, param in self.base_model.named_parameters():
if name in lora_weights:
# 应用LoRA适配
lora_A = lora_weights[f"{name}_lora_A"]
lora_B = lora_weights[f"{name}_lora_B"]
# 低秩更新
delta_weight = lora_A @ lora_B
param.data += delta_weight
联邦提示微调
class FederatedPromptTuning:
def __init__(self, base_model: nn.Module, num_prompts: int = 10, prompt_length: int = 20):
self.base_model = base_model
self.num_prompts = num_prompts
self.prompt_length = prompt_length
self.prompt_embeddings = self._initialize_prompts()
def _initialize_prompts(self) -> torch.Tensor:
"""初始化提示嵌入"""
# 假设模型的嵌入维度为768
embedding_dim = 768
prompts = torch.randn(self.num_prompts, self.prompt_length, embedding_dim)
return prompts
def federated_prompt_training(self, client_data: List[Any],
local_epochs: int = 5) -> torch.Tensor:
"""联邦提示微调"""
client_prompt_updates = []
for data in client_data:
prompt_update = self._local_prompt_training(data, local_epochs)
client_prompt_updates.append(prompt_update)
# 聚合提示嵌入
aggregated_prompts = self._aggregate_prompts(client_prompt_updates)
return aggregated_prompts
def _local_prompt_training(self, data: Any, epochs: int) -> torch.Tensor:
"""本地提示训练"""
# 创建可训练的提示
trainable_prompts = self.prompt_embeddings.clone().requires_grad_(True)
optimizer = torch.optim.Adam([trainable_prompts], lr=0.01)
for epoch in range(epochs):
# 模拟训练
optimizer.zero_grad()
# 前向传播(简化)
loss = torch.tensor(0.0, requires_grad=True)
loss.backward()
optimizer.step()
return trainable_prompts.detach()
def _aggregate_prompts(self, client_updates: List[torch.Tensor]) -> torch.Tensor:
"""聚合提示嵌入"""
aggregated = torch.zeros_like(self.prompt_embeddings)
for client_update in client_updates:
aggregated += client_update / len(client_updates)
return aggregated
def apply_prompts_to_model(self, prompts: torch.Tensor):
"""将提示应用到模型"""
# 在模型输入前添加提示
original_forward = self.base_model.forward
def prompted_forward(input_ids):
# 添加提示token
batch_size = input_ids.size(0)
prompt_tokens = prompts.expand(batch_size, -1, -1)
# 拼接提示和输入
prompted_input = torch.cat([prompt_tokens, input_ids], dim=1)
return original_forward(prompted_input)
self.base_model.forward = prompted_forward
联邦学习评估
class FederatedLearningEvaluator:
def __init__(self):
self.metrics = {}
def evaluate_global_model(self, global_model: nn.Module, test_data: Any) -> Dict[str, float]:
"""评估全局模型"""
# 计算准确率
accuracy = self._calculate_accuracy(global_model, test_data)
# 计算损失
loss = self._calculate_loss(global_model, test_data)
# 计算公平性指标
fairness = self._calculate_fairness(global_model, test_data)
return {
"accuracy": accuracy,
"loss": loss,
"fairness": fairness
}
def evaluate_client_contribution(self, client_updates: List[Dict[str, torch.Tensor]],
global_update: Dict[str, torch.Tensor]) -> List[float]:
"""评估客户端贡献"""
contributions = []
for client_update in client_updates:
# 计算客户端更新与全局更新的相似度
similarity = self._calculate_similarity(client_update, global_update)
contributions.append(similarity)
return contributions
def _calculate_accuracy(self, model: nn.Module, data: Any) -> float:
"""计算准确率"""
# 简化实现
return 0.85
def _calculate_loss(self, model: nn.Module, data: Any) -> float:
"""计算损失"""
# 简化实现
return 0.35
def _calculate_fairness(self, model: nn.Module, data: Any) -> float:
"""计算公平性"""
# 简化实现
return 0.92
def _calculate_similarity(self, update1: Dict[str, torch.Tensor],
update2: Dict[str, torch.Tensor]) -> float:
"""计算相似度"""
similarity = 0.0
count = 0
for key in update1.keys():
if key in update2:
cos_sim = torch.nn.functional.cosine_similarity(
update1[key].flatten().unsqueeze(0),
update2[key].flatten().unsqueeze(0)
).item()
similarity += cos_sim
count += 1
return similarity / count if count > 0 else 0.0
总结
联邦学习为LLM的分布式训练提供了强大的解决方案。通过安全聚合、通信优化和联邦微调等技术,可以在保护数据隐私的同时,实现高效的模型训练。在实际应用中,需要根据具体场景选择合适的联邦学习策略,并持续优化通信效率和模型性能。