Spaces:
Runtime error
Runtime error
"""Local LLM integration for the reasoning system.""" | |
import os | |
from typing import Dict, Any, Optional, AsyncGenerator, Generator | |
from datetime import datetime | |
import logging | |
from llama_cpp import Llama | |
import huggingface_hub | |
from .base import ReasoningStrategy | |
from .model_manager import ModelManager, ModelType | |
class LocalLLMStrategy(ReasoningStrategy): | |
"""Implements reasoning using local LLM.""" | |
def __init__(self, config: Optional[Dict[str, Any]] = None): | |
"""Initialize the local LLM strategy.""" | |
super().__init__() | |
self.config = config or {} | |
# Initialize model manager with model_dir from config | |
model_dir = self.config.get('model_dir') | |
self.model_manager = ModelManager(model_dir) | |
# Standard reasoning parameters | |
self.min_confidence = self.config.get('min_confidence', 0.7) | |
self.parallel_threshold = self.config.get('parallel_threshold', 3) | |
self.learning_rate = self.config.get('learning_rate', 0.1) | |
self.strategy_weights = self.config.get('strategy_weights', { | |
"LOCAL_LLM": 0.8, | |
"CHAIN_OF_THOUGHT": 0.6, | |
"TREE_OF_THOUGHTS": 0.5, | |
"META_LEARNING": 0.4 | |
}) | |
self.logger = logging.getLogger(__name__) | |
async def initialize(self): | |
"""Initialize all models.""" | |
await self.model_manager.initialize_all_models() | |
async def reason(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Generate reasoning response using appropriate local LLM.""" | |
try: | |
# Determine best model for the task | |
task_type = context.get('task_type', 'general') | |
model_key = self.model_manager.get_best_model_for_task(task_type) | |
# Get or initialize the model | |
model = await self.model_manager.get_model(model_key) | |
if not model: | |
raise Exception(f"Failed to initialize {model_key} model") | |
# Format prompt with context | |
prompt = self._format_prompt(query, context) | |
# Generate response | |
response = model( | |
prompt, | |
max_tokens=1024 if model.n_ctx >= 4096 else 512, | |
temperature=0.7, | |
top_p=0.95, | |
repeat_penalty=1.1, | |
echo=False | |
) | |
# Extract and structure the response | |
result = self._parse_response(response['choices'][0]['text']) | |
return { | |
'success': True, | |
'answer': result['answer'], | |
'reasoning': result['reasoning'], | |
'confidence': result['confidence'], | |
'timestamp': datetime.now(), | |
'metadata': { | |
'model': model_key, | |
'strategy': 'local_llm', | |
'context_length': len(prompt), | |
'response_length': len(response['choices'][0]['text']) | |
} | |
} | |
except Exception as e: | |
self.logger.error(f"Error in reasoning: {e}") | |
return { | |
'success': False, | |
'error': str(e), | |
'timestamp': datetime.now() | |
} | |
def _format_prompt(self, query: str, context: Dict[str, Any]) -> str: | |
"""Format the prompt with query and context.""" | |
# Include relevant context | |
context_str = "\n".join([ | |
f"{k}: {v}" for k, v in context.items() | |
if k in ['objective', 'constraints', 'background'] | |
]) | |
return f"""Let's solve this problem step by step. | |
Context: | |
{context_str} | |
Question: {query} | |
Let me break this down: | |
1.""" | |
def _parse_response(self, text: str) -> Dict[str, Any]: | |
"""Parse the response into structured output.""" | |
# Simple parsing for now | |
lines = text.strip().split('\n') | |
return { | |
'answer': lines[-1] if lines else '', | |
'reasoning': '\n'.join(lines[:-1]) if len(lines) > 1 else '', | |
'confidence': 0.8 # Default confidence | |
} | |