← 返回首页
🧠

文本生成技术详解

📂 llm ⏱ 2 min 395 words

--- title: "文本生成技术详解" description: "深入介绍大语言模型的文本生成原理和各种解码策略" tags: ["文本生成", "解码策略", "LLM", "NLP"] category: "llm" icon: "🧠"

文本生成技术详解

文本生成原理

大语言模型通过自回归方式生成文本:每次预测下一个token的概率分布,然后采样或选择token。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def generate_with_probs(prompt, n_tokens=5):
    """展示生成过程的概率分布"""
    inputs = tokenizer(prompt, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[:, -1, :]  # 最后一个位置的logits
    
    probs = torch.softmax(logits, dim=-1)
    top_k_probs, top_k_ids = torch.topk(probs, k=10)
    
    print(f"输入: {prompt}")
    print("Top 10 可能的下一个token:")
    for prob, token_id in zip(top_k_probs[0], top_k_ids[0]):
        token = tokenizer.decode(token_id)
        print(f"  '{token}': {prob:.4f}")

# generate_with_probs("The future of AI is")

解码策略

1. 贪心搜索(Greedy Search)

# 贪心搜索:每步选择概率最高的token
output = model.generate(
    **inputs,
    max_length=50,
    do_sample=False  # 贪心模式
)

print(tokenizer.decode(output[0], skip_special_tokens=True))
# 问题:容易生成重复、无趣的文本

2. 束搜索(Beam Search)

# 束搜索:维护k个最优序列
output = model.generate(
    **inputs,
    max_length=50,
    num_beams=5,          # 束宽度
    early_stopping=True,  # 所有束完成时停止
    no_repeat_ngram_size=2  # 避免重复n-gram
)

print(tokenizer.decode(output[0], skip_special_tokens=True))

3. 采样(Sampling)

# 随机采样:从概率分布中随机采样
output = model.generate(
    **inputs,
    max_length=50,
    do_sample=True,
    temperature=1.0  # 温度参数
)

print(tokenizer.decode(output[0], skip_special_tokens=True))

4. Top-K采样

# Top-K:只从概率最高的K个token中采样
output = model.generate(
    **inputs,
    max_length=50,
    do_sample=True,
    top_k=50,       # 只考虑top 50个token
    temperature=0.7
)

5. Top-P (Nucleus) 采样

# Top-P:累积概率达到p的最小token集合中采样
output = model.generate(
    **inputs,
    max_length=50,
    do_sample=True,
    top_p=0.9,      # 累积概率阈值
    temperature=0.7
)

6. 组合策略

# 实际应用中通常组合多种策略
output = model.generate(
    **inputs,
    max_length=100,
    do_sample=True,
    temperature=0.7,
    top_k=50,
    top_p=0.9,
    repetition_penalty=1.2,  # 重复惩罚
    no_repeat_ngram_size=3   # 避免3-gram重复
)

温度参数详解

import matplotlib.pyplot as plt
import numpy as np

def show_temperature_effect(logits, temperatures=[0.1, 0.5, 1.0, 1.5, 2.0]):
    """展示不同温度对分布的影响"""
    fig, axes = plt.subplots(1, len(temperatures), figsize=(15, 3))
    
    for ax, temp in zip(axes, temperatures):
        probs = torch.softmax(logits / temp, dim=-1).numpy()
        ax.bar(range(len(probs)), probs)
        ax.set_title(f"Temperature={temp}")
        ax.set_xlabel("Token ID")
        ax.set_ylabel("Probability")
    
    plt.tight_layout()
    plt.savefig("temperature_effect.png")
    plt.show()

# 低温度:更确定,更保守
# 高温度:更随机,更有创意

实际应用示例

创意写作

def creative_writing(prompt, style="creative"):
    """创意写作生成"""
    configs = {
        "creative": {"temperature": 0.9, "top_p": 0.95, "top_k": 100},
        "balanced": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
        "conservative": {"temperature": 0.3, "top_p": 0.8, "top_k": 20}
    }
    
    config = configs.get(style, configs["balanced"])
    
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model.generate(
        **inputs,
        max_length=200,
        do_sample=True,
        **config
    )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

# 生成故事开头
story = creative_writing("Once upon a time in a land far away", "creative")
print(story)

代码生成

def generate_code(prompt):
    """代码生成(需要代码模型)"""
    # 对于代码生成,通常使用较低温度
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model.generate(
        **inputs,
        max_length=300,
        temperature=0.2,  # 代码需要确定性
        top_p=0.95,
        do_sample=True
    )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

对话生成

def chat_response(history, user_input):
    """对话生成"""
    # 构建对话格式
    prompt = ""
    for msg in history:
        prompt += f"User: {msg['user']}\nAssistant: {msg['assistant']}\n"
    prompt += f"User: {user_input}\nAssistant:"
    
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model.generate(
        **inputs,
        max_length=200,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    # 只取Assistant:后面的部分
    response = response.split("Assistant:")[-1].strip()
    
    return response

生成质量控制

# 避免重复
output = model.generate(
    **inputs,
    max_length=100,
    repetition_penalty=1.5,      # 重复惩罚
    no_repeat_ngram_size=3,       # 禁止重复3-gram
    do_sample=True,
    temperature=0.7
)

# 控制输出长度
output = model.generate(
    **inputs,
    min_length=20,   # 最小长度
    max_length=100,  # 最大长度
    length_penalty=1.0,  # 长度惩罚
    do_sample=True
)

# 使用停止词
output = model.generate(
    **inputs,
    max_length=200,
    stop=["\n\n", "User:"],  # 遇到停止词则停止
    do_sample=True
)

总结

文本生成是大语言模型的核心能力,通过合理选择解码策略和参数,可以在创造性、连贯性和可控性之间找到平衡。理解不同策略的特点,对于构建高质量的生成应用至关重要。