文本生成技术详解
--- title: "文本生成技术详解" description: "深入介绍大语言模型的文本生成原理和各种解码策略" tags: ["文本生成", "解码策略", "LLM", "NLP"] category: "llm" icon: "🧠"
文本生成技术详解
文本生成原理
大语言模型通过自回归方式生成文本:每次预测下一个token的概率分布,然后采样或选择token。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_with_probs(prompt, n_tokens=5):
"""展示生成过程的概率分布"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[:, -1, :] # 最后一个位置的logits
probs = torch.softmax(logits, dim=-1)
top_k_probs, top_k_ids = torch.topk(probs, k=10)
print(f"输入: {prompt}")
print("Top 10 可能的下一个token:")
for prob, token_id in zip(top_k_probs[0], top_k_ids[0]):
token = tokenizer.decode(token_id)
print(f" '{token}': {prob:.4f}")
# generate_with_probs("The future of AI is")
解码策略
1. 贪心搜索(Greedy Search)
# 贪心搜索:每步选择概率最高的token
output = model.generate(
**inputs,
max_length=50,
do_sample=False # 贪心模式
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
# 问题:容易生成重复、无趣的文本
2. 束搜索(Beam Search)
# 束搜索:维护k个最优序列
output = model.generate(
**inputs,
max_length=50,
num_beams=5, # 束宽度
early_stopping=True, # 所有束完成时停止
no_repeat_ngram_size=2 # 避免重复n-gram
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
3. 采样(Sampling)
# 随机采样:从概率分布中随机采样
output = model.generate(
**inputs,
max_length=50,
do_sample=True,
temperature=1.0 # 温度参数
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
4. Top-K采样
# Top-K:只从概率最高的K个token中采样
output = model.generate(
**inputs,
max_length=50,
do_sample=True,
top_k=50, # 只考虑top 50个token
temperature=0.7
)
5. Top-P (Nucleus) 采样
# Top-P:累积概率达到p的最小token集合中采样
output = model.generate(
**inputs,
max_length=50,
do_sample=True,
top_p=0.9, # 累积概率阈值
temperature=0.7
)
6. 组合策略
# 实际应用中通常组合多种策略
output = model.generate(
**inputs,
max_length=100,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2, # 重复惩罚
no_repeat_ngram_size=3 # 避免3-gram重复
)
温度参数详解
import matplotlib.pyplot as plt
import numpy as np
def show_temperature_effect(logits, temperatures=[0.1, 0.5, 1.0, 1.5, 2.0]):
"""展示不同温度对分布的影响"""
fig, axes = plt.subplots(1, len(temperatures), figsize=(15, 3))
for ax, temp in zip(axes, temperatures):
probs = torch.softmax(logits / temp, dim=-1).numpy()
ax.bar(range(len(probs)), probs)
ax.set_title(f"Temperature={temp}")
ax.set_xlabel("Token ID")
ax.set_ylabel("Probability")
plt.tight_layout()
plt.savefig("temperature_effect.png")
plt.show()
# 低温度:更确定,更保守
# 高温度:更随机,更有创意
实际应用示例
创意写作
def creative_writing(prompt, style="creative"):
"""创意写作生成"""
configs = {
"creative": {"temperature": 0.9, "top_p": 0.95, "top_k": 100},
"balanced": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
"conservative": {"temperature": 0.3, "top_p": 0.8, "top_k": 20}
}
config = configs.get(style, configs["balanced"])
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(
**inputs,
max_length=200,
do_sample=True,
**config
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# 生成故事开头
story = creative_writing("Once upon a time in a land far away", "creative")
print(story)
代码生成
def generate_code(prompt):
"""代码生成(需要代码模型)"""
# 对于代码生成,通常使用较低温度
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(
**inputs,
max_length=300,
temperature=0.2, # 代码需要确定性
top_p=0.95,
do_sample=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
对话生成
def chat_response(history, user_input):
"""对话生成"""
# 构建对话格式
prompt = ""
for msg in history:
prompt += f"User: {msg['user']}\nAssistant: {msg['assistant']}\n"
prompt += f"User: {user_input}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(
**inputs,
max_length=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# 只取Assistant:后面的部分
response = response.split("Assistant:")[-1].strip()
return response
生成质量控制
# 避免重复
output = model.generate(
**inputs,
max_length=100,
repetition_penalty=1.5, # 重复惩罚
no_repeat_ngram_size=3, # 禁止重复3-gram
do_sample=True,
temperature=0.7
)
# 控制输出长度
output = model.generate(
**inputs,
min_length=20, # 最小长度
max_length=100, # 最大长度
length_penalty=1.0, # 长度惩罚
do_sample=True
)
# 使用停止词
output = model.generate(
**inputs,
max_length=200,
stop=["\n\n", "User:"], # 遇到停止词则停止
do_sample=True
)
总结
文本生成是大语言模型的核心能力,通过合理选择解码策略和参数,可以在创造性、连贯性和可控性之间找到平衡。理解不同策略的特点,对于构建高质量的生成应用至关重要。