思维树
--- title: "思维树" description: "Tree-of-Thought推理框架详解,支持多路径搜索与规划" tags: ["Tree-of-Thought", "思维树", "束搜索", "规划"] category: "llm" icon: "🧠"
思维树
思维树(Tree-of-Thought, ToT)将线性思维链扩展为树形搜索结构,允许模型在多个推理分支中探索和选择最优路径。By Yao等人于2023年提出,ToT特别适合需要规划和前瞻的复杂任务。
核心思想
传统CoT是线性的:问题 → 思考1 → 思考2 → 答案
ToT是树形的:问题 → {思考1a, 思考1b} → {思考2a, 思考2b, 思考2c} → ... → 选择最优答案
实现框架
import heapq
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class ThoughtNode:
state: str
thought: str
score: float = 0.0
parent: Optional['ThoughtNode'] = field(default=None, repr=False)
children: List['ThoughtNode'] = field(default_factory=list, repr=False)
def __lt__(self, other):
return self.score > other.score # 高分优先
class TreeOfThought:
def __init__(self, model, tokenizer, num_branches=3, max_depth=5):
self.model = model
self.tokenizer = tokenizer
self.num_branches = num_branches
self.max_depth = max_depth
def solve(self, problem):
"""ToT主流程"""
root = ThoughtNode(state=problem, thought="")
# 广度优先搜索
current_nodes = [root]
for depth in range(self.max_depth):
next_nodes = []
for node in current_nodes:
# 生成多个思考分支
thoughts = self._generate_thoughts(node)
for thought in thoughts:
child = ThoughtNode(
state=self._update_state(node.state, thought),
thought=thought,
parent=node,
)
# 评估每个思考的价值
child.score = self._evaluate_thought(child)
node.children.append(child)
next_nodes.append(child)
# 选择最优节点继续探索(束搜索)
current_nodes = sorted(next_nodes, key=lambda x: x.score, reverse=True)
current_nodes = current_nodes[:self.num_branches]
# 返回最优路径
best_leaf = max(current_nodes, key=lambda x: x.score)
return self._extract_answer(best_leaf)
def _generate_thoughts(self, node):
"""生成候选思考"""
prompt = f"""基于当前推理状态,生成{self.num_branches}个不同的下一步思考方向。
当前状态:
{node.state}
请生成{self.num_branches}个候选思考(每个一行):"""
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(
**inputs, max_new_tokens=256, temperature=0.8, do_sample=True
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# 解析多个思考
thoughts = [t.strip() for t in response.split("\n") if t.strip()]
return thoughts[:self.num_branches]
def _evaluate_thought(self, node):
"""评估思考的前景"""
prompt = f"""评估以下推理步骤的质量(1-10分)。
问题:{node.parent.state if node.parent else node.state}
当前思考:{node.thought}
评分(1-10):"""
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(**inputs, max_new_tokens=32)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取分数
try:
score = float(''.join(c for c in response if c.isdigit())[:2]) / 10
except:
score = 0.5
return min(max(score, 0), 1)
束搜索实现
class BeamSearchToT:
def __init__(self, model, tokenizer, beam_width=3):
self.model = model
self.tokenizer = tokenizer
self.beam_width = beam_width
def search(self, problem):
"""束搜索:维护top-k候选"""
# 初始化:生成多个起始思考
initial_thoughts = self._generate_thoughts(problem)
beams = [(thought, 0.0, [thought]) for thought in initial_thoughts]
for step in range(5): # 5步推理
all_candidates = []
for current_thought, score, path in beams:
# 对每个候选生成下一步
next_thoughts = self._generate_thoughts(
f"{problem}\n思考历史:{' → '.join(path)}"
)
for next_thought in next_thoughts:
new_score = score + self._score_thought(next_thought)
new_path = path + [next_thought]
all_candidates.append((next_thought, new_score, new_path))
# 保留top-k
all_candidates.sort(key=lambda x: x[1], reverse=True)
beams = all_candidates[:self.beam_width]
# 返回最优路径
best = max(beams, key=lambda x: x[1])
return {
"path": best[2],
"score": best[1],
"answer": self._extract_final(best[2]),
}
def _score_thought(self, thought):
prompt = f"评估推理质量(0-1):{thought}"
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(**inputs, max_new_tokens=16)
try:
return float(self.tokenizer.decode(outputs[0])[:3])
except:
return 0.5
搜索策略对比
class ToTStrategies:
@staticmethod
def bfs_search(root, goal_test, max_depth=5):
"""广度优先搜索:全面探索"""
from collections import deque
queue = deque([(root, 0)])
while queue:
node, depth = queue.popleft()
if goal_test(node) or depth >= max_depth:
return node
for child in node.children:
queue.append((child, depth + 1))
@staticmethod
def dfs_search(node, goal_test, max_depth=5, depth=0):
"""深度优先搜索:深入探索"""
if goal_test(node) or depth >= max_depth:
return node
for child in node.children:
result = ToTStrategies.dfs_search(child, goal_test, max_depth, depth + 1)
if result:
return result
return None
@staticmethod
def mcts_search(root, num_simulations=100):
"""蒙特卡洛树搜索:平衡探索与利用"""
for _ in range(num_simulations):
node = root
# 选择
while node.children:
node = max(node.children, key=lambda n: n.score)
# 扩展和模拟
# ...回溯更新
适用场景
| 场景 | 说明 | 推荐策略 |
|---|---|---|
| 数学证明 | 需要多步推导和回溯 | DFS + 剪枝 |
| 文本创作 | 需要探索不同叙事方向 | BFS + 多样性 |
| 规划任务 | 需要全局最优解 | MCTS |
| 代码生成 | 需要验证中间结果 | 束搜索 + 验证 |
ToT通过系统性地探索多条推理路径,显著提升了模型在复杂规划和搜索任务中的表现。