Spaces:
Runtime error
Runtime error
""" | |
Agentic Orchestrator for Advanced AI System | |
----------------------------------------- | |
Manages and coordinates multiple agentic components: | |
1. Task Planning & Decomposition | |
2. Resource Management | |
3. Agent Communication | |
4. State Management | |
5. Error Recovery | |
6. Performance Monitoring | |
""" | |
import logging | |
from typing import Dict, Any, List, Optional, Union, TypeVar, Generic | |
from dataclasses import dataclass, field | |
from enum import Enum | |
import json | |
import asyncio | |
from datetime import datetime | |
import uuid | |
from concurrent.futures import ThreadPoolExecutor | |
import networkx as nx | |
from collections import defaultdict | |
import numpy as np | |
from reasoning import UnifiedReasoningEngine as ReasoningEngine, StrategyType as ReasoningMode | |
from reasoning.meta_learning import MetaLearningStrategy | |
T = TypeVar('T') | |
class AgentRole(Enum): | |
"""Different roles an agent can take.""" | |
PLANNER = "planner" | |
EXECUTOR = "executor" | |
MONITOR = "monitor" | |
COORDINATOR = "coordinator" | |
LEARNER = "learner" | |
class AgentState(Enum): | |
"""Possible states of an agent.""" | |
IDLE = "idle" | |
BUSY = "busy" | |
ERROR = "error" | |
LEARNING = "learning" | |
TERMINATED = "terminated" | |
class TaskPriority(Enum): | |
"""Task priority levels.""" | |
LOW = 0 | |
MEDIUM = 1 | |
HIGH = 2 | |
CRITICAL = 3 | |
class AgentMetadata: | |
"""Metadata about an agent.""" | |
id: str | |
role: AgentRole | |
capabilities: List[str] | |
state: AgentState | |
load: float | |
last_active: datetime | |
metrics: Dict[str, float] | |
class Task: | |
"""Represents a task in the system.""" | |
id: str | |
description: str | |
priority: TaskPriority | |
dependencies: List[str] | |
assigned_to: Optional[str] | |
state: str | |
created_at: datetime | |
deadline: Optional[datetime] | |
metadata: Dict[str, Any] | |
class AgentOrchestrator: | |
"""Advanced orchestrator for managing agentic system.""" | |
def __init__(self, config: Dict[str, Any] = None): | |
self.config = config or {} | |
# Core components | |
self.agents: Dict[str, AgentMetadata] = {} | |
self.tasks: Dict[str, Task] = {} | |
self.task_graph = nx.DiGraph() | |
# State management | |
self.state_history: List[Dict[str, Any]] = [] | |
self.global_state: Dict[str, Any] = {} | |
# Resource management | |
self.resource_pool: Dict[str, Any] = {} | |
self.resource_locks: Dict[str, asyncio.Lock] = {} | |
# Communication | |
self.message_queue = asyncio.Queue() | |
self.event_bus = asyncio.Queue() | |
# Performance monitoring | |
self.metrics = defaultdict(list) | |
self.performance_log = [] | |
# Error handling | |
self.error_handlers: Dict[str, callable] = {} | |
self.recovery_strategies: Dict[str, callable] = {} | |
# Async support | |
self.executor = ThreadPoolExecutor(max_workers=4) | |
self.lock = asyncio.Lock() | |
# Logging | |
self.logger = logging.getLogger(__name__) | |
# Initialize components | |
self._init_components() | |
def _init_components(self): | |
"""Initialize orchestrator components.""" | |
# Initialize reasoning engine | |
self.reasoning_engine = ReasoningEngine( | |
min_confidence=0.7, | |
parallel_threshold=5, | |
learning_rate=0.1, | |
strategy_weights={ | |
"LOCAL_LLM": 2.0, | |
"CHAIN_OF_THOUGHT": 1.0, | |
"TREE_OF_THOUGHTS": 1.0, | |
"META_LEARNING": 1.5 | |
} | |
) | |
# Initialize meta-learning | |
self.meta_learning = MetaLearningStrategy() | |
# Register basic error handlers | |
self._register_error_handlers() | |
async def register_agent( | |
self, | |
role: AgentRole, | |
capabilities: List[str] | |
) -> str: | |
"""Register a new agent with the orchestrator.""" | |
agent_id = str(uuid.uuid4()) | |
agent = AgentMetadata( | |
id=agent_id, | |
role=role, | |
capabilities=capabilities, | |
state=AgentState.IDLE, | |
load=0.0, | |
last_active=datetime.now(), | |
metrics={} | |
) | |
async with self.lock: | |
self.agents[agent_id] = agent | |
self.logger.info(f"Registered new agent: {agent_id} with role {role}") | |
return agent_id | |
async def submit_task( | |
self, | |
description: str, | |
priority: TaskPriority = TaskPriority.MEDIUM, | |
dependencies: List[str] = None, | |
deadline: Optional[datetime] = None, | |
metadata: Dict[str, Any] = None | |
) -> str: | |
"""Submit a new task to the orchestrator.""" | |
task_id = str(uuid.uuid4()) | |
task = Task( | |
id=task_id, | |
description=description, | |
priority=priority, | |
dependencies=dependencies or [], | |
assigned_to=None, | |
state="pending", | |
created_at=datetime.now(), | |
deadline=deadline, | |
metadata=metadata or {} | |
) | |
async with self.lock: | |
self.tasks[task_id] = task | |
self._update_task_graph(task) | |
# Trigger task planning | |
await self._plan_task_execution(task_id) | |
return task_id | |
async def _plan_task_execution(self, task_id: str) -> None: | |
"""Plan the execution of a task.""" | |
task = self.tasks[task_id] | |
# Check dependencies | |
if not await self._check_dependencies(task): | |
self.logger.info(f"Task {task_id} waiting for dependencies") | |
return | |
# Find suitable agent | |
agent_id = await self._find_suitable_agent(task) | |
if not agent_id: | |
self.logger.warning(f"No suitable agent found for task {task_id}") | |
return | |
# Assign task | |
await self._assign_task(task_id, agent_id) | |
async def _check_dependencies(self, task: Task) -> bool: | |
"""Check if all task dependencies are satisfied.""" | |
for dep_id in task.dependencies: | |
if dep_id not in self.tasks: | |
return False | |
if self.tasks[dep_id].state != "completed": | |
return False | |
return True | |
async def _find_suitable_agent(self, task: Task) -> Optional[str]: | |
"""Find the most suitable agent for a task.""" | |
best_agent = None | |
best_score = float('-inf') | |
for agent_id, agent in self.agents.items(): | |
if agent.state != AgentState.IDLE: | |
continue | |
score = await self._calculate_agent_suitability(agent, task) | |
if score > best_score: | |
best_score = score | |
best_agent = agent_id | |
return best_agent | |
async def _calculate_agent_suitability( | |
self, | |
agent: AgentMetadata, | |
task: Task | |
) -> float: | |
"""Calculate how suitable an agent is for a task.""" | |
# Base score on capabilities match | |
capability_score = sum( | |
1 for cap in task.metadata.get("required_capabilities", []) | |
if cap in agent.capabilities | |
) | |
# Consider agent load | |
load_score = 1 - agent.load | |
# Consider agent's recent performance | |
performance_score = sum(agent.metrics.values()) / len(agent.metrics) if agent.metrics else 0.5 | |
# Weighted combination | |
weights = self.config.get("agent_selection_weights", { | |
"capabilities": 0.5, | |
"load": 0.3, | |
"performance": 0.2 | |
}) | |
return ( | |
weights["capabilities"] * capability_score + | |
weights["load"] * load_score + | |
weights["performance"] * performance_score | |
) | |
async def _assign_task(self, task_id: str, agent_id: str) -> None: | |
"""Assign a task to an agent.""" | |
async with self.lock: | |
task = self.tasks[task_id] | |
agent = self.agents[agent_id] | |
task.assigned_to = agent_id | |
task.state = "assigned" | |
agent.state = AgentState.BUSY | |
agent.load += 1 | |
agent.last_active = datetime.now() | |
self.logger.info(f"Assigned task {task_id} to agent {agent_id}") | |
# Notify agent | |
await self.message_queue.put({ | |
"type": "task_assignment", | |
"task_id": task_id, | |
"agent_id": agent_id, | |
"timestamp": datetime.now() | |
}) | |
def _update_task_graph(self, task: Task) -> None: | |
"""Update the task dependency graph.""" | |
self.task_graph.add_node(task.id, task=task) | |
for dep_id in task.dependencies: | |
self.task_graph.add_edge(dep_id, task.id) | |
async def _monitor_system_state(self): | |
"""Monitor overall system state.""" | |
while True: | |
try: | |
# Collect agent states | |
agent_states = { | |
agent_id: { | |
"state": agent.state, | |
"load": agent.load, | |
"metrics": agent.metrics | |
} | |
for agent_id, agent in self.agents.items() | |
} | |
# Collect task states | |
task_states = { | |
task_id: { | |
"state": task.state, | |
"assigned_to": task.assigned_to, | |
"deadline": task.deadline | |
} | |
for task_id, task in self.tasks.items() | |
} | |
# Update global state | |
self.global_state = { | |
"timestamp": datetime.now(), | |
"agents": agent_states, | |
"tasks": task_states, | |
"resource_usage": self._get_resource_usage(), | |
"performance_metrics": self._calculate_performance_metrics() | |
} | |
# Archive state | |
self.state_history.append(self.global_state.copy()) | |
# Trim history if too long | |
if len(self.state_history) > 1000: | |
self.state_history = self.state_history[-1000:] | |
# Check for anomalies | |
await self._check_anomalies() | |
await asyncio.sleep(1) # Monitor frequency | |
except Exception as e: | |
self.logger.error(f"Error in system monitoring: {e}") | |
await self._handle_error("monitoring_error", e) | |
def _get_resource_usage(self) -> Dict[str, float]: | |
"""Get current resource usage statistics.""" | |
return { | |
"cpu_usage": sum(agent.load for agent in self.agents.values()) / len(self.agents), | |
"memory_usage": len(self.state_history) * 1000, # Rough estimate | |
"queue_size": self.message_queue.qsize() | |
} | |
def _calculate_performance_metrics(self) -> Dict[str, float]: | |
"""Calculate current performance metrics.""" | |
metrics = {} | |
# Task completion rate | |
completed_tasks = sum(1 for task in self.tasks.values() if task.state == "completed") | |
total_tasks = len(self.tasks) | |
metrics["task_completion_rate"] = completed_tasks / max(1, total_tasks) | |
# Average task duration | |
durations = [] | |
for task in self.tasks.values(): | |
if task.state == "completed" and "completion_time" in task.metadata: | |
duration = (task.metadata["completion_time"] - task.created_at).total_seconds() | |
durations.append(duration) | |
metrics["avg_task_duration"] = sum(durations) / len(durations) if durations else 0 | |
# Agent utilization | |
metrics["agent_utilization"] = sum(agent.load for agent in self.agents.values()) / len(self.agents) | |
return metrics | |
async def _check_anomalies(self): | |
"""Check for system anomalies.""" | |
# Check for overloaded agents | |
for agent_id, agent in self.agents.items(): | |
if agent.load > 0.9: # 90% load threshold | |
await self._handle_overload(agent_id) | |
# Check for stalled tasks | |
now = datetime.now() | |
for task_id, task in self.tasks.items(): | |
if task.state == "assigned": | |
duration = (now - task.created_at).total_seconds() | |
if duration > 3600: # 1 hour threshold | |
await self._handle_stalled_task(task_id) | |
# Check for missed deadlines | |
for task_id, task in self.tasks.items(): | |
if task.deadline and now > task.deadline and task.state != "completed": | |
await self._handle_missed_deadline(task_id) | |
async def _handle_overload(self, agent_id: str): | |
"""Handle an overloaded agent.""" | |
agent = self.agents[agent_id] | |
# Try to redistribute tasks | |
assigned_tasks = [ | |
task_id for task_id, task in self.tasks.items() | |
if task.assigned_to == agent_id and task.state == "assigned" | |
] | |
for task_id in assigned_tasks: | |
# Find another suitable agent | |
new_agent_id = await self._find_suitable_agent(self.tasks[task_id]) | |
if new_agent_id: | |
await self._reassign_task(task_id, new_agent_id) | |
async def _handle_stalled_task(self, task_id: str): | |
"""Handle a stalled task.""" | |
task = self.tasks[task_id] | |
# First, try to ping the assigned agent | |
if task.assigned_to: | |
agent = self.agents[task.assigned_to] | |
if agent.state == AgentState.ERROR: | |
# Agent is in error state, reassign task | |
await self._reassign_task(task_id, None) | |
else: | |
# Request status update from agent | |
await self.message_queue.put({ | |
"type": "status_request", | |
"task_id": task_id, | |
"agent_id": task.assigned_to, | |
"timestamp": datetime.now() | |
}) | |
async def _handle_missed_deadline(self, task_id: str): | |
"""Handle a missed deadline.""" | |
task = self.tasks[task_id] | |
# Log the incident | |
self.logger.warning(f"Task {task_id} missed deadline: {task.deadline}") | |
# Update task priority to CRITICAL | |
task.priority = TaskPriority.CRITICAL | |
# If task is assigned, try to speed it up | |
if task.assigned_to: | |
await self.message_queue.put({ | |
"type": "expedite_request", | |
"task_id": task_id, | |
"agent_id": task.assigned_to, | |
"timestamp": datetime.now() | |
}) | |
else: | |
# If not assigned, try to assign to fastest available agent | |
await self._plan_task_execution(task_id) | |
async def _reassign_task(self, task_id: str, new_agent_id: Optional[str] = None): | |
"""Reassign a task to a new agent.""" | |
task = self.tasks[task_id] | |
old_agent_id = task.assigned_to | |
if old_agent_id: | |
# Update old agent | |
old_agent = self.agents[old_agent_id] | |
old_agent.load -= 1 | |
if old_agent.load <= 0: | |
old_agent.state = AgentState.IDLE | |
if new_agent_id is None: | |
# Find new suitable agent | |
new_agent_id = await self._find_suitable_agent(task) | |
if new_agent_id: | |
# Assign to new agent | |
await self._assign_task(task_id, new_agent_id) | |
else: | |
# No suitable agent found, mark task as pending | |
task.state = "pending" | |
task.assigned_to = None | |
def _register_error_handlers(self): | |
"""Register basic error handlers.""" | |
self.error_handlers.update({ | |
"monitoring_error": self._handle_monitoring_error, | |
"agent_error": self._handle_agent_error, | |
"task_error": self._handle_task_error, | |
"resource_error": self._handle_resource_error | |
}) | |
self.recovery_strategies.update({ | |
"agent_recovery": self._recover_agent, | |
"task_recovery": self._recover_task, | |
"resource_recovery": self._recover_resource | |
}) | |
async def _handle_error(self, error_type: str, error: Exception): | |
"""Handle an error using registered handlers.""" | |
handler = self.error_handlers.get(error_type) | |
if handler: | |
try: | |
await handler(error) | |
except Exception as e: | |
self.logger.error(f"Error in error handler: {e}") | |
else: | |
self.logger.error(f"No handler for error type: {error_type}") | |
self.logger.error(f"Error: {error}") | |
async def _handle_monitoring_error(self, error: Exception): | |
"""Handle monitoring system errors.""" | |
self.logger.error(f"Monitoring error: {error}") | |
# Implement recovery logic | |
pass | |
async def _handle_agent_error(self, error: Exception): | |
"""Handle agent-related errors.""" | |
self.logger.error(f"Agent error: {error}") | |
# Implement recovery logic | |
pass | |
async def _handle_task_error(self, error: Exception): | |
"""Handle task-related errors.""" | |
self.logger.error(f"Task error: {error}") | |
# Implement recovery logic | |
pass | |
async def _handle_resource_error(self, error: Exception): | |
"""Handle resource-related errors.""" | |
self.logger.error(f"Resource error: {error}") | |
# Implement recovery logic | |
pass | |
async def _recover_agent(self, agent_id: str): | |
"""Recover a failed agent.""" | |
try: | |
agent = self.agents[agent_id] | |
# Log recovery attempt | |
self.logger.info(f"Attempting to recover agent {agent_id}") | |
# Reset agent state | |
agent.state = AgentState.IDLE | |
agent.load = 0 | |
agent.last_active = datetime.now() | |
# Reassign any tasks that were assigned to this agent | |
for task_id, task in self.tasks.items(): | |
if task.assigned_to == agent_id: | |
await self._reassign_task(task_id) | |
# Update metrics | |
agent.metrics["recovery_attempts"] = agent.metrics.get("recovery_attempts", 0) + 1 | |
self.logger.info(f"Successfully recovered agent {agent_id}") | |
return True | |
except Exception as e: | |
self.logger.error(f"Failed to recover agent {agent_id}: {e}") | |
return False | |
async def _recover_task(self, task_id: str): | |
"""Recover a failed task.""" | |
try: | |
task = self.tasks[task_id] | |
# Log recovery attempt | |
self.logger.info(f"Attempting to recover task {task_id}") | |
# Reset task state | |
task.state = "pending" | |
task.assigned_to = None | |
# Try to reassign the task | |
await self._reassign_task(task_id) | |
self.logger.info(f"Successfully recovered task {task_id}") | |
return True | |
except Exception as e: | |
self.logger.error(f"Failed to recover task {task_id}: {e}") | |
return False | |
async def _recover_resource(self, resource_id: str): | |
"""Recover a failed resource.""" | |
try: | |
# Log recovery attempt | |
self.logger.info(f"Attempting to recover resource {resource_id}") | |
# Release any locks on the resource | |
if resource_id in self.resource_locks: | |
lock = self.resource_locks[resource_id] | |
if lock.locked(): | |
lock.release() | |
# Reset resource state | |
if resource_id in self.resource_pool: | |
self.resource_pool[resource_id] = { | |
"state": "available", | |
"last_error": None, | |
"last_recovery": datetime.now() | |
} | |
self.logger.info(f"Successfully recovered resource {resource_id}") | |
return True | |
except Exception as e: | |
self.logger.error(f"Failed to recover resource {resource_id}: {e}") | |
return False | |