← 返回首页
🧠

思维树

📂 llm ⏱ 3 min 481 words

--- title: "思维树" description: "Tree-of-Thought推理框架详解,支持多路径搜索与规划" tags: ["Tree-of-Thought", "思维树", "束搜索", "规划"] category: "llm" icon: "🧠"

思维树

思维树(Tree-of-Thought, ToT)将线性思维链扩展为树形搜索结构,允许模型在多个推理分支中探索和选择最优路径。By Yao等人于2023年提出,ToT特别适合需要规划和前瞻的复杂任务。

核心思想

传统CoT是线性的:问题 → 思考1 → 思考2 → 答案

ToT是树形的:问题 → {思考1a, 思考1b} → {思考2a, 思考2b, 思考2c} → ... → 选择最优答案

实现框架

import heapq
from dataclasses import dataclass, field
from typing import List, Optional

@dataclass
class ThoughtNode:
    state: str
    thought: str
    score: float = 0.0
    parent: Optional['ThoughtNode'] = field(default=None, repr=False)
    children: List['ThoughtNode'] = field(default_factory=list, repr=False)
    
    def __lt__(self, other):
        return self.score > other.score  # 高分优先

class TreeOfThought:
    def __init__(self, model, tokenizer, num_branches=3, max_depth=5):
        self.model = model
        self.tokenizer = tokenizer
        self.num_branches = num_branches
        self.max_depth = max_depth
    
    def solve(self, problem):
        """ToT主流程"""
        root = ThoughtNode(state=problem, thought="")
        
        # 广度优先搜索
        current_nodes = [root]
        
        for depth in range(self.max_depth):
            next_nodes = []
            for node in current_nodes:
                # 生成多个思考分支
                thoughts = self._generate_thoughts(node)
                for thought in thoughts:
                    child = ThoughtNode(
                        state=self._update_state(node.state, thought),
                        thought=thought,
                        parent=node,
                    )
                    # 评估每个思考的价值
                    child.score = self._evaluate_thought(child)
                    node.children.append(child)
                    next_nodes.append(child)
            
            # 选择最优节点继续探索(束搜索)
            current_nodes = sorted(next_nodes, key=lambda x: x.score, reverse=True)
            current_nodes = current_nodes[:self.num_branches]
        
        # 返回最优路径
        best_leaf = max(current_nodes, key=lambda x: x.score)
        return self._extract_answer(best_leaf)
    
    def _generate_thoughts(self, node):
        """生成候选思考"""
        prompt = f"""基于当前推理状态,生成{self.num_branches}个不同的下一步思考方向。

当前状态:
{node.state}

请生成{self.num_branches}个候选思考(每个一行):"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        outputs = self.model.generate(
            **inputs, max_new_tokens=256, temperature=0.8, do_sample=True
        )
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 解析多个思考
        thoughts = [t.strip() for t in response.split("\n") if t.strip()]
        return thoughts[:self.num_branches]
    
    def _evaluate_thought(self, node):
        """评估思考的前景"""
        prompt = f"""评估以下推理步骤的质量(1-10分)。

问题:{node.parent.state if node.parent else node.state}
当前思考:{node.thought}

评分(1-10):"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        outputs = self.model.generate(**inputs, max_new_tokens=32)
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 提取分数
        try:
            score = float(''.join(c for c in response if c.isdigit())[:2]) / 10
        except:
            score = 0.5
        return min(max(score, 0), 1)

束搜索实现

class BeamSearchToT:
    def __init__(self, model, tokenizer, beam_width=3):
        self.model = model
        self.tokenizer = tokenizer
        self.beam_width = beam_width
    
    def search(self, problem):
        """束搜索:维护top-k候选"""
        # 初始化:生成多个起始思考
        initial_thoughts = self._generate_thoughts(problem)
        beams = [(thought, 0.0, [thought]) for thought in initial_thoughts]
        
        for step in range(5):  # 5步推理
            all_candidates = []
            
            for current_thought, score, path in beams:
                # 对每个候选生成下一步
                next_thoughts = self._generate_thoughts(
                    f"{problem}\n思考历史:{' → '.join(path)}"
                )
                
                for next_thought in next_thoughts:
                    new_score = score + self._score_thought(next_thought)
                    new_path = path + [next_thought]
                    all_candidates.append((next_thought, new_score, new_path))
            
            # 保留top-k
            all_candidates.sort(key=lambda x: x[1], reverse=True)
            beams = all_candidates[:self.beam_width]
        
        # 返回最优路径
        best = max(beams, key=lambda x: x[1])
        return {
            "path": best[2],
            "score": best[1],
            "answer": self._extract_final(best[2]),
        }
    
    def _score_thought(self, thought):
        prompt = f"评估推理质量(0-1):{thought}"
        inputs = self.tokenizer(prompt, return_tensors="pt")
        outputs = self.model.generate(**inputs, max_new_tokens=16)
        try:
            return float(self.tokenizer.decode(outputs[0])[:3])
        except:
            return 0.5

搜索策略对比

class ToTStrategies:
    @staticmethod
    def bfs_search(root, goal_test, max_depth=5):
        """广度优先搜索:全面探索"""
        from collections import deque
        queue = deque([(root, 0)])
        while queue:
            node, depth = queue.popleft()
            if goal_test(node) or depth >= max_depth:
                return node
            for child in node.children:
                queue.append((child, depth + 1))
    
    @staticmethod
    def dfs_search(node, goal_test, max_depth=5, depth=0):
        """深度优先搜索:深入探索"""
        if goal_test(node) or depth >= max_depth:
            return node
        for child in node.children:
            result = ToTStrategies.dfs_search(child, goal_test, max_depth, depth + 1)
            if result:
                return result
        return None
    
    @staticmethod
    def mcts_search(root, num_simulations=100):
        """蒙特卡洛树搜索:平衡探索与利用"""
        for _ in range(num_simulations):
            node = root
            # 选择
            while node.children:
                node = max(node.children, key=lambda n: n.score)
            # 扩展和模拟
            # ...回溯更新

适用场景

场景 说明 推荐策略
数学证明 需要多步推导和回溯 DFS + 剪枝
文本创作 需要探索不同叙事方向 BFS + 多样性
规划任务 需要全局最优解 MCTS
代码生成 需要验证中间结果 束搜索 + 验证

ToT通过系统性地探索多条推理路径,显著提升了模型在复杂规划和搜索任务中的表现。