PPO算法:近端策略优化
--- title: "PPO算法:近端策略优化" description: "深入理解PPO算法原理、实现细节和在RLHF中的应用" tags: ["PPO", "强化学习", "策略优化", "RLHF"] category: "llm" icon: "🧠"
PPO算法:近端策略优化
PPO简介
PPO(Proximal Policy Optimization)是由OpenAI提出的策略梯度算法,是目前最流行的强化学习算法之一。它通过限制策略更新的幅度,实现了稳定高效的训练。PPO是RLHF中用于优化LLM的核心算法。
PPO的核心优势:
- 稳定性:限制策略更新幅度,避免训练崩溃
- 简单性:实现简单,超参数较少
- 样本效率:支持多次复用采样数据
- 通用性:适用于连续和离散动作空间
算法原理
策略梯度基础
import torch
import torch.nn as nn
class PolicyNetwork(nn.Module):
"""策略网络"""
def __init__(self, state_dim, action_dim):
super().__init__()
self.fc = nn.Linear(state_dim, action_dim)
def forward(self, state):
logits = self.fc(state)
probs = torch.softmax(logits, dim=-1)
return probs
def compute_policy_gradient(log_probs, rewards):
"""计算策略梯度"""
loss = -(log_probs * rewards).mean()
return loss
PPO-Clip目标函数
def ppo_clip_loss(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2):
"""PPO-Clip损失函数"""
ratio = torch.exp(new_log_probs - old_log_probs)
clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
loss1 = ratio * advantages
loss2 = clipped_ratio * advantages
loss = -torch.min(loss1, loss2).mean()
return loss
# 优势函数计算
def compute_advantages(rewards, values, gamma=0.99, lam=0.95):
"""GAE(广义优势估计)"""
advantages = []
advantage = 0
for i in reversed(range(len(rewards))):
next_value = values[i + 1] if i < len(rewards) - 1 else 0
delta = rewards[i] + gamma * next_value - values[i]
advantage = delta + gamma * lam * advantage
advantages.insert(0, advantage)
return torch.tensor(advantages)
完整PPO实现
import torch
import torch.nn as nn
from torch.distributions import Categorical
class ActorCritic(nn.Module):
"""Actor-Critic网络"""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
# Actor(策略网络)
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
# Critic(价值网络)
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
action_probs = self.actor(state)
value = self.critic(state)
return action_probs, value
def get_action(self, state):
action_probs = self.actor(state)
dist = Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action)
return action, log_prob
def evaluate(self, state, action):
action_probs = self.actor(state)
dist = Categorical(action_probs)
log_prob = dist.log_prob(action)
entropy = dist.entropy()
value = self.critic(state)
return log_prob, value, entropy
class PPO:
"""PPO算法实现"""
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
gae_lambda=0.95, clip_epsilon=0.2, epochs=10):
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.epochs = epochs
self.policy = ActorCritic(state_dim, action_dim)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
def compute_gae(self, rewards, values, dones):
"""计算广义优势估计"""
advantages = []
advantage = 0
for i in reversed(range(len(rewards))):
if i == len(rewards) - 1:
next_value = 0
else:
next_value = values[i + 1]
delta = rewards[i] + self.gamma * next_value * (1 - dones[i]) - values[i]
advantage = delta + self.gamma * self.gae_lambda * (1 - dones[i]) * advantage
advantages.insert(0, advantage)
return torch.tensor(advantages)
def update(self, states, actions, old_log_probs, rewards, dones, values):
"""PPO更新"""
# 计算优势
advantages = self.compute_gae(rewards, values, dones)
returns = advantages + torch.tensor(values)
# 标准化优势
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# 多轮更新
for _ in range(self.epochs):
# 评估当前策略
new_log_probs, new_values, entropy = self.policy.evaluate(states, actions)
# 计算PPO损失
ratio = torch.exp(new_log_probs - old_log_probs)
clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
actor_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
critic_loss = nn.MSELoss()(new_values.squeeze(), returns)
entropy_loss = -entropy.mean()
# 总损失
loss = actor_loss + 0.5 * critic_loss + 0.01 * entropy_loss
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
self.optimizer.step()
在RLHF中的应用
PPO训练器
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer
# 配置PPO
ppo_config = PPOConfig(
learning_rate=1.41e-5,
batch_size=64,
mini_batch_size=16,
ppo_epochs=4,
kl_penalty="kl",
init_kl_coef=0.2,
target_kl=6.0,
seed=0
)
# 加载模型
model = AutoModelForCausalLMWithValueHead.from_pretrained("sft_model")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("sft_model")
tokenizer = AutoTokenizer.from_pretrained("sft_model")
# 创建PPO训练器
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer
)
# 训练循环
def train_step(query):
"""单步PPO训练"""
# 1. 编码查询
query_tensors = tokenizer(query, return_tensors="pt").input_ids[0]
# 2. 生成响应
response_tensors = ppo_trainer.generate(
[query_tensors],
max_new_tokens=256,
temperature=0.7
)
# 3. 计算奖励
reward = reward_model(response_tensors[0])
# 4. PPO更新
stats = ppo_trainer.step(
[query_tensors],
[response_tensors[0]],
[reward]
)
return stats
超参数调优
# 关键超参数
hyperparams = {
"learning_rate": 1.41e-5, # 学习率
"batch_size": 64, # 批量大小
"mini_batch_size": 16, # 小批量大小
"ppo_epochs": 4, # 每批数据的更新轮数
"clip_epsilon": 0.2, # 裁剪范围
"gamma": 0.99, # 折扣因子
"gae_lambda": 0.95, # GAE lambda
"init_kl_coef": 0.2, # 初始KL惩罚系数
"target_kl": 6.0, # 目标KL散度
"max_grad_norm": 0.5 # 梯度裁剪
}
训练稳定性技巧
# 1. 自适应KL控制
def adaptive_kl_control(kl, target_kl, kl_coef, kl_coef_high=1.5, kl_coef_low=0.5):
"""自适应KL惩罚系数"""
if kl < target_kl / 1.5:
kl_coef *= kl_coef_low
elif kl > target_kl * 1.5:
kl_coef *= kl_coef_high
return kl_coef
# 2. 奖励归一化
def normalize_rewards(rewards):
"""归一化奖励"""
mean_reward = rewards.mean()
std_reward = rewards.std()
normalized = (rewards - mean_reward) / (std_reward + 1e-8)
return normalized
# 3. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
PPO作为RLHF的核心算法,通过其稳定性和高效性,成为训练对齐LLM的首选方法。