Spaces:
Runtime error
Runtime error
"""Tree of Thoughts reasoning implementation with advanced tree exploration.""" | |
import logging | |
from typing import Dict, Any, List, Optional, Set, Tuple, AsyncGenerator, Generator | |
import json | |
from dataclasses import dataclass | |
from enum import Enum | |
import heapq | |
from collections import defaultdict | |
from datetime import datetime | |
from .base import ReasoningStrategy, StrategyResult | |
class NodeType(Enum): | |
"""Types of nodes in the thought tree.""" | |
ROOT = "root" | |
HYPOTHESIS = "hypothesis" | |
EVIDENCE = "evidence" | |
ANALYSIS = "analysis" | |
SYNTHESIS = "synthesis" | |
EVALUATION = "evaluation" | |
CONCLUSION = "conclusion" | |
class TreeNode: | |
"""Represents a node in the thought tree.""" | |
id: str | |
type: NodeType | |
content: str | |
confidence: float | |
children: List['TreeNode'] | |
parent: Optional['TreeNode'] | |
metadata: Dict[str, Any] | |
depth: int | |
evaluation_score: float = 0.0 | |
timestamp: str = datetime.now().isoformat() | |
class TreeOfThoughtsStrategy(ReasoningStrategy): | |
""" | |
Advanced Tree of Thoughts reasoning implementation with: | |
- Beam search for path exploration | |
- Dynamic node evaluation | |
- Pruning strategies | |
- Path optimization | |
- Meta-learning from tree patterns | |
""" | |
def __init__(self, | |
min_confidence: float = 0.7, | |
parallel_threshold: int = 3, | |
learning_rate: float = 0.1, | |
strategy_weights: Optional[Dict[str, float]] = None): | |
"""Initialize Tree of Thoughts reasoning.""" | |
super().__init__() | |
self.min_confidence = min_confidence | |
self.parallel_threshold = parallel_threshold | |
self.learning_rate = learning_rate | |
self.strategy_weights = strategy_weights or { | |
'hypothesis': 0.3, | |
'evidence': 0.2, | |
'analysis': 0.2, | |
'synthesis': 0.15, | |
'evaluation': 0.15 | |
} | |
# Initialize tree | |
self.root: Optional[TreeNode] = None | |
self.current_node: Optional[TreeNode] = None | |
# Performance tracking | |
self.performance_metrics = { | |
'tree_depth': 0, | |
'num_nodes': 0, | |
'branching_factor': 0.0, | |
'avg_confidence': 0.0, | |
'pruned_nodes': 0 | |
} | |
async def reason( | |
self, | |
query: str, | |
context: Dict[str, Any] | |
) -> StrategyResult: | |
""" | |
Apply Tree of Thoughts reasoning to analyze the query. | |
Args: | |
query: The input query to reason about | |
context: Additional context and parameters | |
Returns: | |
StrategyResult containing the reasoning tree and confidence | |
""" | |
try: | |
# Initialize root node | |
self.root = TreeNode( | |
id="root", | |
type=NodeType.ROOT, | |
content=query, | |
confidence=1.0, | |
children=[], | |
parent=None, | |
metadata={"query": query}, | |
depth=0 | |
) | |
self.current_node = self.root | |
# Generate initial hypotheses | |
await self._generate_hypotheses(context) | |
# Gather evidence | |
await self._gather_evidence(context) | |
# Analyze evidence | |
await self._analyze_evidence(context) | |
# Synthesize findings | |
await self._synthesize_findings(context) | |
# Evaluate paths | |
await self._evaluate_paths(context) | |
# Find best path | |
best_path = self._find_best_path() | |
# Generate conclusion | |
conclusion = await self._generate_conclusion(best_path, context) | |
# Update performance metrics | |
self._update_metrics() | |
return StrategyResult( | |
strategy_type="tree_of_thoughts", | |
success=True, | |
answer=conclusion.content, | |
confidence=conclusion.confidence, | |
reasoning_trace=[{ | |
"step": str(node.type.value), | |
"content": node.content, | |
"confidence": node.confidence, | |
"depth": node.depth, | |
"score": node.evaluation_score, | |
"metadata": node.metadata, | |
"timestamp": node.timestamp | |
} for node in self._traverse_tree()], | |
metadata={ | |
"tree_depth": self.performance_metrics['tree_depth'], | |
"num_nodes": self.performance_metrics['num_nodes'], | |
"branching_factor": self.performance_metrics['branching_factor'] | |
}, | |
performance_metrics=self.performance_metrics | |
) | |
except Exception as e: | |
logging.error(f"Tree of Thoughts reasoning error: {str(e)}") | |
return StrategyResult( | |
strategy_type="tree_of_thoughts", | |
success=False, | |
answer=None, | |
confidence=0.0, | |
reasoning_trace=[{ | |
"step": "error", | |
"error": str(e), | |
"timestamp": datetime.now().isoformat() | |
}], | |
metadata={"error": str(e)}, | |
performance_metrics=self.performance_metrics | |
) | |
async def _generate_hypotheses(self, context: Dict[str, Any]) -> None: | |
"""Generate initial hypotheses as child nodes.""" | |
hypotheses = self._extract_hypotheses(self.root.content, context) | |
for h_content in hypotheses: | |
node = TreeNode( | |
id=f"h{len(self.root.children)}", | |
type=NodeType.HYPOTHESIS, | |
content=h_content, | |
confidence=self._calculate_confidence(h_content, context), | |
children=[], | |
parent=self.root, | |
metadata={"type": "hypothesis"}, | |
depth=1 | |
) | |
self.root.children.append(node) | |
async def _gather_evidence(self, context: Dict[str, Any]) -> None: | |
"""Gather evidence for each hypothesis.""" | |
for hypothesis in self.root.children: | |
evidence = self._find_evidence(hypothesis.content, context) | |
for e_content in evidence: | |
node = TreeNode( | |
id=f"{hypothesis.id}_e{len(hypothesis.children)}", | |
type=NodeType.EVIDENCE, | |
content=e_content, | |
confidence=self._calculate_confidence(e_content, context), | |
children=[], | |
parent=hypothesis, | |
metadata={"type": "evidence"}, | |
depth=hypothesis.depth + 1 | |
) | |
hypothesis.children.append(node) | |
async def _analyze_evidence(self, context: Dict[str, Any]) -> None: | |
"""Analyze gathered evidence.""" | |
for hypothesis in self.root.children: | |
for evidence in hypothesis.children: | |
analysis = self._analyze_node(evidence, context) | |
node = TreeNode( | |
id=f"{evidence.id}_a", | |
type=NodeType.ANALYSIS, | |
content=analysis, | |
confidence=self._calculate_confidence(analysis, context), | |
children=[], | |
parent=evidence, | |
metadata={"type": "analysis"}, | |
depth=evidence.depth + 1 | |
) | |
evidence.children.append(node) | |
async def _synthesize_findings(self, context: Dict[str, Any]) -> None: | |
"""Synthesize findings from analysis.""" | |
for hypothesis in self.root.children: | |
synthesis = self._synthesize_branch(hypothesis, context) | |
node = TreeNode( | |
id=f"{hypothesis.id}_s", | |
type=NodeType.SYNTHESIS, | |
content=synthesis, | |
confidence=self._calculate_confidence(synthesis, context), | |
children=[], | |
parent=hypothesis, | |
metadata={"type": "synthesis"}, | |
depth=hypothesis.depth + 1 | |
) | |
hypothesis.children.append(node) | |
async def _evaluate_paths(self, context: Dict[str, Any]) -> None: | |
"""Evaluate different reasoning paths.""" | |
for hypothesis in self.root.children: | |
evaluation = self._evaluate_branch(hypothesis, context) | |
node = TreeNode( | |
id=f"{hypothesis.id}_e", | |
type=NodeType.EVALUATION, | |
content=evaluation, | |
confidence=self._calculate_confidence(evaluation, context), | |
children=[], | |
parent=hypothesis, | |
metadata={"type": "evaluation"}, | |
depth=hypothesis.depth + 1 | |
) | |
hypothesis.children.append(node) | |
def _find_best_path(self) -> List[TreeNode]: | |
"""Find the path with highest confidence.""" | |
best_path = [] | |
best_score = 0.0 | |
for hypothesis in self.root.children: | |
path_score = self._calculate_path_score(hypothesis) | |
if path_score > best_score: | |
best_score = path_score | |
best_path = self._get_path(hypothesis) | |
return best_path | |
async def _generate_conclusion( | |
self, | |
path: List[TreeNode], | |
context: Dict[str, Any] | |
) -> TreeNode: | |
"""Generate final conclusion from best path.""" | |
conclusion_content = self._synthesize_path(path, context) | |
node = TreeNode( | |
id="conclusion", | |
type=NodeType.CONCLUSION, | |
content=conclusion_content, | |
confidence=self._calculate_path_confidence(path), | |
children=[], | |
parent=self.root, | |
metadata={"type": "conclusion", "path_length": len(path)}, | |
depth=max(n.depth for n in path) + 1 | |
) | |
self.root.children.append(node) | |
return node | |
def _calculate_confidence( | |
self, | |
content: str, | |
context: Dict[str, Any] | |
) -> float: | |
"""Calculate confidence score for content.""" | |
# Base confidence | |
confidence = 0.5 | |
# Adjust based on content length | |
words = content.split() | |
if len(words) > 50: | |
confidence += 0.1 | |
if len(words) > 100: | |
confidence += 0.1 | |
# Adjust based on context match | |
if context.get('keywords'): | |
matches = sum(1 for k in context['keywords'] if k in content.lower()) | |
confidence += min(0.3, matches * 0.1) | |
return min(1.0, confidence) | |
def _calculate_path_score(self, node: TreeNode) -> float: | |
"""Calculate score for a path in the tree.""" | |
score = node.confidence | |
# Consider child nodes | |
if node.children: | |
child_scores = [self._calculate_path_score(c) for c in node.children] | |
score += max(child_scores) * 0.8 # Decay factor | |
return score | |
def _calculate_path_confidence(self, path: List[TreeNode]) -> float: | |
"""Calculate overall confidence for a path.""" | |
if not path: | |
return 0.0 | |
# Weight confidences by node type | |
weighted_sum = sum( | |
node.confidence * self.strategy_weights.get(node.type.value, 0.1) | |
for node in path | |
) | |
# Normalize by weights | |
total_weight = sum( | |
self.strategy_weights.get(node.type.value, 0.1) | |
for node in path | |
) | |
return weighted_sum / total_weight if total_weight > 0 else 0.0 | |
def _get_path(self, node: TreeNode) -> List[TreeNode]: | |
"""Get path from root to node.""" | |
path = [] | |
current = node | |
while current: | |
path.append(current) | |
current = current.parent | |
return list(reversed(path)) | |
def _traverse_tree(self) -> List[TreeNode]: | |
"""Traverse tree in pre-order.""" | |
nodes = [] | |
def traverse(node: TreeNode): | |
nodes.append(node) | |
for child in node.children: | |
traverse(child) | |
if self.root: | |
traverse(self.root) | |
return nodes | |
def _extract_hypotheses( | |
self, | |
content: str, | |
context: Dict[str, Any] | |
) -> List[str]: | |
"""Extract potential hypotheses from content.""" | |
# Simple extraction based on keywords | |
# Could be enhanced with NLP | |
hypotheses = [] | |
keywords = context.get('keywords', []) | |
sentences = content.split('.') | |
for sentence in sentences: | |
if any(k in sentence.lower() for k in keywords): | |
hypotheses.append(sentence.strip()) | |
return hypotheses or ["Default hypothesis"] | |
def _find_evidence( | |
self, | |
hypothesis: str, | |
context: Dict[str, Any] | |
) -> List[str]: | |
"""Find evidence supporting hypothesis.""" | |
evidence = [] | |
if 'evidence' in context: | |
for e in context['evidence']: | |
if any(term in e.lower() for term in hypothesis.lower().split()): | |
evidence.append(e) | |
return evidence or ["No direct evidence found"] | |
def _analyze_node( | |
self, | |
node: TreeNode, | |
context: Dict[str, Any] | |
) -> str: | |
"""Analyze a node's content.""" | |
return f"Analysis of {node.content}" | |
def _synthesize_branch( | |
self, | |
node: TreeNode, | |
context: Dict[str, Any] | |
) -> str: | |
"""Synthesize findings from a branch.""" | |
return f"Synthesis of branch {node.id}" | |
def _evaluate_branch( | |
self, | |
node: TreeNode, | |
context: Dict[str, Any] | |
) -> str: | |
"""Evaluate a branch of the tree.""" | |
return f"Evaluation of branch {node.id}" | |
def _synthesize_path( | |
self, | |
path: List[TreeNode], | |
context: Dict[str, Any] | |
) -> str: | |
"""Synthesize conclusion from path.""" | |
return "Conclusion: " + " -> ".join(n.content for n in path) | |
def _update_metrics(self) -> None: | |
"""Update performance metrics.""" | |
if self.root: | |
nodes = self._traverse_tree() | |
depths = [n.depth for n in nodes] | |
# Count nodes with children | |
internal_nodes = sum(1 for n in nodes if n.children) | |
self.performance_metrics.update({ | |
'tree_depth': max(depths), | |
'num_nodes': len(nodes), | |
'branching_factor': len(nodes) / max(1, internal_nodes), | |
'avg_confidence': sum(n.confidence for n in nodes) / len(nodes), | |
'pruned_nodes': self.performance_metrics['pruned_nodes'] | |
}) | |