"""Tree of Thoughts reasoning implementation with advanced tree exploration.""" import logging from typing import Dict, Any, List, Optional, Set, Tuple import json from dataclasses import dataclass from enum import Enum import heapq from collections import defaultdict from .base import ReasoningStrategy class NodeType(Enum): """Types of nodes in the thought tree.""" ROOT = "root" HYPOTHESIS = "hypothesis" EVIDENCE = "evidence" ANALYSIS = "analysis" SYNTHESIS = "synthesis" EVALUATION = "evaluation" CONCLUSION = "conclusion" @dataclass class TreeNode: """Represents a node in the thought tree.""" id: str type: NodeType content: str confidence: float children: List['TreeNode'] parent: Optional['TreeNode'] metadata: Dict[str, Any] depth: int evaluation_score: float = 0.0 class TreeOfThoughtsStrategy(ReasoningStrategy): """ Advanced Tree of Thoughts reasoning implementation with: - Beam search for path exploration - Dynamic node evaluation - Pruning strategies - Path optimization - Meta-learning from tree patterns """ def __init__(self, min_confidence: float = 0.7, parallel_threshold: int = 3, learning_rate: float = 0.1, strategy_weights: Optional[Dict[str, float]] = None): self.min_confidence = min_confidence self.parallel_threshold = parallel_threshold self.learning_rate = learning_rate self.strategy_weights = strategy_weights or { "LOCAL_LLM": 0.8, "CHAIN_OF_THOUGHT": 0.6, "TREE_OF_THOUGHTS": 0.5, "META_LEARNING": 0.4 } self.node_history: Dict[str, TreeNode] = {} self.path_patterns: Dict[str, float] = defaultdict(float) async def reason(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]: """Main reasoning method implementing tree of thoughts.""" try: # Initialize root node root = await self._create_root_node(query, context) # Build and explore tree tree = await self._build_tree(root, context) # Find best paths paths = await self._find_best_paths(tree, context) # Synthesize conclusion conclusion = await self._synthesize_conclusion(paths, context) # Update history and patterns self._update_history(tree) self._update_patterns(paths) return { "success": True, "answer": conclusion["answer"], "confidence": conclusion["confidence"], "tree": self._tree_to_dict(tree), "best_paths": [self._path_to_dict(p) for p in paths], "reasoning_trace": conclusion["trace"], "meta_insights": conclusion["meta_insights"] } except Exception as e: logging.error(f"Error in tree of thoughts reasoning: {str(e)}") return {"success": False, "error": str(e)} async def _create_root_node(self, query: str, context: Dict[str, Any]) -> TreeNode: """Create the root node of the thought tree.""" prompt = f""" Initialize root thought node for query: Query: {query} Context: {json.dumps(context)} Provide: 1. Initial problem decomposition 2. Key aspects to explore 3. Evaluation criteria 4. Success metrics Format as: [Root] Decomposition: ... Aspects: ... Criteria: ... Metrics: ... """ response = await context["groq_api"].predict(prompt) return self._parse_root_node(response["answer"], query) async def _build_tree(self, root: TreeNode, context: Dict[str, Any]) -> TreeNode: """Build and explore the thought tree.""" # Initialize beam with root beam = [(root.evaluation_score, root)] visited: Set[str] = set() for depth in range(5): next_beam = [] for _, node in beam: if node.id in visited: continue visited.add(node.id) # Generate child nodes children = await self._generate_children(node, context) # Evaluate and filter children evaluated_children = await self._evaluate_nodes(children, context) # Add to beam for child in evaluated_children: if child.evaluation_score > 0.4: next_beam.append((child.evaluation_score, child)) node.children.append(child) # Select best nodes for next iteration beam = heapq.nlargest(3, next_beam, key=lambda x: x[0]) if not beam: break return root async def _generate_children(self, parent: TreeNode, context: Dict[str, Any]) -> List[TreeNode]: """Generate child nodes for a given parent.""" prompt = f""" Generate child thoughts for node: Parent: {json.dumps(self._node_to_dict(parent))} Context: {json.dumps(context)} For each child provide: 1. [Type]: {" | ".join([t.value for t in NodeType if t != NodeType.ROOT])} 2. [Content]: Main thought 3. [Confidence]: 0-1 score 4. [Rationale]: Why this follows from parent 5. [Potential]: Future exploration potential Format as: [C1] Type: ... Content: ... Confidence: ... Rationale: ... Potential: ... """ response = await context["groq_api"].predict(prompt) return self._parse_child_nodes(response["answer"], parent) async def _evaluate_nodes(self, nodes: List[TreeNode], context: Dict[str, Any]) -> List[TreeNode]: """Evaluate a list of nodes.""" prompt = f""" Evaluate thought nodes: Nodes: {json.dumps([self._node_to_dict(n) for n in nodes])} Context: {json.dumps(context)} For each node evaluate: 1. Logical coherence 2. Evidence support 3. Novelty value 4. Exploration potential Format as: [N1] Coherence: 0-1 Evidence: 0-1 Novelty: 0-1 Potential: 0-1 Overall: 0-1 """ response = await context["groq_api"].predict(prompt) return self._apply_evaluations(nodes, response["answer"]) async def _find_best_paths(self, root: TreeNode, context: Dict[str, Any]) -> List[List[TreeNode]]: """Find the best paths through the tree.""" paths = [] current_path = [root] def dfs(node: TreeNode, path: List[TreeNode]): if not node.children: paths.append(path[:]) return # Sort children by score sorted_children = sorted(node.children, key=lambda x: x.evaluation_score, reverse=True) # Explore top paths for child in sorted_children[:3]: path.append(child) dfs(child, path) path.pop() dfs(root, current_path) # Evaluate complete paths evaluated_paths = await self._evaluate_paths(paths, context) # Return top paths return sorted(evaluated_paths, key=lambda p: sum(n.evaluation_score for n in p), reverse=True)[:3] async def _synthesize_conclusion(self, paths: List[List[TreeNode]], context: Dict[str, Any]) -> Dict[str, Any]: """Synthesize final conclusion from best paths.""" prompt = f""" Synthesize conclusion from thought paths: Paths: {json.dumps([[self._node_to_dict(n) for n in path] for path in paths])} Context: {json.dumps(context)} Provide: 1. Main conclusion 2. Confidence level 3. Reasoning trace 4. Supporting evidence 5. Alternative perspectives 6. Meta-insights Format as: [Conclusion] Answer: ... Confidence: ... Trace: ... Evidence: ... Alternatives: ... [Meta] Insights: ... Patterns: ... """ response = await context["groq_api"].predict(prompt) return self._parse_conclusion(response["answer"]) def _parse_root_node(self, response: str, query: str) -> TreeNode: """Parse root node from response.""" root = TreeNode( id="root", type=NodeType.ROOT, content=query, confidence=1.0, children=[], parent=None, metadata={}, depth=0 ) for line in response.split('\n'): line = line.strip() if line.startswith('Decomposition:'): root.metadata["decomposition"] = line[14:].strip() elif line.startswith('Aspects:'): root.metadata["aspects"] = [a.strip() for a in line[8:].split(',')] elif line.startswith('Criteria:'): root.metadata["criteria"] = [c.strip() for c in line[9:].split(',')] elif line.startswith('Metrics:'): root.metadata["metrics"] = [m.strip() for m in line[8:].split(',')] return root def _parse_child_nodes(self, response: str, parent: TreeNode) -> List[TreeNode]: """Parse child nodes from response.""" children = [] current = None for line in response.split('\n'): line = line.strip() if not line: continue if line.startswith('[C'): if current: children.append(current) current = None elif line.startswith('Type:'): type_str = line[5:].strip() try: node_type = NodeType(type_str.lower()) current = TreeNode( id=f"{parent.id}_{len(children)}", type=node_type, content="", confidence=0.0, children=[], parent=parent, metadata={}, depth=parent.depth + 1 ) except ValueError: logging.warning(f"Invalid node type: {type_str}") elif current: if line.startswith('Content:'): current.content = line[8:].strip() elif line.startswith('Confidence:'): try: current.confidence = float(line[11:].strip()) except: current.confidence = 0.5 elif line.startswith('Rationale:'): current.metadata["rationale"] = line[10:].strip() elif line.startswith('Potential:'): current.metadata["potential"] = line[10:].strip() if current: children.append(current) return children def _apply_evaluations(self, nodes: List[TreeNode], response: str) -> List[TreeNode]: """Apply evaluation scores to nodes.""" current_node_idx = 0 current_scores = {} for line in response.split('\n'): line = line.strip() if not line: continue if line.startswith('[N'): if current_scores and current_node_idx < len(nodes): nodes[current_node_idx].evaluation_score = current_scores.get("Overall", 0.0) nodes[current_node_idx].metadata.update(current_scores) current_node_idx += 1 current_scores = {} elif ':' in line: key, value = line.split(':') try: current_scores[key.strip()] = float(value.strip()) except: pass if current_scores and current_node_idx < len(nodes): nodes[current_node_idx].evaluation_score = current_scores.get("Overall", 0.0) nodes[current_node_idx].metadata.update(current_scores) return nodes async def _evaluate_paths(self, paths: List[List[TreeNode]], context: Dict[str, Any]) -> List[List[TreeNode]]: """Evaluate complete reasoning paths.""" prompt = f""" Evaluate complete reasoning paths: Paths: {json.dumps([[self._node_to_dict(n) for n in path] for path in paths])} Context: {json.dumps(context)} For each path evaluate: 1. Coherence of progression 2. Evidence support 3. Conclusion strength 4. Novel insights Format as: [P1] Coherence: 0-1 Evidence: 0-1 Conclusion: 0-1 Insights: 0-1 Overall: 0-1 """ response = await context["groq_api"].predict(prompt) scores = self._parse_path_scores(response["answer"]) # Apply scores to paths for i, path in enumerate(paths): if i < len(scores): for node in path: node.evaluation_score *= scores[i] return paths def _parse_path_scores(self, response: str) -> List[float]: """Parse path evaluation scores.""" scores = [] current_score = None for line in response.split('\n'): line = line.strip() if not line: continue if line.startswith('[P'): if current_score is not None: scores.append(current_score) current_score = None elif line.startswith('Overall:'): try: current_score = float(line[8:].strip()) except: current_score = 0.5 if current_score is not None: scores.append(current_score) return scores def _parse_conclusion(self, response: str) -> Dict[str, Any]: """Parse final conclusion.""" conclusion = { "answer": "", "confidence": 0.0, "trace": [], "evidence": [], "alternatives": [], "meta_insights": [] } section = None for line in response.split('\n'): line = line.strip() if not line: continue if line.startswith('[Conclusion]'): section = "conclusion" elif line.startswith('[Meta]'): section = "meta" elif section == "conclusion": if line.startswith('Answer:'): conclusion["answer"] = line[7:].strip() elif line.startswith('Confidence:'): try: conclusion["confidence"] = float(line[11:].strip()) except: conclusion["confidence"] = 0.5 elif line.startswith('Trace:'): conclusion["trace"] = [t.strip() for t in line[6:].split(',')] elif line.startswith('Evidence:'): conclusion["evidence"] = [e.strip() for e in line[9:].split(',')] elif line.startswith('Alternatives:'): conclusion["alternatives"] = [a.strip() for a in line[13:].split(',')] elif section == "meta": if line.startswith('Insights:'): conclusion["meta_insights"].extend([i.strip() for i in line[9:].split(',')]) return conclusion def _node_to_dict(self, node: TreeNode) -> Dict[str, Any]: """Convert node to dictionary for serialization.""" return { "id": node.id, "type": node.type.value, "content": node.content, "confidence": node.confidence, "evaluation_score": node.evaluation_score, "metadata": node.metadata, "depth": node.depth } def _tree_to_dict(self, root: TreeNode) -> Dict[str, Any]: """Convert entire tree to dictionary.""" def convert_node(node: TreeNode) -> Dict[str, Any]: node_dict = self._node_to_dict(node) node_dict["children"] = [convert_node(c) for c in node.children] return node_dict return convert_node(root) def _path_to_dict(self, path: List[TreeNode]) -> List[Dict[str, Any]]: """Convert path to dictionary.""" return [self._node_to_dict(n) for n in path] def _update_history(self, root: TreeNode): """Update node history.""" def add_to_history(node: TreeNode): self.node_history[node.id] = node for child in node.children: add_to_history(child) add_to_history(root) def _update_patterns(self, paths: List[List[TreeNode]]): """Update path patterns.""" for path in paths: pattern = "->".join(n.type.value for n in path) self.path_patterns[pattern] += path[-1].evaluation_score def get_node_history(self) -> Dict[str, Dict[str, Any]]: """Get history of all nodes.""" return {k: self._node_to_dict(v) for k, v in self.node_history.items()} def get_successful_patterns(self) -> Dict[str, float]: """Get successful reasoning patterns.""" return dict(sorted(self.path_patterns.items(), key=lambda x: x[1], reverse=True)) def clear_history(self): """Clear node history and patterns.""" self.node_history.clear() self.path_patterns.clear()