← 返回首页
🧠

PPO算法:近端策略优化

📂 llm ⏱ 3 min 579 words

--- title: "PPO算法:近端策略优化" description: "深入理解PPO算法原理、实现细节和在RLHF中的应用" tags: ["PPO", "强化学习", "策略优化", "RLHF"] category: "llm" icon: "🧠"

PPO算法:近端策略优化

PPO简介

PPO(Proximal Policy Optimization)是由OpenAI提出的策略梯度算法,是目前最流行的强化学习算法之一。它通过限制策略更新的幅度,实现了稳定高效的训练。PPO是RLHF中用于优化LLM的核心算法。

PPO的核心优势:

算法原理

策略梯度基础

import torch
import torch.nn as nn

class PolicyNetwork(nn.Module):
    """策略网络"""
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Linear(state_dim, action_dim)
    
    def forward(self, state):
        logits = self.fc(state)
        probs = torch.softmax(logits, dim=-1)
        return probs

def compute_policy_gradient(log_probs, rewards):
    """计算策略梯度"""
    loss = -(log_probs * rewards).mean()
    return loss

PPO-Clip目标函数

def ppo_clip_loss(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2):
    """PPO-Clip损失函数"""
    ratio = torch.exp(new_log_probs - old_log_probs)
    
    clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
    
    loss1 = ratio * advantages
    loss2 = clipped_ratio * advantages
    
    loss = -torch.min(loss1, loss2).mean()
    return loss

# 优势函数计算
def compute_advantages(rewards, values, gamma=0.99, lam=0.95):
    """GAE(广义优势估计)"""
    advantages = []
    advantage = 0
    
    for i in reversed(range(len(rewards))):
        next_value = values[i + 1] if i < len(rewards) - 1 else 0
        delta = rewards[i] + gamma * next_value - values[i]
        advantage = delta + gamma * lam * advantage
        advantages.insert(0, advantage)
    
    return torch.tensor(advantages)

完整PPO实现

import torch
import torch.nn as nn
from torch.distributions import Categorical

class ActorCritic(nn.Module):
    """Actor-Critic网络"""
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        
        # Actor(策略网络)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # Critic(价值网络)
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state):
        action_probs = self.actor(state)
        value = self.critic(state)
        return action_probs, value
    
    def get_action(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action, log_prob
    
    def evaluate(self, state, action):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        value = self.critic(state)
        return log_prob, value, entropy

class PPO:
    """PPO算法实现"""
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, 
                 gae_lambda=0.95, clip_epsilon=0.2, epochs=10):
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.epochs = epochs
        
        self.policy = ActorCritic(state_dim, action_dim)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
    
    def compute_gae(self, rewards, values, dones):
        """计算广义优势估计"""
        advantages = []
        advantage = 0
        
        for i in reversed(range(len(rewards))):
            if i == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[i + 1]
            
            delta = rewards[i] + self.gamma * next_value * (1 - dones[i]) - values[i]
            advantage = delta + self.gamma * self.gae_lambda * (1 - dones[i]) * advantage
            advantages.insert(0, advantage)
        
        return torch.tensor(advantages)
    
    def update(self, states, actions, old_log_probs, rewards, dones, values):
        """PPO更新"""
        # 计算优势
        advantages = self.compute_gae(rewards, values, dones)
        returns = advantages + torch.tensor(values)
        
        # 标准化优势
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # 多轮更新
        for _ in range(self.epochs):
            # 评估当前策略
            new_log_probs, new_values, entropy = self.policy.evaluate(states, actions)
            
            # 计算PPO损失
            ratio = torch.exp(new_log_probs - old_log_probs)
            clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
            
            actor_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
            critic_loss = nn.MSELoss()(new_values.squeeze(), returns)
            entropy_loss = -entropy.mean()
            
            # 总损失
            loss = actor_loss + 0.5 * critic_loss + 0.01 * entropy_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
            self.optimizer.step()

在RLHF中的应用

PPO训练器

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer

# 配置PPO
ppo_config = PPOConfig(
    learning_rate=1.41e-5,
    batch_size=64,
    mini_batch_size=16,
    ppo_epochs=4,
    kl_penalty="kl",
    init_kl_coef=0.2,
    target_kl=6.0,
    seed=0
)

# 加载模型
model = AutoModelForCausalLMWithValueHead.from_pretrained("sft_model")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("sft_model")
tokenizer = AutoTokenizer.from_pretrained("sft_model")

# 创建PPO训练器
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer
)

# 训练循环
def train_step(query):
    """单步PPO训练"""
    # 1. 编码查询
    query_tensors = tokenizer(query, return_tensors="pt").input_ids[0]
    
    # 2. 生成响应
    response_tensors = ppo_trainer.generate(
        [query_tensors],
        max_new_tokens=256,
        temperature=0.7
    )
    
    # 3. 计算奖励
    reward = reward_model(response_tensors[0])
    
    # 4. PPO更新
    stats = ppo_trainer.step(
        [query_tensors],
        [response_tensors[0]],
        [reward]
    )
    
    return stats

超参数调优

# 关键超参数
hyperparams = {
    "learning_rate": 1.41e-5,      # 学习率
    "batch_size": 64,               # 批量大小
    "mini_batch_size": 16,          # 小批量大小
    "ppo_epochs": 4,                # 每批数据的更新轮数
    "clip_epsilon": 0.2,            # 裁剪范围
    "gamma": 0.99,                  # 折扣因子
    "gae_lambda": 0.95,             # GAE lambda
    "init_kl_coef": 0.2,            # 初始KL惩罚系数
    "target_kl": 6.0,               # 目标KL散度
    "max_grad_norm": 0.5            # 梯度裁剪
}

训练稳定性技巧

# 1. 自适应KL控制
def adaptive_kl_control(kl, target_kl, kl_coef, kl_coef_high=1.5, kl_coef_low=0.5):
    """自适应KL惩罚系数"""
    if kl < target_kl / 1.5:
        kl_coef *= kl_coef_low
    elif kl > target_kl * 1.5:
        kl_coef *= kl_coef_high
    return kl_coef

# 2. 奖励归一化
def normalize_rewards(rewards):
    """归一化奖励"""
    mean_reward = rewards.mean()
    std_reward = rewards.std()
    normalized = (rewards - mean_reward) / (std_reward + 1e-8)
    return normalized

# 3. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

PPO作为RLHF的核心算法,通过其稳定性和高效性,成为训练对齐LLM的首选方法。