"""Advanced meta-learning strategy for adaptive reasoning.""" import logging from typing import Dict, Any, List, Optional, Set, Union, Type, Tuple import json from dataclasses import dataclass, field from enum import Enum from datetime import datetime import numpy as np from collections import defaultdict from .base import ReasoningStrategy @dataclass class MetaTask: """Meta-learning task with parameters and performance metrics.""" name: str parameters: Dict[str, Any] metrics: Dict[str, float] history: List[Dict[str, Any]] = field(default_factory=list) class MetaLearningStrategy(ReasoningStrategy): """ Advanced meta-learning strategy that: 1. Adapts to new tasks 2. Learns from experience 3. Optimizes parameters 4. Transfers knowledge 5. Improves over time """ def __init__(self, config: Optional[Dict[str, Any]] = None): """Initialize meta-learning strategy.""" super().__init__() self.config = config or {} # Configure parameters self.learning_rate = self.config.get('learning_rate', 0.01) self.memory_size = self.config.get('memory_size', 100) self.adaptation_threshold = self.config.get('adaptation_threshold', 0.7) # Initialize task memory self.task_memory: List[MetaTask] = [] async def reason(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]: """ Apply meta-learning to adapt and optimize reasoning. Args: query: The input query to reason about context: Additional context and parameters Returns: Dict containing reasoning results and confidence scores """ try: # Identify similar tasks similar_tasks = await self._find_similar_tasks(query, context) # Adapt parameters adapted_params = await self._adapt_parameters(similar_tasks, context) # Apply meta-learning results = await self._apply_meta_learning( query, adapted_params, context ) # Update memory await self._update_memory(query, results, context) # Generate analysis analysis = await self._generate_analysis(results, context) return { 'answer': self._format_analysis(analysis), 'confidence': self._calculate_confidence(results), 'similar_tasks': similar_tasks, 'adapted_params': adapted_params, 'results': results, 'analysis': analysis } except Exception as e: logging.error(f"Meta-learning failed: {str(e)}") return { 'error': f"Meta-learning failed: {str(e)}", 'confidence': 0.0 } async def _find_similar_tasks( self, query: str, context: Dict[str, Any] ) -> List[MetaTask]: """Find similar tasks in memory.""" similar_tasks = [] # Extract query features query_features = self._extract_features(query) for task in self.task_memory: # Calculate similarity similarity = self._calculate_similarity( query_features, self._extract_features(task.name) ) if similarity > self.adaptation_threshold: similar_tasks.append(task) # Sort by similarity similar_tasks.sort( key=lambda x: np.mean(list(x.metrics.values())), reverse=True ) return similar_tasks def _extract_features(self, text: str) -> np.ndarray: """Extract features from text.""" # Simple bag of words for now words = set(text.lower().split()) return np.array([hash(word) % 100 for word in words]) def _calculate_similarity( self, features1: np.ndarray, features2: np.ndarray ) -> float: """Calculate similarity between feature sets.""" # Simple Jaccard similarity intersection = np.intersect1d(features1, features2) union = np.union1d(features1, features2) return len(intersection) / len(union) if len(union) > 0 else 0 async def _adapt_parameters( self, similar_tasks: List[MetaTask], context: Dict[str, Any] ) -> Dict[str, Any]: """Adapt parameters based on similar tasks.""" if not similar_tasks: return self.config.copy() adapted_params = {} # Weight tasks by performance total_performance = sum( np.mean(list(task.metrics.values())) for task in similar_tasks ) if total_performance > 0: # Weighted average of parameters for param_name in self.config: adapted_params[param_name] = sum( task.parameters.get(param_name, self.config[param_name]) * (np.mean(list(task.metrics.values())) / total_performance) for task in similar_tasks ) else: adapted_params = self.config.copy() return adapted_params async def _apply_meta_learning( self, query: str, parameters: Dict[str, Any], context: Dict[str, Any] ) -> Dict[str, Any]: """Apply meta-learning with adapted parameters.""" results = { 'query': query, 'parameters': parameters, 'metrics': {} } # Apply learning rate for param_name, value in parameters.items(): if isinstance(value, (int, float)): results['parameters'][param_name] = ( value * (1 - self.learning_rate) + self.config[param_name] * self.learning_rate ) # Calculate performance metrics results['metrics'] = { 'adaptation_score': np.mean([ p / self.config[name] for name, p in results['parameters'].items() if isinstance(p, (int, float)) and self.config[name] != 0 ]), 'novelty_score': 1 - max( self._calculate_similarity( self._extract_features(query), self._extract_features(task.name) ) for task in self.task_memory ) if self.task_memory else 1.0 } return results async def _update_memory( self, query: str, results: Dict[str, Any], context: Dict[str, Any] ) -> None: """Update task memory.""" # Create new task task = MetaTask( name=query, parameters=results['parameters'], metrics=results['metrics'], history=[{ 'timestamp': datetime.now().isoformat(), 'context': context, 'results': results }] ) # Add to memory self.task_memory.append(task) # Maintain memory size if len(self.task_memory) > self.memory_size: # Remove worst performing task self.task_memory.sort( key=lambda x: np.mean(list(x.metrics.values())) ) self.task_memory.pop(0) async def _generate_analysis( self, results: Dict[str, Any], context: Dict[str, Any] ) -> Dict[str, Any]: """Generate meta-learning analysis.""" # Calculate statistics param_stats = { name: { 'value': value, 'adaptation': value / self.config[name] if isinstance(value, (int, float)) and self.config[name] != 0 else 1.0 } for name, value in results['parameters'].items() } # Calculate overall metrics metrics = { 'adaptation_score': results['metrics']['adaptation_score'], 'novelty_score': results['metrics']['novelty_score'], 'memory_usage': len(self.task_memory) / self.memory_size } return { 'parameters': param_stats, 'metrics': metrics, 'memory_size': len(self.task_memory), 'total_tasks_seen': len(self.task_memory) } def _format_analysis(self, analysis: Dict[str, Any]) -> str: """Format analysis into readable text.""" sections = [] # Parameter adaptations sections.append("Parameter adaptations:") for name, stats in analysis['parameters'].items(): sections.append( f"- {name}: {stats['value']:.2f} " f"({stats['adaptation']:.1%} of original)" ) # Performance metrics sections.append("\nPerformance metrics:") metrics = analysis['metrics'] sections.append(f"- Adaptation score: {metrics['adaptation_score']:.1%}") sections.append(f"- Novelty score: {metrics['novelty_score']:.1%}") sections.append(f"- Memory usage: {metrics['memory_usage']:.1%}") # Memory statistics sections.append("\nMemory statistics:") sections.append(f"- Current tasks in memory: {analysis['memory_size']}") sections.append(f"- Total tasks seen: {analysis['total_tasks_seen']}") return "\n".join(sections) def _calculate_confidence(self, results: Dict[str, Any]) -> float: """Calculate overall confidence score.""" if not results.get('metrics'): return 0.0 # Base confidence confidence = 0.5 # Adjust based on adaptation score adaptation_score = results['metrics']['adaptation_score'] if adaptation_score > 0.8: confidence += 0.3 elif adaptation_score > 0.6: confidence += 0.2 elif adaptation_score > 0.4: confidence += 0.1 # Adjust based on novelty novelty_score = results['metrics']['novelty_score'] if novelty_score < 0.2: # Very similar to known tasks confidence += 0.2 elif novelty_score < 0.4: confidence += 0.1 return min(confidence, 1.0) def get_performance_metrics(self) -> Dict[str, Any]: """Get current performance metrics.""" return { "success_rate": 0.0, "adaptation_rate": 0.0, "exploration_count": 0, "episode_count": len(self.task_memory), "pattern_count": 0, "learning_rate": self.learning_rate, "exploration_rate": 0.0 } def get_top_patterns(self, n: int = 10) -> List[Tuple[str, float]]: """Get top performing patterns.""" return [] def clear_memory(self): """Clear learning memory.""" self.task_memory.clear()