Spaces:
Runtime error
Runtime error
"""Enhanced learning mechanisms for reasoning strategies.""" | |
import logging | |
from typing import Dict, Any, List, Optional, Set, Union, Type, Tuple | |
import json | |
from dataclasses import dataclass, field | |
from enum import Enum | |
from datetime import datetime | |
import numpy as np | |
from collections import defaultdict | |
class LearningEvent: | |
"""Event for strategy learning.""" | |
strategy_type: str | |
event_type: str | |
data: Dict[str, Any] | |
outcome: Optional[float] | |
timestamp: datetime = field(default_factory=datetime.now) | |
class LearningMode(Enum): | |
"""Types of learning modes.""" | |
SUPERVISED = "supervised" | |
REINFORCEMENT = "reinforcement" | |
ACTIVE = "active" | |
TRANSFER = "transfer" | |
META = "meta" | |
ENSEMBLE = "ensemble" | |
class LearningState: | |
"""State for learning process.""" | |
mode: LearningMode | |
parameters: Dict[str, Any] | |
history: List[LearningEvent] | |
metrics: Dict[str, float] | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
class EnhancedLearningManager: | |
""" | |
Advanced learning manager that: | |
1. Implements multiple learning modes | |
2. Tracks learning progress | |
3. Adapts learning parameters | |
4. Optimizes strategy performance | |
5. Transfers knowledge between strategies | |
""" | |
def __init__(self, | |
learning_rate: float = 0.1, | |
exploration_rate: float = 0.2, | |
memory_size: int = 1000): | |
self.learning_rate = learning_rate | |
self.exploration_rate = exploration_rate | |
self.memory_size = memory_size | |
# Learning states | |
self.states: Dict[str, LearningState] = {} | |
# Performance tracking | |
self.performance_history: List[Dict[str, Any]] = [] | |
self.strategy_metrics: Dict[str, List[float]] = defaultdict(list) | |
# Knowledge transfer | |
self.knowledge_base: Dict[str, Any] = {} | |
self.transfer_history: List[Dict[str, Any]] = [] | |
async def learn(self, | |
strategy_type: str, | |
event: LearningEvent, | |
context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Learn from strategy execution event.""" | |
try: | |
# Initialize or get learning state | |
state = self._get_learning_state(strategy_type) | |
# Select learning mode | |
mode = await self._select_learning_mode(event, state, context) | |
# Execute learning | |
if mode == LearningMode.SUPERVISED: | |
result = await self._supervised_learning(event, state, context) | |
elif mode == LearningMode.REINFORCEMENT: | |
result = await self._reinforcement_learning(event, state, context) | |
elif mode == LearningMode.ACTIVE: | |
result = await self._active_learning(event, state, context) | |
elif mode == LearningMode.TRANSFER: | |
result = await self._transfer_learning(event, state, context) | |
elif mode == LearningMode.META: | |
result = await self._meta_learning(event, state, context) | |
elif mode == LearningMode.ENSEMBLE: | |
result = await self._ensemble_learning(event, state, context) | |
else: | |
raise ValueError(f"Unsupported learning mode: {mode}") | |
# Update state | |
self._update_learning_state(state, result) | |
# Record performance | |
self._record_performance(strategy_type, result) | |
return result | |
except Exception as e: | |
logging.error(f"Error in learning: {str(e)}") | |
return { | |
"success": False, | |
"error": str(e), | |
"mode": mode.value if 'mode' in locals() else None | |
} | |
async def _supervised_learning(self, | |
event: LearningEvent, | |
state: LearningState, | |
context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Implement supervised learning.""" | |
# Extract features and labels | |
features = await self._extract_features(event.data, context) | |
labels = event.outcome if event.outcome is not None else 0.0 | |
# Train model | |
model_update = await self._update_model(features, labels, state, context) | |
# Validate performance | |
validation = await self._validate_model(model_update, state, context) | |
return { | |
"success": True, | |
"mode": LearningMode.SUPERVISED.value, | |
"model_update": model_update, | |
"validation": validation, | |
"metrics": { | |
"accuracy": validation.get("accuracy", 0.0), | |
"loss": validation.get("loss", 0.0) | |
} | |
} | |
async def _reinforcement_learning(self, | |
event: LearningEvent, | |
state: LearningState, | |
context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Implement reinforcement learning.""" | |
# Extract state and action | |
current_state = await self._extract_state(event.data, context) | |
action = event.data.get("action") | |
reward = event.outcome if event.outcome is not None else 0.0 | |
# Update policy | |
policy_update = await self._update_policy( | |
current_state, action, reward, state, context) | |
# Optimize value function | |
value_update = await self._update_value_function( | |
current_state, reward, state, context) | |
return { | |
"success": True, | |
"mode": LearningMode.REINFORCEMENT.value, | |
"policy_update": policy_update, | |
"value_update": value_update, | |
"metrics": { | |
"reward": reward, | |
"value_error": value_update.get("error", 0.0) | |
} | |
} | |
async def _active_learning(self, | |
event: LearningEvent, | |
state: LearningState, | |
context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Implement active learning.""" | |
# Query selection | |
query = await self._select_query(event.data, state, context) | |
# Get feedback | |
feedback = await self._get_feedback(query, context) | |
# Update model | |
model_update = await self._update_model_active( | |
query, feedback, state, context) | |
return { | |
"success": True, | |
"mode": LearningMode.ACTIVE.value, | |
"query": query, | |
"feedback": feedback, | |
"model_update": model_update, | |
"metrics": { | |
"uncertainty": query.get("uncertainty", 0.0), | |
"feedback_quality": feedback.get("quality", 0.0) | |
} | |
} | |
async def _transfer_learning(self, | |
event: LearningEvent, | |
state: LearningState, | |
context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Implement transfer learning.""" | |
# Source task selection | |
source_task = await self._select_source_task(event.data, state, context) | |
# Knowledge extraction | |
knowledge = await self._extract_knowledge(source_task, context) | |
# Transfer adaptation | |
adaptation = await self._adapt_knowledge( | |
knowledge, event.data, state, context) | |
# Apply transfer | |
transfer = await self._apply_transfer(adaptation, state, context) | |
return { | |
"success": True, | |
"mode": LearningMode.TRANSFER.value, | |
"source_task": source_task, | |
"knowledge": knowledge, | |
"adaptation": adaptation, | |
"transfer": transfer, | |
"metrics": { | |
"transfer_efficiency": transfer.get("efficiency", 0.0), | |
"adaptation_quality": adaptation.get("quality", 0.0) | |
} | |
} | |
async def _meta_learning(self, | |
event: LearningEvent, | |
state: LearningState, | |
context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Implement meta-learning.""" | |
# Task characterization | |
task_char = await self._characterize_task(event.data, context) | |
# Strategy selection | |
strategy = await self._select_strategy(task_char, state, context) | |
# Parameter optimization | |
optimization = await self._optimize_parameters( | |
strategy, task_char, state, context) | |
# Apply meta-learning | |
meta_update = await self._apply_meta_learning( | |
optimization, state, context) | |
return { | |
"success": True, | |
"mode": LearningMode.META.value, | |
"task_characterization": task_char, | |
"strategy": strategy, | |
"optimization": optimization, | |
"meta_update": meta_update, | |
"metrics": { | |
"strategy_fit": strategy.get("fit_score", 0.0), | |
"optimization_improvement": optimization.get("improvement", 0.0) | |
} | |
} | |
async def _ensemble_learning(self, | |
event: LearningEvent, | |
state: LearningState, | |
context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Implement ensemble learning.""" | |
# Member selection | |
members = await self._select_members(event.data, state, context) | |
# Weight optimization | |
weights = await self._optimize_weights(members, state, context) | |
# Combine predictions | |
combination = await self._combine_predictions( | |
members, weights, event.data, context) | |
return { | |
"success": True, | |
"mode": LearningMode.ENSEMBLE.value, | |
"members": members, | |
"weights": weights, | |
"combination": combination, | |
"metrics": { | |
"ensemble_diversity": weights.get("diversity", 0.0), | |
"combination_strength": combination.get("strength", 0.0) | |
} | |
} | |
def _get_learning_state(self, strategy_type: str) -> LearningState: | |
"""Get or initialize learning state for strategy.""" | |
if strategy_type not in self.states: | |
self.states[strategy_type] = LearningState( | |
mode=LearningMode.SUPERVISED, | |
parameters={ | |
"learning_rate": self.learning_rate, | |
"exploration_rate": self.exploration_rate | |
}, | |
history=[], | |
metrics={} | |
) | |
return self.states[strategy_type] | |
def _update_learning_state(self, state: LearningState, result: Dict[str, Any]): | |
"""Update learning state with result.""" | |
# Update history | |
state.history.append(LearningEvent( | |
strategy_type=result.get("strategy_type", "unknown"), | |
event_type="learning_update", | |
data=result, | |
outcome=result.get("metrics", {}).get("accuracy", 0.0), | |
timestamp=datetime.now() | |
)) | |
# Update metrics | |
for metric, value in result.get("metrics", {}).items(): | |
if metric in state.metrics: | |
state.metrics[metric] = ( | |
0.9 * state.metrics[metric] + 0.1 * value # Exponential moving average | |
) | |
else: | |
state.metrics[metric] = value | |
# Adapt parameters | |
self._adapt_parameters(state, result) | |
def _record_performance(self, strategy_type: str, result: Dict[str, Any]): | |
"""Record learning performance.""" | |
self.performance_history.append({ | |
"timestamp": datetime.now().isoformat(), | |
"strategy_type": strategy_type, | |
"mode": result.get("mode"), | |
"metrics": result.get("metrics", {}), | |
"success": result.get("success", False) | |
}) | |
# Update strategy metrics | |
for metric, value in result.get("metrics", {}).items(): | |
self.strategy_metrics[f"{strategy_type}_{metric}"].append(value) | |
# Maintain memory size | |
if len(self.performance_history) > self.memory_size: | |
self.performance_history = self.performance_history[-self.memory_size:] | |
def _adapt_parameters(self, state: LearningState, result: Dict[str, Any]): | |
"""Adapt learning parameters based on performance.""" | |
# Adapt learning rate | |
if "accuracy" in result.get("metrics", {}): | |
accuracy = result["metrics"]["accuracy"] | |
if accuracy > 0.8: | |
state.parameters["learning_rate"] *= 0.95 # Decrease if performing well | |
elif accuracy < 0.6: | |
state.parameters["learning_rate"] *= 1.05 # Increase if performing poorly | |
# Adapt exploration rate | |
if "reward" in result.get("metrics", {}): | |
reward = result["metrics"]["reward"] | |
if reward > 0: | |
state.parameters["exploration_rate"] *= 0.95 # Decrease if getting rewards | |
else: | |
state.parameters["exploration_rate"] *= 1.05 # Increase if not getting rewards | |
# Clip parameters to reasonable ranges | |
state.parameters["learning_rate"] = np.clip( | |
state.parameters["learning_rate"], 0.001, 0.5) | |
state.parameters["exploration_rate"] = np.clip( | |
state.parameters["exploration_rate"], 0.01, 0.5) | |
def get_performance_metrics(self) -> Dict[str, Any]: | |
"""Get comprehensive performance metrics.""" | |
return { | |
"learning_states": { | |
strategy_type: { | |
"mode": state.mode.value, | |
"parameters": state.parameters, | |
"metrics": state.metrics | |
} | |
for strategy_type, state in self.states.items() | |
}, | |
"strategy_performance": { | |
metric: { | |
"mean": np.mean(values) if values else 0.0, | |
"std": np.std(values) if values else 0.0, | |
"min": min(values) if values else 0.0, | |
"max": max(values) if values else 0.0 | |
} | |
for metric, values in self.strategy_metrics.items() | |
}, | |
"transfer_metrics": { | |
"total_transfers": len(self.transfer_history), | |
"success_rate": sum(1 for t in self.transfer_history if t.get("success", False)) / len(self.transfer_history) if self.transfer_history else 0 | |
} | |
} | |
def clear_history(self): | |
"""Clear learning history and reset states.""" | |
self.states.clear() | |
self.performance_history.clear() | |
self.strategy_metrics.clear() | |
self.transfer_history.clear() | |