Spaces:
Runtime error
Runtime error
"""Local LLM integration for the reasoning system.""" | |
import os | |
from typing import Dict, Any, Optional | |
from datetime import datetime | |
import logging | |
from llama_cpp import Llama | |
import huggingface_hub | |
from .base import ReasoningStrategy | |
class LocalLLMStrategy(ReasoningStrategy): | |
"""Implements reasoning using local LLM.""" | |
def __init__(self, config: Optional[Dict[str, Any]] = None): | |
"""Initialize the local LLM strategy.""" | |
super().__init__() | |
self.config = config or {} | |
# Configure parameters with defaults | |
self.repo_id = self.config.get('repo_id', "gpt-omni/mini-omni2") | |
self.filename = self.config.get('filename', "mini-omni2.gguf") | |
self.model_dir = self.config.get('model_dir', "models") | |
# Standard reasoning parameters | |
self.min_confidence = self.config.get('min_confidence', 0.7) | |
self.parallel_threshold = self.config.get('parallel_threshold', 3) | |
self.learning_rate = self.config.get('learning_rate', 0.1) | |
self.strategy_weights = self.config.get('strategy_weights', { | |
"LOCAL_LLM": 0.8, | |
"CHAIN_OF_THOUGHT": 0.6, | |
"TREE_OF_THOUGHTS": 0.5, | |
"META_LEARNING": 0.4 | |
}) | |
self.logger = logging.getLogger(__name__) | |
self.model = None | |
async def initialize(self): | |
"""Initialize the model.""" | |
try: | |
# Create models directory if it doesn't exist | |
os.makedirs(self.model_dir, exist_ok=True) | |
model_path = os.path.join(self.model_dir, self.filename) | |
# Download model if it doesn't exist | |
if not os.path.exists(model_path): | |
self.logger.info(f"Downloading model to {model_path}...") | |
model_path = huggingface_hub.hf_hub_download( | |
repo_id=self.repo_id, | |
filename=self.filename, | |
repo_type="model", | |
local_dir=self.model_dir, | |
local_dir_use_symlinks=False | |
) | |
self.logger.info("Model downloaded successfully!") | |
else: | |
self.logger.info("Using existing model file...") | |
# Try to use GPU, fall back to CPU if not available | |
try: | |
self.model = Llama( | |
model_path=model_path, | |
n_ctx=4096, | |
n_batch=512, | |
n_threads=8, | |
n_gpu_layers=35 | |
) | |
self.logger.info("Model loaded with GPU acceleration!") | |
except Exception as e: | |
self.logger.warning(f"GPU loading failed: {e}, falling back to CPU...") | |
self.model = Llama( | |
model_path=model_path, | |
n_ctx=2048, | |
n_batch=512, | |
n_threads=4, | |
n_gpu_layers=0 | |
) | |
self.logger.info("Model loaded in CPU-only mode") | |
except Exception as e: | |
self.logger.error(f"Error initializing model: {e}") | |
raise | |
async def reason(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Generate reasoning response using local LLM.""" | |
try: | |
if not self.model: | |
await self.initialize() | |
# Format prompt with context | |
prompt = self._format_prompt(query, context) | |
# Generate response | |
response = self.model( | |
prompt, | |
max_tokens=1024 if self.model.n_ctx >= 4096 else 512, | |
temperature=0.7, | |
top_p=0.95, | |
repeat_penalty=1.1, | |
echo=False | |
) | |
# Extract and structure the response | |
result = self._parse_response(response['choices'][0]['text']) | |
return { | |
'success': True, | |
'answer': result['answer'], | |
'reasoning': result['reasoning'], | |
'confidence': result['confidence'], | |
'timestamp': datetime.now(), | |
'metadata': { | |
'model': self.repo_id, | |
'strategy': 'local_llm', | |
'context_length': len(prompt), | |
'response_length': len(response['choices'][0]['text']) | |
} | |
} | |
except Exception as e: | |
self.logger.error(f"Error in reasoning: {e}") | |
return { | |
'success': False, | |
'error': str(e), | |
'timestamp': datetime.now() | |
} | |
def _format_prompt(self, query: str, context: Dict[str, Any]) -> str: | |
"""Format the prompt with query and context.""" | |
# Include relevant context | |
context_str = "\n".join([ | |
f"{k}: {v}" for k, v in context.items() | |
if k in ['objective', 'constraints', 'background'] | |
]) | |
return f"""Let's solve this problem step by step. | |
Context: | |
{context_str} | |
Question: {query} | |
Let me break this down: | |
1.""" | |
def _parse_response(self, text: str) -> Dict[str, Any]: | |
"""Parse the response into structured output.""" | |
# Simple parsing for now | |
lines = text.strip().split('\n') | |
return { | |
'answer': lines[-1] if lines else '', | |
'reasoning': '\n'.join(lines[:-1]) if len(lines) > 1 else '', | |
'confidence': 0.8 # Default confidence | |
} | |