Spaces:
Runtime error
Runtime error
"""Advanced strategy coordination patterns for the unified reasoning engine.""" | |
import logging | |
from typing import Dict, Any, List, Optional, Set, Union, Type, Callable | |
import json | |
from dataclasses import dataclass, field | |
from enum import Enum | |
from datetime import datetime | |
import asyncio | |
from collections import defaultdict | |
from .base import ReasoningStrategy | |
from .unified_engine import StrategyType, StrategyResult, UnifiedResult | |
class CoordinationPattern(Enum): | |
"""Types of strategy coordination patterns.""" | |
PIPELINE = "pipeline" | |
PARALLEL = "parallel" | |
HIERARCHICAL = "hierarchical" | |
FEEDBACK = "feedback" | |
ADAPTIVE = "adaptive" | |
ENSEMBLE = "ensemble" | |
class CoordinationPhase(Enum): | |
"""Phases in strategy coordination.""" | |
INITIALIZATION = "initialization" | |
EXECUTION = "execution" | |
SYNCHRONIZATION = "synchronization" | |
ADAPTATION = "adaptation" | |
COMPLETION = "completion" | |
class CoordinationState: | |
"""State of strategy coordination.""" | |
pattern: CoordinationPattern | |
active_strategies: Dict[StrategyType, bool] | |
phase: CoordinationPhase | |
shared_context: Dict[str, Any] | |
synchronization_points: List[str] | |
adaptation_history: List[Dict[str, Any]] | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
class StrategyInteraction: | |
"""Interaction between strategies.""" | |
source: StrategyType | |
target: StrategyType | |
interaction_type: str | |
data: Dict[str, Any] | |
timestamp: datetime = field(default_factory=datetime.now) | |
class StrategyCoordinator: | |
""" | |
Advanced strategy coordinator that: | |
1. Manages strategy interactions | |
2. Implements coordination patterns | |
3. Handles state synchronization | |
4. Adapts coordination dynamically | |
5. Optimizes strategy combinations | |
""" | |
def __init__(self, | |
strategies: Dict[StrategyType, ReasoningStrategy], | |
learning_rate: float = 0.1): | |
self.strategies = strategies | |
self.learning_rate = learning_rate | |
# Coordination state | |
self.states: Dict[str, CoordinationState] = {} | |
self.interactions: List[StrategyInteraction] = [] | |
# Pattern performance | |
self.pattern_performance: Dict[CoordinationPattern, List[float]] = defaultdict(list) | |
self.pattern_weights: Dict[CoordinationPattern, float] = { | |
pattern: 1.0 for pattern in CoordinationPattern | |
} | |
async def coordinate(self, | |
query: str, | |
context: Dict[str, Any], | |
pattern: Optional[CoordinationPattern] = None) -> Dict[str, Any]: | |
"""Coordinate strategy execution using specified pattern.""" | |
try: | |
# Select pattern if not specified | |
if not pattern: | |
pattern = await self._select_pattern(query, context) | |
# Initialize coordination | |
state = await self._initialize_coordination(pattern, context) | |
# Execute coordination pattern | |
if pattern == CoordinationPattern.PIPELINE: | |
result = await self._coordinate_pipeline(query, context, state) | |
elif pattern == CoordinationPattern.PARALLEL: | |
result = await self._coordinate_parallel(query, context, state) | |
elif pattern == CoordinationPattern.HIERARCHICAL: | |
result = await self._coordinate_hierarchical(query, context, state) | |
elif pattern == CoordinationPattern.FEEDBACK: | |
result = await self._coordinate_feedback(query, context, state) | |
elif pattern == CoordinationPattern.ADAPTIVE: | |
result = await self._coordinate_adaptive(query, context, state) | |
elif pattern == CoordinationPattern.ENSEMBLE: | |
result = await self._coordinate_ensemble(query, context, state) | |
else: | |
raise ValueError(f"Unsupported coordination pattern: {pattern}") | |
# Update performance metrics | |
self._update_pattern_performance(pattern, result) | |
return result | |
except Exception as e: | |
logging.error(f"Error in strategy coordination: {str(e)}") | |
return { | |
"success": False, | |
"error": str(e), | |
"pattern": pattern.value if pattern else None | |
} | |
async def _select_pattern(self, query: str, context: Dict[str, Any]) -> CoordinationPattern: | |
"""Select appropriate coordination pattern.""" | |
prompt = f""" | |
Select coordination pattern: | |
Query: {query} | |
Context: {json.dumps(context)} | |
Consider: | |
1. Task complexity and type | |
2. Strategy dependencies | |
3. Resource constraints | |
4. Performance history | |
5. Adaptation needs | |
Format as: | |
[Selection] | |
Pattern: ... | |
Rationale: ... | |
Confidence: ... | |
""" | |
response = await context["groq_api"].predict(prompt) | |
selection = self._parse_pattern_selection(response["answer"]) | |
# Weight by performance history | |
weighted_patterns = { | |
pattern: self.pattern_weights[pattern] * selection.get(pattern.value, 0.0) | |
for pattern in CoordinationPattern | |
} | |
return max(weighted_patterns.items(), key=lambda x: x[1])[0] | |
async def _coordinate_pipeline(self, | |
query: str, | |
context: Dict[str, Any], | |
state: CoordinationState) -> Dict[str, Any]: | |
"""Coordinate strategies in pipeline pattern.""" | |
results = [] | |
current_context = context.copy() | |
# Determine optimal order | |
strategy_order = await self._determine_pipeline_order(query, context) | |
for strategy_type in strategy_order: | |
try: | |
# Execute strategy | |
strategy = self.strategies[strategy_type] | |
result = await strategy.reason(query, current_context) | |
# Update context with result | |
current_context.update({ | |
"previous_result": result, | |
"pipeline_position": len(results) | |
}) | |
results.append(StrategyResult( | |
strategy_type=strategy_type, | |
success=result.get("success", False), | |
answer=result.get("answer"), | |
confidence=result.get("confidence", 0.0), | |
reasoning_trace=result.get("reasoning_trace", []), | |
metadata=result.get("metadata", {}), | |
performance_metrics=result.get("performance_metrics", {}) | |
)) | |
# Record interaction | |
self._record_interaction( | |
source=strategy_type, | |
target=strategy_order[len(results)] if len(results) < len(strategy_order) else None, | |
interaction_type="pipeline_transfer", | |
data={"result": result} | |
) | |
except Exception as e: | |
logging.error(f"Error in pipeline strategy {strategy_type}: {str(e)}") | |
return { | |
"success": any(r.success for r in results), | |
"results": results, | |
"pattern": CoordinationPattern.PIPELINE.value, | |
"metrics": { | |
"total_steps": len(results), | |
"success_rate": sum(1 for r in results if r.success) / len(results) if results else 0 | |
} | |
} | |
async def _coordinate_parallel(self, | |
query: str, | |
context: Dict[str, Any], | |
state: CoordinationState) -> Dict[str, Any]: | |
"""Coordinate strategies in parallel pattern.""" | |
async def execute_strategy(strategy_type: StrategyType) -> StrategyResult: | |
try: | |
strategy = self.strategies[strategy_type] | |
result = await strategy.reason(query, context) | |
return StrategyResult( | |
strategy_type=strategy_type, | |
success=result.get("success", False), | |
answer=result.get("answer"), | |
confidence=result.get("confidence", 0.0), | |
reasoning_trace=result.get("reasoning_trace", []), | |
metadata=result.get("metadata", {}), | |
performance_metrics=result.get("performance_metrics", {}) | |
) | |
except Exception as e: | |
logging.error(f"Error in parallel strategy {strategy_type}: {str(e)}") | |
return StrategyResult( | |
strategy_type=strategy_type, | |
success=False, | |
answer=None, | |
confidence=0.0, | |
reasoning_trace=[{"error": str(e)}], | |
metadata={}, | |
performance_metrics={} | |
) | |
# Execute strategies in parallel | |
tasks = [execute_strategy(strategy_type) | |
for strategy_type in state.active_strategies | |
if state.active_strategies[strategy_type]] | |
results = await asyncio.gather(*tasks) | |
# Synthesize results | |
synthesis = await self._synthesize_parallel_results(results, context) | |
return { | |
"success": synthesis.get("success", False), | |
"results": results, | |
"synthesis": synthesis, | |
"pattern": CoordinationPattern.PARALLEL.value, | |
"metrics": { | |
"total_strategies": len(results), | |
"success_rate": sum(1 for r in results if r.success) / len(results) if results else 0 | |
} | |
} | |
async def _coordinate_hierarchical(self, | |
query: str, | |
context: Dict[str, Any], | |
state: CoordinationState) -> Dict[str, Any]: | |
"""Coordinate strategies in hierarchical pattern.""" | |
# Build strategy hierarchy | |
hierarchy = await self._build_strategy_hierarchy(query, context) | |
results = {} | |
async def execute_level(level_strategies: List[StrategyType], | |
level_context: Dict[str, Any]) -> List[StrategyResult]: | |
tasks = [] | |
for strategy_type in level_strategies: | |
if strategy_type in state.active_strategies and state.active_strategies[strategy_type]: | |
strategy = self.strategies[strategy_type] | |
tasks.append(strategy.reason(query, level_context)) | |
level_results = await asyncio.gather(*tasks) | |
return [ | |
StrategyResult( | |
strategy_type=strategy_type, | |
success=result.get("success", False), | |
answer=result.get("answer"), | |
confidence=result.get("confidence", 0.0), | |
reasoning_trace=result.get("reasoning_trace", []), | |
metadata=result.get("metadata", {}), | |
performance_metrics=result.get("performance_metrics", {}) | |
) | |
for strategy_type, result in zip(level_strategies, level_results) | |
] | |
# Execute hierarchy levels | |
current_context = context.copy() | |
for level, level_strategies in enumerate(hierarchy): | |
results[level] = await execute_level(level_strategies, current_context) | |
# Update context for next level | |
current_context.update({ | |
"previous_level_results": results[level], | |
"hierarchy_level": level | |
}) | |
return { | |
"success": any(any(r.success for r in level_results) | |
for level_results in results.values()), | |
"results": results, | |
"hierarchy": hierarchy, | |
"pattern": CoordinationPattern.HIERARCHICAL.value, | |
"metrics": { | |
"total_levels": len(hierarchy), | |
"level_success_rates": { | |
level: sum(1 for r in results[level] if r.success) / len(results[level]) | |
for level in results if results[level] | |
} | |
} | |
} | |
async def _coordinate_feedback(self, | |
query: str, | |
context: Dict[str, Any], | |
state: CoordinationState) -> Dict[str, Any]: | |
"""Coordinate strategies with feedback loops.""" | |
results = [] | |
feedback_history = [] | |
current_context = context.copy() | |
max_iterations = 5 # Prevent infinite loops | |
iteration = 0 | |
while iteration < max_iterations: | |
iteration += 1 | |
# Execute strategies | |
iteration_results = [] | |
for strategy_type in state.active_strategies: | |
if state.active_strategies[strategy_type]: | |
try: | |
strategy = self.strategies[strategy_type] | |
result = await strategy.reason(query, current_context) | |
strategy_result = StrategyResult( | |
strategy_type=strategy_type, | |
success=result.get("success", False), | |
answer=result.get("answer"), | |
confidence=result.get("confidence", 0.0), | |
reasoning_trace=result.get("reasoning_trace", []), | |
metadata=result.get("metadata", {}), | |
performance_metrics=result.get("performance_metrics", {}) | |
) | |
iteration_results.append(strategy_result) | |
except Exception as e: | |
logging.error(f"Error in feedback strategy {strategy_type}: {str(e)}") | |
results.append(iteration_results) | |
# Generate feedback | |
feedback = await self._generate_feedback(iteration_results, current_context) | |
feedback_history.append(feedback) | |
# Check termination condition | |
if self._should_terminate_feedback(feedback, iteration_results): | |
break | |
# Update context with feedback | |
current_context.update({ | |
"previous_results": iteration_results, | |
"feedback": feedback, | |
"iteration": iteration | |
}) | |
return { | |
"success": any(any(r.success for r in iteration_results) | |
for iteration_results in results), | |
"results": results, | |
"feedback_history": feedback_history, | |
"pattern": CoordinationPattern.FEEDBACK.value, | |
"metrics": { | |
"total_iterations": iteration, | |
"feedback_impact": self._calculate_feedback_impact(results, feedback_history) | |
} | |
} | |
async def _coordinate_adaptive(self, | |
query: str, | |
context: Dict[str, Any], | |
state: CoordinationState) -> Dict[str, Any]: | |
"""Coordinate strategies with adaptive selection.""" | |
results = [] | |
adaptations = [] | |
current_context = context.copy() | |
while len(results) < len(state.active_strategies): | |
# Select next strategy | |
next_strategy = await self._select_next_strategy( | |
results, state.active_strategies, current_context) | |
if not next_strategy: | |
break | |
try: | |
# Execute strategy | |
strategy = self.strategies[next_strategy] | |
result = await strategy.reason(query, current_context) | |
strategy_result = StrategyResult( | |
strategy_type=next_strategy, | |
success=result.get("success", False), | |
answer=result.get("answer"), | |
confidence=result.get("confidence", 0.0), | |
reasoning_trace=result.get("reasoning_trace", []), | |
metadata=result.get("metadata", {}), | |
performance_metrics=result.get("performance_metrics", {}) | |
) | |
results.append(strategy_result) | |
# Adapt strategy selection | |
adaptation = await self._adapt_strategy_selection( | |
strategy_result, current_context) | |
adaptations.append(adaptation) | |
# Update context | |
current_context.update({ | |
"previous_results": results, | |
"adaptations": adaptations, | |
"current_strategy": next_strategy | |
}) | |
except Exception as e: | |
logging.error(f"Error in adaptive strategy {next_strategy}: {str(e)}") | |
return { | |
"success": any(r.success for r in results), | |
"results": results, | |
"adaptations": adaptations, | |
"pattern": CoordinationPattern.ADAPTIVE.value, | |
"metrics": { | |
"total_strategies": len(results), | |
"adaptation_impact": self._calculate_adaptation_impact(results, adaptations) | |
} | |
} | |
async def _coordinate_ensemble(self, | |
query: str, | |
context: Dict[str, Any], | |
state: CoordinationState) -> Dict[str, Any]: | |
"""Coordinate strategies as an ensemble.""" | |
# Execute all strategies | |
results = [] | |
for strategy_type in state.active_strategies: | |
if state.active_strategies[strategy_type]: | |
try: | |
strategy = self.strategies[strategy_type] | |
result = await strategy.reason(query, context) | |
strategy_result = StrategyResult( | |
strategy_type=strategy_type, | |
success=result.get("success", False), | |
answer=result.get("answer"), | |
confidence=result.get("confidence", 0.0), | |
reasoning_trace=result.get("reasoning_trace", []), | |
metadata=result.get("metadata", {}), | |
performance_metrics=result.get("performance_metrics", {}) | |
) | |
results.append(strategy_result) | |
except Exception as e: | |
logging.error(f"Error in ensemble strategy {strategy_type}: {str(e)}") | |
# Combine results using ensemble methods | |
ensemble_result = await self._combine_ensemble_results(results, context) | |
return { | |
"success": ensemble_result.get("success", False), | |
"results": results, | |
"ensemble_result": ensemble_result, | |
"pattern": CoordinationPattern.ENSEMBLE.value, | |
"metrics": { | |
"total_members": len(results), | |
"ensemble_confidence": ensemble_result.get("confidence", 0.0) | |
} | |
} | |
def _record_interaction(self, | |
source: StrategyType, | |
target: Optional[StrategyType], | |
interaction_type: str, | |
data: Dict[str, Any]): | |
"""Record strategy interaction.""" | |
self.interactions.append(StrategyInteraction( | |
source=source, | |
target=target, | |
interaction_type=interaction_type, | |
data=data | |
)) | |
def _update_pattern_performance(self, pattern: CoordinationPattern, result: Dict[str, Any]): | |
"""Update pattern performance metrics.""" | |
success_rate = result["metrics"].get("success_rate", 0.0) | |
self.pattern_performance[pattern].append(success_rate) | |
# Update weights using exponential moving average | |
current_weight = self.pattern_weights[pattern] | |
self.pattern_weights[pattern] = ( | |
(1 - self.learning_rate) * current_weight + | |
self.learning_rate * success_rate | |
) | |
def get_performance_metrics(self) -> Dict[str, Any]: | |
"""Get comprehensive performance metrics.""" | |
return { | |
"pattern_weights": dict(self.pattern_weights), | |
"average_performance": { | |
pattern.value: sum(scores) / len(scores) if scores else 0 | |
for pattern, scores in self.pattern_performance.items() | |
}, | |
"interaction_counts": defaultdict(int, { | |
interaction.interaction_type: 1 | |
for interaction in self.interactions | |
}), | |
"active_patterns": [ | |
pattern.value for pattern, weight in self.pattern_weights.items() | |
if weight > 0.5 | |
] | |
} | |