Spaces:
Running
Running
| """ | |
| GNN-LLM Intelligent Selection System - Standalone Application | |
| This is a standalone demo application that integrates actual API calls to NVIDIA's model serving platform. | |
| All dependencies are self-contained within this file - no external imports required. | |
| Features: | |
| - Real API calls to NVIDIA's model serving platform | |
| - Self-contained model_prompting function implementation | |
| - Model mapping for different LLM types | |
| - Error handling with fallback mechanisms | |
| - Progress tracking and status updates | |
| - Thought template integration with similarity search | |
| - GNN-based LLM selection system | |
| - Interactive Gradio web interface | |
| Dependencies: | |
| - Standard Python packages (torch, gradio, transformers, etc.) | |
| - NVIDIA API access (configured in the client) | |
| - No local model files or external scripts required | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import GCNConv, global_mean_pool | |
| from torch_geometric.data import Data, Batch | |
| import numpy as np | |
| from transformers import pipeline, LongformerModel, LongformerTokenizer | |
| import requests | |
| import json | |
| import pandas as pd | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Tuple, Dict, Optional, Union | |
| import os | |
| from datasets import load_dataset | |
| from openai import OpenAI | |
| # Graph Router Integration Imports | |
| import sys | |
| import yaml | |
| from transformers import LongformerTokenizer as RouterTokenizer, LongformerModel as RouterModel | |
| # Load environment variables from .env file (for local development only) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| print("β .env file loaded successfully (local development)") | |
| except ImportError: | |
| print("Warning: python-dotenv not installed. Install with: pip install python-dotenv") | |
| print("Or set NVIDIA_API_KEY environment variable manually") | |
| except FileNotFoundError: | |
| print("βΉοΈ No .env file found - using environment variables directly") | |
| # Check for API key | |
| if os.getenv("NVIDIA_API_KEY") is None: | |
| print("β NVIDIA_API_KEY not found in environment variables") | |
| print("For local development: Create a .env file with: NVIDIA_API_KEY=your_api_key_here") | |
| print("For Hugging Face Spaces: Set NVIDIA_API_KEY in Repository Secrets") | |
| print("β οΈ Some features will be limited without API access") | |
| else: | |
| print("β NVIDIA API key loaded from environment") | |
| NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1" | |
| # Add GraphRouter_eval to path | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'GraphRouter_eval/model')) | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'GraphRouter_eval/data_processing')) | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'GraphRouter_eval')) | |
| try: | |
| # Import the GraphRouter_eval package | |
| import sys | |
| import os | |
| # Add the parent directory to Python path so we can import GraphRouter_eval as a package | |
| current_dir = os.path.dirname(__file__) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| # Import the required modules | |
| from GraphRouter_eval.model.multi_task_graph_router import graph_router_prediction | |
| GRAPH_ROUTER_AVAILABLE = True | |
| print("β Graph router successfully imported") | |
| except ImportError as e: | |
| print(f"Warning: Graph router not available: {e}") | |
| GRAPH_ROUTER_AVAILABLE = False | |
| # Set up CUDA device for faster embedding calculations | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| print(f"CUDA device set to: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device count: {torch.cuda.device_count()}") | |
| print(f"Current CUDA device: {torch.cuda.current_device()}") | |
| print(f"CUDA device name: {torch.cuda.get_device_name(0)}") | |
| # Initialize OpenAI client for NVIDIA API | |
| if os.getenv("NVIDIA_API_KEY") is None: | |
| print("β NVIDIA API key not found. Please create a .env file with your API key") | |
| client = None | |
| else: | |
| client = OpenAI( | |
| base_url=NVIDIA_BASE_URL, | |
| api_key=os.getenv("NVIDIA_API_KEY"), | |
| timeout=60, | |
| max_retries=2 | |
| ) | |
| print("β NVIDIA API client initialized successfully") | |
| def model_prompting( | |
| llm_model: str, | |
| prompt: str, | |
| max_token_num: Optional[int] = 1024, | |
| temperature: Optional[float] = 0.2, | |
| top_p: Optional[float] = 0.7, | |
| stream: Optional[bool] = True, | |
| ) -> Union[str, None]: | |
| """ | |
| Get a response from an LLM model using the OpenAI-compatible NVIDIA API. | |
| Args: | |
| llm_model: Name of the model to use (e.g., "meta/llama-3.1-8b-instruct") | |
| prompt: Input prompt text | |
| max_token_num: Maximum number of tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Top-p sampling parameter | |
| stream: Whether to stream the response | |
| Returns: | |
| Generated text response | |
| """ | |
| if client is None: | |
| raise Exception("NVIDIA API client not initialized. Please check your .env file contains NVIDIA_API_KEY") | |
| try: | |
| completion = client.chat.completions.create( | |
| model=llm_model, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=max_token_num, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stream=stream | |
| ) | |
| response_text = "" | |
| for chunk in completion: | |
| if chunk.choices[0].delta.content is not None: | |
| response_text += chunk.choices[0].delta.content | |
| return response_text | |
| except Exception as e: | |
| raise Exception(f"API call failed: {str(e)}") | |
| # Initialize the Longformer model for embeddings (same as enhance_query_with_templates.py) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Device set to use: {device}") | |
| MODEL_NAME = "allenai/longformer-base-4096" | |
| try: | |
| tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME) | |
| model_long = LongformerModel.from_pretrained(MODEL_NAME) | |
| # Ensure model is on the correct device | |
| model_long = model_long.to(device) | |
| print(f"Successfully loaded Longformer model: {MODEL_NAME} on {device}") | |
| except Exception as e: | |
| print(f"Warning: Could not load Longformer model: {e}") | |
| tokenizer = None | |
| model_long = None | |
| def get_longformer_representation(text): | |
| """ | |
| Get representations of long text using Longformer on CUDA:0 device | |
| """ | |
| if model_long is None or tokenizer is None: | |
| raise Exception("Longformer model not available") | |
| # Set model to evaluation mode for faster inference | |
| model_long.eval() | |
| inputs = tokenizer(text, return_tensors="pt", max_length=4096, truncation=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| global_attention_mask = torch.zeros( | |
| inputs["input_ids"].shape, | |
| dtype=torch.long, | |
| device=device | |
| ) | |
| global_attention_mask[:, 0] = 1 | |
| # Use torch.no_grad() for faster inference and less memory usage | |
| with torch.no_grad(): | |
| outputs = model_long( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| global_attention_mask=global_attention_mask, | |
| output_hidden_states=True | |
| ) | |
| # Move result to CPU and convert to numpy for faster processing | |
| return outputs.last_hidden_state[0, 0, :].cpu() | |
| def get_embedding(instructions: List[str]) -> np.ndarray: | |
| """ | |
| Get embeddings for a list of texts using Longformer. | |
| """ | |
| if model_long is None: | |
| raise Exception("Longformer model not available") | |
| try: | |
| embeddings = [] | |
| # Process in batches for better GPU utilization | |
| batch_size = 4 # Adjust based on GPU memory | |
| for i in range(0, len(instructions), batch_size): | |
| batch_texts = instructions[i:i + batch_size] | |
| batch_embeddings = [] | |
| for text in batch_texts: | |
| representation = get_longformer_representation(text) | |
| batch_embeddings.append(representation.numpy()) | |
| embeddings.extend(batch_embeddings) | |
| return np.array(embeddings) | |
| except Exception as e: | |
| raise Exception(f"Error generating embeddings: {str(e)}") | |
| def parse_embedding(embedding_str): | |
| """Parse embedding string to numpy array, handling different formats.""" | |
| if embedding_str is None: | |
| return None | |
| if isinstance(embedding_str, np.ndarray): | |
| return embedding_str | |
| try: | |
| if isinstance(embedding_str, str) and 'tensor' in embedding_str: | |
| clean_str = embedding_str.replace('tensor(', '').replace(')', '') | |
| if 'device=' in clean_str: | |
| clean_str = clean_str.split('device=')[0].strip() | |
| clean_str = clean_str.replace('\n', '').replace(' ', '') | |
| embedding = np.array(eval(clean_str)) | |
| if embedding.ndim == 2 and embedding.shape[0] == 1: | |
| embedding = embedding.squeeze(0) | |
| return embedding | |
| elif isinstance(embedding_str, str): | |
| clean_str = embedding_str.replace('[', '').replace(']', '') | |
| return np.array([float(x) for x in clean_str.split(',') if x.strip()]) | |
| elif isinstance(embedding_str, (int, float)): | |
| return np.array([embedding_str]) | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Error parsing embedding: {str(e)}") | |
| return None | |
| def get_template_subset_name(model_size: str, template_size: str) -> str: | |
| """ | |
| Get the HuggingFace dataset subset name based on model size and template size. | |
| """ | |
| return f"thought_template_{model_size}_{template_size}" | |
| def load_template_dataset(model_size: str, template_size: str) -> pd.DataFrame: | |
| """ | |
| Load thought templates from HuggingFace dataset with robust error handling for Spaces deployment. | |
| """ | |
| subset_name = get_template_subset_name(model_size, template_size) | |
| # Try multiple approaches to load the dataset | |
| approaches = [ | |
| # Approach 1: Direct load with timeout | |
| lambda: load_dataset("ulab-ai/FusionBench", subset_name, trust_remote_code=True), | |
| # Approach 2: Load with cache_dir specification | |
| lambda: load_dataset("ulab-ai/FusionBench", subset_name, cache_dir="./cache", trust_remote_code=True), | |
| # Approach 3: Load with streaming (for large datasets) | |
| lambda: load_dataset("ulab-ai/FusionBench", subset_name, streaming=True, trust_remote_code=True), | |
| ] | |
| for i, approach in enumerate(approaches, 1): | |
| try: | |
| print(f"Attempting to load templates (approach {i}): ulab-ai/FusionBench, subset: {subset_name}") | |
| dataset = approach() | |
| # Handle streaming dataset | |
| if hasattr(dataset, 'iter') and callable(dataset.iter): | |
| # Convert streaming dataset to list | |
| data_list = list(dataset['data']) | |
| template_df = pd.DataFrame(data_list) | |
| else: | |
| # Regular dataset | |
| template_df = pd.DataFrame(dataset['data']) | |
| print(f"β Successfully loaded {len(template_df)} templates from {subset_name}") | |
| return template_df | |
| except Exception as e: | |
| print(f"β Approach {i} failed: {str(e)}") | |
| if i == len(approaches): | |
| # All approaches failed, provide detailed error | |
| error_msg = f"Failed to load template dataset {subset_name} after trying {len(approaches)} approaches. Last error: {str(e)}" | |
| print(error_msg) | |
| # Return empty DataFrame with warning | |
| print("β οΈ Returning empty template DataFrame - functionality will be limited") | |
| return pd.DataFrame(columns=['query', 'thought_template', 'task_description', 'query_embedding']) | |
| # This should never be reached, but just in case | |
| return pd.DataFrame(columns=['query', 'thought_template', 'task_description', 'query_embedding']) | |
| def enhance_query_with_templates( | |
| model_size: str, | |
| template_size: str, | |
| query: str, | |
| query_embedding: Optional[np.ndarray] = None, | |
| task_description: Optional[str] = None, | |
| top_k: int = 3 | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Enhance a query with thought templates by finding similar templates and creating an enhanced prompt. | |
| """ | |
| if model_size not in ["70b", "8b"]: | |
| raise ValueError("model_size must be either '70b' or '8b'") | |
| if template_size not in ["full", "small"]: | |
| raise ValueError("template_size must be either 'full' or 'small'") | |
| # Load template data from HuggingFace dataset | |
| template_df = load_template_dataset(model_size, template_size) | |
| # Check if dataset is empty (failed to load) | |
| if template_df.empty: | |
| print("β οΈ Template dataset is empty - returning original query") | |
| return query, [] | |
| # Generate embedding for the query if not provided | |
| if query_embedding is None: | |
| try: | |
| query_embedding = get_embedding([query])[0] | |
| print(f"Generated embedding for query: {query[:50]}...") | |
| except Exception as e: | |
| print(f"Failed to generate embedding for query: {str(e)}") | |
| return query, [] | |
| # Filter templates by task description if provided | |
| if task_description is None or not task_description.strip(): | |
| matching_templates = template_df | |
| print(f"Using all {len(matching_templates)} templates (no task filter)") | |
| else: | |
| matching_templates = template_df[template_df['task_description'] == task_description] | |
| if matching_templates.empty: | |
| task_desc_lower = task_description.lower() | |
| partial_matches = template_df[template_df['task_description'].str.lower().str.contains(task_desc_lower.split()[0], na=False)] | |
| if not partial_matches.empty: | |
| matching_templates = partial_matches | |
| print(f"Found partial matches for task: {task_description[:50]}... ({len(matching_templates)} templates)") | |
| else: | |
| print(f"No matching templates found for task: {task_description[:50]}... - using all templates") | |
| matching_templates = template_df | |
| if matching_templates.empty: | |
| print("No matching templates found. Returning original query.") | |
| return query, [] | |
| print(f"Processing {len(matching_templates)} templates for similarity calculation...") | |
| # Calculate similarities with template embeddings | |
| similarities = [] | |
| for t_idx, t_row in matching_templates.iterrows(): | |
| template_embedding = None | |
| # Try to parse existing template embedding | |
| if 'query_embedding' in t_row and not pd.isna(t_row['query_embedding']): | |
| try: | |
| template_embedding = parse_embedding(t_row['query_embedding']) | |
| except Exception as e: | |
| print(f"Failed to parse template embedding: {str(e)}") | |
| template_embedding = None | |
| # If no valid embedding found, generate one for the template query | |
| if template_embedding is None and 'query' in t_row: | |
| try: | |
| template_embedding = get_embedding([t_row['query']])[0] | |
| print(f"Generated embedding for template query: {t_row['query'][:50]}...") | |
| except Exception as e: | |
| print(f"Failed to generate embedding for template query: {str(e)}") | |
| continue | |
| if template_embedding is not None: | |
| try: | |
| q_emb = query_embedding.reshape(1, -1) | |
| t_emb = template_embedding.reshape(1, -1) | |
| if q_emb.shape[1] != t_emb.shape[1]: | |
| print(f"Dimension mismatch: query={q_emb.shape[1]}, template={t_emb.shape[1]}") | |
| continue | |
| sim = cosine_similarity(q_emb, t_emb)[0][0] | |
| similarities.append((t_idx, sim)) | |
| except Exception as e: | |
| print(f"Error calculating similarity: {str(e)}") | |
| continue | |
| if not similarities: | |
| print("No valid similarities found. Returning original query.") | |
| return query, [] | |
| # Sort by similarity (highest first) and get top k | |
| similarities.sort(key=lambda x: x[1], reverse=True) | |
| top_n = min(top_k, len(similarities)) | |
| top_indices = [idx for idx, _ in similarities[:top_n]] | |
| top_templates = matching_templates.loc[top_indices] | |
| print(f"Found {len(similarities)} similar templates, selected top {top_n}") | |
| print(f"Top similarity scores: {[sim for _, sim in similarities[:top_n]]}") | |
| # Create enhanced query | |
| enhanced_query = "Here are some similar questions and guidelines in how to solve them:\n\n" | |
| retrieved_templates = [] | |
| for i, (t_idx, t_row) in enumerate(top_templates.iterrows(), 1): | |
| enhanced_query += f"Question{i}: {t_row['query']}\n\n" | |
| enhanced_query += f"Thought Template {i}: {t_row['thought_template']}\n\n" | |
| retrieved_templates.append({ | |
| 'index': i, | |
| 'query': t_row['query'], | |
| 'thought_template': t_row['thought_template'], | |
| 'task_description': t_row.get('task_description', ''), | |
| 'similarity_score': similarities[i-1][1] if i-1 < len(similarities) else None | |
| }) | |
| enhanced_query += "Now, please solve the following question:\n\n" | |
| enhanced_query += query | |
| enhanced_query += "\n\n Use the thought templates above as guidance. Reason step by step. And provide the final answer! The final answer should be enclosed in <answer> and </answer> tags." | |
| return enhanced_query, retrieved_templates | |
| def load_thought_templates(template_style): | |
| """ | |
| Load thought templates based on the selected style using HuggingFace datasets. | |
| """ | |
| # Map template style to model_size and template_size | |
| style_mapping = { | |
| "8b_full": ("8b", "full"), | |
| "8b_small": ("8b", "small"), | |
| "70b_full": ("70b", "full"), | |
| "70b_small": ("70b", "small") | |
| } | |
| if template_style not in style_mapping: | |
| return None, f"Template style '{template_style}' not found" | |
| model_size, template_size = style_mapping[template_style] | |
| try: | |
| template_df = load_template_dataset(model_size, template_size) | |
| return template_df, f"Successfully loaded {len(template_df)} templates from {template_style}" | |
| except Exception as e: | |
| return None, f"Error loading templates: {str(e)}" | |
| # GNN network for LLM selection | |
| class LLMSelectorGNN(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, num_llms): | |
| super(LLMSelectorGNN, self).__init__() | |
| self.conv1 = GCNConv(input_dim, hidden_dim) | |
| self.conv2 = GCNConv(hidden_dim, hidden_dim) | |
| self.classifier = nn.Linear(hidden_dim, num_llms) | |
| self.dropout = nn.Dropout(0.2) | |
| def forward(self, x, edge_index, batch): | |
| # GNN forward pass | |
| x = F.relu(self.conv1(x, edge_index)) | |
| x = self.dropout(x) | |
| x = F.relu(self.conv2(x, edge_index)) | |
| # Graph-level pooling | |
| x = global_mean_pool(x, batch) | |
| # Classifier output for LLM selection probabilities | |
| logits = self.classifier(x) | |
| return F.softmax(logits, dim=1) | |
| # Model name mapping dictionary | |
| MODEL_MAPPING = { | |
| "granite-3.0-8b-instruct": "ibm/granite-3.0-8b-instruct", | |
| "qwen2.5-7b-instruct": "qwen/qwen2.5-7b-instruct", | |
| "llama-3.1-8b-instruct": "meta/llama-3.1-8b-instruct", | |
| "mistral-nemo-12b-instruct": "nv-mistralai/mistral-nemo-12b-instruct" | |
| } | |
| def get_mapped_model_name(model_name: str) -> str: | |
| """Map the input model name to the correct API model name""" | |
| return MODEL_MAPPING.get(model_name, model_name) | |
| # LLM configurations | |
| LLM_CONFIGS = { | |
| 0: { | |
| "name": "GPT-3.5 (General Tasks)", | |
| "description": "Suitable for daily conversations and general text generation", | |
| "model_type": "openai", | |
| "api_model": "granite-3.0-8b-instruct" | |
| }, | |
| 1: { | |
| "name": "Claude (Analysis & Reasoning)", | |
| "description": "Excels at logical analysis and complex reasoning tasks", | |
| "model_type": "anthropic", | |
| "api_model": "qwen2.5-7b-instruct" | |
| }, | |
| 2: { | |
| "name": "LLaMA (Code Generation)", | |
| "description": "Specialized model optimized for code generation", | |
| "model_type": "meta", | |
| "api_model": "llama-3.1-8b-instruct" | |
| }, | |
| 3: { | |
| "name": "Gemini (Multimodal)", | |
| "description": "Supports text, image and other multimodal tasks", | |
| "model_type": "google", | |
| "api_model": "mistral-nemo-12b-instruct" | |
| } | |
| } | |
| # Prompt Templates | |
| PROMPT_TEMPLATES = { | |
| "code_assistant": "You are an expert programming assistant. Please help with the following coding task:\n\nTask: {query}\n\nRequirements:\n- Provide clean, well-commented code\n- Explain the logic and approach\n- Include error handling where appropriate\n- Suggest best practices\n\nResponse:", | |
| "academic_tutor": "You are a knowledgeable academic tutor. Please help explain the following topic:\n\nTopic: {query}\n\nPlease provide:\n- Clear, structured explanation\n- Key concepts and definitions\n- Real-world examples or applications\n- Practice questions or exercises if relevant\n\nExplanation:", | |
| "business_consultant": "You are a strategic business consultant. Please analyze the following business scenario:\n\nScenario: {query}\n\nPlease provide:\n- Situation analysis\n- Key challenges and opportunities\n- Strategic recommendations\n- Implementation considerations\n- Risk assessment\n\nAnalysis:", | |
| "creative_writer": "You are a creative writing assistant. Please help with the following creative task:\n\nCreative Request: {query}\n\nPlease provide:\n- Original and engaging content\n- Rich descriptions and imagery\n- Appropriate tone and style\n- Creative elements and storytelling techniques\n\nCreative Response:", | |
| "research_analyst": "You are a thorough research analyst. Please investigate the following topic:\n\nResearch Topic: {query}\n\nPlease provide:\n- Comprehensive overview\n- Key findings and insights\n- Data analysis and trends\n- Reliable sources and references\n- Conclusions and implications\n\nResearch Report:", | |
| "custom": "{template}\n\nQuery: {query}\n\nResponse:" | |
| } | |
| class GNNLLMSystem: | |
| def __init__(self): | |
| # Initialize GNN model | |
| self.gnn_model = LLMSelectorGNN(input_dim=768, hidden_dim=256, num_llms=4) | |
| self.load_pretrained_model() | |
| # Initialize local LLM pipeline (as backup) | |
| try: | |
| self.local_llm = pipeline("text-generation", | |
| model="microsoft/DialoGPT-medium", | |
| tokenizer="microsoft/DialoGPT-medium") | |
| except: | |
| self.local_llm = None | |
| def load_pretrained_model(self): | |
| """Load pretrained GNN model weights""" | |
| # Load your trained model weights here | |
| # self.gnn_model.load_state_dict(torch.load('gnn_selector.pth')) | |
| # For demonstration purposes, we use randomly initialized weights | |
| pass | |
| def query_to_graph(self, query): | |
| """Convert query to graph structure""" | |
| # This is a simplified implementation, you need to design based on specific requirements | |
| words = query.lower().split() | |
| # Create node features (simulated with simple word embeddings) | |
| vocab_size = 1000 | |
| node_features = [] | |
| for word in words: | |
| # Simple hash mapping to feature vector | |
| hash_val = hash(word) % vocab_size | |
| feature = np.random.randn(768) # Simulate 768-dim word embedding | |
| feature[hash_val % 768] += 1.0 # Add some structural information | |
| node_features.append(feature) | |
| if len(node_features) == 0: | |
| node_features = [np.random.randn(768)] | |
| # Create edge connections (fully connected graph as example) | |
| num_nodes = len(node_features) | |
| edge_index = [] | |
| for i in range(num_nodes): | |
| for j in range(i + 1, num_nodes): | |
| edge_index.extend([[i, j], [j, i]]) | |
| if len(edge_index) == 0: | |
| edge_index = [[0, 0]] | |
| # Convert to PyTorch tensors | |
| x = torch.FloatTensor(node_features) | |
| edge_index = torch.LongTensor(edge_index).t().contiguous() | |
| return Data(x=x, edge_index=edge_index) | |
| def select_llm(self, query): | |
| """Use GNN to select the most suitable LLM""" | |
| # Convert query to graph | |
| graph_data = self.query_to_graph(query) | |
| batch = torch.zeros(graph_data.x.size(0), dtype=torch.long) | |
| # GNN inference | |
| with torch.no_grad(): | |
| self.gnn_model.eval() | |
| probabilities = self.gnn_model(graph_data.x, graph_data.edge_index, batch) | |
| selected_llm_idx = torch.argmax(probabilities, dim=1).item() | |
| confidence = probabilities[0][selected_llm_idx].item() | |
| return selected_llm_idx, confidence, probabilities[0].tolist() | |
| def generate_response(self, query, selected_llm_idx, use_template=False, template_key=None, custom_template=None): | |
| """Generate response using selected LLM and optional template""" | |
| llm_config = LLM_CONFIGS[selected_llm_idx] | |
| # Apply template if requested | |
| if use_template: | |
| if template_key == "custom" and custom_template: | |
| formatted_query = PROMPT_TEMPLATES["custom"].format(template=custom_template, query=query) | |
| elif template_key in PROMPT_TEMPLATES: | |
| formatted_query = PROMPT_TEMPLATES[template_key].format(query=query) | |
| else: | |
| formatted_query = query | |
| else: | |
| formatted_query = query | |
| try: | |
| # Get the API model name | |
| api_model = llm_config.get("api_model", "llama-3.1-8b-instruct") | |
| mapped_model_name = get_mapped_model_name(api_model) | |
| # Call the actual API | |
| response = model_prompting( | |
| llm_model=mapped_model_name, | |
| prompt=formatted_query, | |
| max_token_num=4096, | |
| temperature=0.0, | |
| top_p=0.9, | |
| stream=True | |
| ) | |
| return response | |
| except Exception as e: | |
| # Fallback to local LLM or error message | |
| error_msg = f"API Error: {str(e)}" | |
| if self.local_llm: | |
| try: | |
| result = self.local_llm(formatted_query, max_length=100, num_return_sequences=1) | |
| return result[0]['generated_text'] | |
| except: | |
| return f"Sorry, unable to generate response. Error: {error_msg}" | |
| else: | |
| return f"Sorry, unable to generate response. Error: {error_msg}" | |
| # Create system instance | |
| gnn_llm_system = GNNLLMSystem() | |
| # LLM Name Mapping from Graph Router to API Models | |
| LLM_NAME_MAPPING = { | |
| "qwen2-7b-instruct": "qwen/qwen2-7b-instruct", | |
| "qwen2.5-7b-instruct": "qwen/qwen2.5-7b-instruct", | |
| "gemma-7b": "google/gemma-7b", | |
| "codegemma-7b": "google/codegemma-7b", | |
| "gemma-2-9b-it": "google/gemma-2-9b-it", | |
| "llama-3.1-8b-instruct": "meta/llama-3.1-8b-instruct", | |
| "granite-3.0-8b-instruct": "ibm/granite-3.0-8b-instruct", | |
| "llama3-chatqa-1.5-8b": "nvidia/llama3-chatqa-1.5-8b", | |
| "mistral-nemo-12b-instruct": "nv-mistralai/mistral-nemo-12b-instruct", | |
| "mistral-7b-instruct-v0.3": "mistralai/mistral-7b-instruct-v0.3", | |
| "llama-3.3-nemotron-super-49b-v1": "nvidia/llama-3.3-nemotron-super-49b-v1", | |
| "llama-3.1-nemotron-51b-instruct": "nvidia/llama-3.1-nemotron-51b-instruct", | |
| "llama3-chatqa-1.5-70b": "nvidia/llama3-chatqa-1.5-70b", | |
| "llama-3.1-70b-instruct": "meta/llama3-70b-instruct", | |
| "llama3-70b-instruct": "meta/llama-3.1-8b-instruct", | |
| "granite-34b-code-instruct": "ibm/granite-34b-code-instruct", | |
| "mixtral-8x7b-instruct-v0.1": "mistralai/mixtral-8x7b-instruct-v0.1", | |
| "deepseek-r1": "deepseek-ai/deepseek-r1", | |
| "mixtral-8x22b-instruct-v0.1": "mistralai/mixtral-8x22b-instruct-v0.1", | |
| "palmyra-creative-122b": "writer/palmyra-creative-122b" | |
| } | |
| def map_llm_to_api(llm_name: str) -> str: | |
| """Map graph router LLM name to API model name""" | |
| return LLM_NAME_MAPPING.get(llm_name, "meta/llama-3.1-8b-instruct") # Default fallback | |
| def get_cls_embedding_for_router(text, model_name="allenai/longformer-base-4096", device=None): | |
| """ | |
| Extracts the [CLS] embedding from a given text using Longformer for router. | |
| This is a separate function to avoid conflicts with the existing one. | |
| """ | |
| if device is None: | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load tokenizer and model | |
| tokenizer = RouterTokenizer.from_pretrained(model_name) | |
| model = RouterModel.from_pretrained(model_name).to(device) | |
| model.eval() | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=4096).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| cls_embedding = outputs.last_hidden_state[:, 0, :] # (1, hidden_size) | |
| return cls_embedding | |
| def generate_task_description_for_router(query: str) -> str: | |
| """ | |
| Generate a concise task description using LLM API for router. | |
| """ | |
| prompt = f"""Analyze the following query and provide a concise task description that identifies the type of task and domain it belongs to. Focus on the core problem type and relevant domain areas. | |
| Query: {query} | |
| Please provide a brief, focused task description that captures: | |
| 1. The primary task type (e.g., mathematical calculation, text analysis, coding, reasoning, etc.) | |
| 2. The relevant domain or subject area | |
| 3. The complexity level or approach needed | |
| Keep the description concise and informative. Respond with just the task description, no additional formatting.""" | |
| try: | |
| task_description = model_prompting( | |
| llm_model="meta/llama-3.1-8b-instruct", | |
| prompt=prompt, | |
| max_token_num=256, | |
| temperature=0.1, | |
| top_p=0.9, | |
| stream=True | |
| ) | |
| return task_description.strip() | |
| except Exception as e: | |
| print(f"Warning: Failed to generate task description via API: {str(e)}") | |
| return "General query processing task requiring analysis and response generation." | |
| def get_routed_llm(query: str, config_path: str = None) -> Tuple[str, str, str]: | |
| """ | |
| Use graph router to get the best LLM for a query. | |
| Returns: | |
| Tuple of (routed_llm_name, task_description, selection_info) | |
| """ | |
| if not GRAPH_ROUTER_AVAILABLE: | |
| print("Graph router not available, using fallback") | |
| selection_info = f""" | |
| π **Fallback Mode**: Graph router not available | |
| π€ **Selected LLM**: llama-3.1-8b-instruct (Default) | |
| π **Task Description**: General query processing | |
| β οΈ **Note**: Using fallback system due to missing graph router components | |
| """ | |
| return "llama-3.1-8b-instruct", "General query processing", selection_info | |
| try: | |
| print(f"Starting graph router analysis for query: {query[:50]}...") | |
| # Store current working directory | |
| original_cwd = os.getcwd() | |
| # Change to GraphRouter_eval directory for relative paths to work | |
| graph_router_dir = os.path.join(os.path.dirname(__file__), 'GraphRouter_eval') | |
| os.chdir(graph_router_dir) | |
| try: | |
| # Use default config path if none provided | |
| if config_path is None: | |
| config_path = 'configs/config.yaml' | |
| # Load configuration | |
| with open(config_path, 'r', encoding='utf-8') as file: | |
| config = yaml.safe_load(file) | |
| # Load training data | |
| train_df = pd.read_csv(config['train_data_path']) | |
| train_df = train_df[train_df["task_name"] != 'quac'] | |
| print(f"Loaded {len(train_df)} training samples") | |
| # Generate embeddings for the query | |
| print("Generating query embeddings...") | |
| user_query_embedding = get_cls_embedding_for_router(query).squeeze(0) | |
| # Generate task description | |
| print("Generating task description...") | |
| user_task_description = generate_task_description_for_router(query) | |
| print(f"Task description: {user_task_description}") | |
| # Generate embeddings for the task description | |
| print("Generating task description embeddings...") | |
| user_task_embedding = get_cls_embedding_for_router(user_task_description).squeeze(0) | |
| # Prepare test dataframe | |
| test_df = train_df.head(config['llm_num']).copy() | |
| test_df['query'] = query | |
| test_df['task_description'] = user_task_description | |
| test_df.loc[0, 'query_embedding'] = str(user_query_embedding) | |
| test_df.loc[0, 'task_description'] = str(user_task_embedding) | |
| # Run graph router prediction | |
| print("Running graph router prediction...") | |
| router = graph_router_prediction( | |
| router_data_train=train_df, | |
| router_data_test=test_df, | |
| llm_path=config['llm_description_path'], | |
| llm_embedding_path=config['llm_embedding_path'], | |
| config=config | |
| ) | |
| # Get the routed LLM name | |
| routed_llm_name = router.test_GNN() | |
| print(f"Graph router selected: {routed_llm_name}") | |
| # Create detailed selection info | |
| api_model = map_llm_to_api(routed_llm_name) | |
| selection_info = f""" | |
| π― **Graph Router Analysis Complete** | |
| π€ **Selected LLM**: {routed_llm_name} | |
| π **Task Description**: {user_task_description} | |
| β **Routing Method**: Advanced Graph Neural Network | |
| π **Analysis**: Query analyzed for optimal model selection | |
| β‘ **Performance**: Cost-performance optimized routing | |
| """ | |
| return routed_llm_name, user_task_description, selection_info | |
| finally: | |
| # Restore original working directory | |
| os.chdir(original_cwd) | |
| except FileNotFoundError as e: | |
| print(f"Configuration file not found: {e}") | |
| selection_info = f""" | |
| β **Configuration Error**: {str(e)} | |
| π **Fallback**: Using default LLM | |
| π€ **Selected LLM**: llama-3.1-8b-instruct (Default) | |
| π **Task Description**: General query processing | |
| π **API Model**: meta/llama-3.1-8b-instruct | |
| """ | |
| return "llama-3.1-8b-instruct", "General query processing", selection_info | |
| except Exception as e: | |
| print(f"Error in graph router: {str(e)}") | |
| selection_info = f""" | |
| β **Graph Router Error**: {str(e)} | |
| π **Fallback**: Using default LLM | |
| π€ **Selected LLM**: llama-3.1-8b-instruct (Default) | |
| π **Task Description**: General query processing | |
| π **API Model**: meta/llama-3.1-8b-instruct | |
| β οΈ **Note**: Advanced routing failed, using fallback system | |
| """ | |
| return "llama-3.1-8b-instruct", "General query processing", selection_info | |
| def process_query(query): | |
| """Main function to process user queries using Graph Router""" | |
| if not query.strip(): | |
| return "Please enter your question", "" | |
| try: | |
| print(f"Processing query: {query[:50]}...") | |
| # Use Graph Router to select the best LLM | |
| routed_llm_name, task_description, selection_info = get_routed_llm(query) | |
| print(f"Graph router selected: {routed_llm_name}") | |
| # Check if the routed LLM name has "_think" suffix | |
| think_mode = False | |
| actual_llm_name = routed_llm_name | |
| if routed_llm_name.endswith("_think"): | |
| think_mode = True | |
| actual_llm_name = routed_llm_name[:-6] # Remove "_think" suffix | |
| print(f"Think mode detected. Actual model: {actual_llm_name}") | |
| # Map the actual LLM name to API model name | |
| api_model = map_llm_to_api(actual_llm_name) | |
| print(f"Mapped to API model: {api_model}") | |
| # Prepare the prompt - append "please think step by step" if in think mode | |
| final_prompt = query | |
| if think_mode: | |
| final_prompt = query + "\n\nPlease think step by step." | |
| print("Added 'please think step by step' to the prompt") | |
| # Generate response using the routed LLM | |
| print("Generating response...") | |
| response = model_prompting( | |
| llm_model=api_model, | |
| prompt=final_prompt, | |
| max_token_num=4096, | |
| temperature=0.0, | |
| top_p=0.9, | |
| stream=True | |
| ) | |
| print("Response generated successfully") | |
| # Update selection info to show think mode if applicable | |
| if think_mode: | |
| selection_info = selection_info.replace( | |
| f"π€ **Selected LLM**: {routed_llm_name}", | |
| f"π€ **Selected LLM**: {actual_llm_name} (Think Mode)" | |
| ) | |
| selection_info = selection_info.replace( | |
| f"π **API Model**: {api_model}", | |
| f"π **API Model**: {api_model}\nπ§ **Mode**: Step-by-step reasoning enabled" | |
| ) | |
| except Exception as e: | |
| print(f"Error in process_query: {str(e)}") | |
| response = f"Error generating response: {str(e)}" | |
| # Update selection info to show error | |
| selection_info = f""" | |
| β **Processing Error**: {str(e)} | |
| π **Fallback**: Using default response | |
| β οΈ **Note**: An error occurred during processing | |
| """ | |
| return response, selection_info | |
| def process_template_query(query, template_type, custom_template): | |
| """Process query using prompt template""" | |
| if not query.strip(): | |
| return "Please enter your question", "", "" | |
| # Use GNN to select LLM | |
| selected_llm_idx, confidence, all_probabilities = gnn_llm_system.select_llm(query) | |
| # Generate selection information | |
| selected_llm_info = LLM_CONFIGS[selected_llm_idx] | |
| template_names = { | |
| "code_assistant": "π» Code Assistant", | |
| "academic_tutor": "π Academic Tutor", | |
| "business_consultant": "πΌ Business Consultant", | |
| "creative_writer": "βοΈ Creative Writer", | |
| "research_analyst": "π¬ Research Analyst", | |
| "custom": "π¨ Custom Template" | |
| } | |
| selection_info = f""" | |
| π― **Template Used**: {template_names.get(template_type, template_type)} | |
| π€ **Selected LLM**: {selected_llm_info['name']} | |
| π **Reason**: {selected_llm_info['description']} | |
| π― **Confidence**: {confidence:.2%} | |
| π **API Model**: {selected_llm_info.get('api_model', 'Unknown')} | |
| **Selection Probabilities for All LLMs**: | |
| """ | |
| for i, prob in enumerate(all_probabilities): | |
| llm_name = LLM_CONFIGS[i]['name'] | |
| selection_info += f"- {llm_name}: {prob:.2%}\n" | |
| # Generate response using template | |
| try: | |
| response = gnn_llm_system.generate_response( | |
| query, selected_llm_idx, use_template=True, | |
| template_key=template_type, custom_template=custom_template | |
| ) | |
| status_message = '<div class="status-success">β Template query processed successfully with API</div>' | |
| except Exception as e: | |
| response = f"Error generating response: {str(e)}" | |
| status_message = '<div class="status-info">β οΈ API call failed, using fallback</div>' | |
| return response, selection_info, status_message | |
| def process_thought_template_query(query, template_style, task_description, top_n): | |
| """Process query using thought templates with similarity search - no routing""" | |
| if not query.strip(): | |
| return "Please enter your question", "", "" | |
| # Process query with thought templates using the new function | |
| try: | |
| # Map template style to model_size and template_size | |
| style_mapping = { | |
| "8b_full": ("8b", "full"), | |
| "8b_small": ("8b", "small"), | |
| "70b_full": ("70b", "full"), | |
| "70b_small": ("70b", "small") | |
| } | |
| if template_style not in style_mapping: | |
| error_msg = f"Invalid template style: {template_style}" | |
| return error_msg, "", "" | |
| model_size, template_size = style_mapping[template_style] | |
| # Use the enhance_query_with_templates function | |
| enhanced_query, retrieved_templates = enhance_query_with_templates( | |
| model_size=model_size, | |
| template_size=template_size, | |
| query=query, | |
| task_description=task_description if task_description.strip() else None, | |
| top_k=top_n | |
| ) | |
| # Generate response using Llama3.1 8B model (actual API call) | |
| try: | |
| llama_response = model_prompting( | |
| llm_model="meta/llama-3.1-8b-instruct", | |
| prompt=enhanced_query, | |
| max_token_num=4096, | |
| temperature=0.0, | |
| top_p=0.9, | |
| stream=True | |
| ) | |
| except Exception as e: | |
| llama_response = f"[API Error] Unable to generate response: {str(e)}\n\nEnhanced Query: {enhanced_query}" | |
| # Create template information display | |
| template_info = f""" | |
| ## π§ Used Thought Templates | |
| **Template Style**: {template_style} | |
| **Number of Templates**: {len(retrieved_templates)} | |
| **Benchmark Task**: {task_description if task_description.strip() else 'All Tasks'} | |
| **API Model**: meta/llama-3.1-8b-instruct | |
| **Status**: {'β API call successful' if 'API Error' not in llama_response else 'β οΈ API call failed'} | |
| ### Retrieved Templates: | |
| """ | |
| for template in retrieved_templates: | |
| template_info += f""" | |
| **Template {template['index']}** (Similarity: {template['similarity_score']:.4f}): | |
| - **Query**: {template['query']} | |
| - **Task**: {template['task_description']} | |
| - **Template**: {template['thought_template']} | |
| """ | |
| return enhanced_query, template_info, llama_response | |
| except Exception as e: | |
| error_msg = f"Error processing thought template query: {str(e)}" | |
| return error_msg, "", "" | |
| # Test function to verify dropdown functionality | |
| def test_dropdown_functionality(): | |
| """Test function to verify dropdown components are working""" | |
| print("Testing dropdown functionality...") | |
| # Test template style mapping | |
| style_mapping = { | |
| "8b_full": ("8b", "full"), | |
| "8b_small": ("8b", "small"), | |
| "70b_full": ("70b", "full"), | |
| "70b_small": ("70b", "small") | |
| } | |
| for style, (model_size, template_size) in style_mapping.items(): | |
| print(f"β Template style '{style}' maps to model_size='{model_size}', template_size='{template_size}'") | |
| # Test benchmark task options | |
| benchmark_tasks = [ | |
| ("All Tasks", ""), | |
| ("ARC-Challenge", "ARC-Challenge"), | |
| ("BoolQ", "BoolQ"), | |
| ("CommonsenseQA", "CommonsenseQA"), | |
| ("GPQA", "GPQA"), | |
| ("GSM8K", "GSM8K"), | |
| ("HellaSwag", "HellaSwag"), | |
| ("HumanEval", "HumanEval"), | |
| ("MATH", "MATH"), | |
| ("MBPP", "MBPP"), | |
| ("MMLU", "MMLU"), | |
| ("Natural Questions", "Natural Questions"), | |
| ("OpenBookQA", "OpenBookQA"), | |
| ("SQuAD", "SQuAD"), | |
| ("TriviaQA", "TriviaQA") | |
| ] | |
| print(f"β {len(benchmark_tasks)} benchmark task options available") | |
| return True | |
| # Run test on import | |
| if __name__ == "__main__": | |
| test_dropdown_functionality() | |
| else: | |
| # Run test when module is imported | |
| try: | |
| test_dropdown_functionality() | |
| except Exception as e: | |
| print(f"Warning: Dropdown functionality test failed: {e}") | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="GNN-LLM System with Prompt Templates", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| /* Theme-robust CSS with CSS variables */ | |
| :root { | |
| --primary-color: #4CAF50; | |
| --secondary-color: #ff6b6b; | |
| --success-color: #28a745; | |
| --info-color: #17a2b8; | |
| --warning-color: #ffc107; | |
| --danger-color: #dc3545; | |
| /* Light theme colors */ | |
| --bg-primary: #ffffff; | |
| --bg-secondary: #f8f9fa; | |
| --bg-info: #f0f8ff; | |
| --bg-template: #fff5f5; | |
| --text-primary: #212529; | |
| --text-secondary: #6c757d; | |
| --border-color: #dee2e6; | |
| --shadow-color: rgba(0, 0, 0, 0.1); | |
| } | |
| /* Dark theme colors */ | |
| [data-theme="dark"] { | |
| --bg-primary: #1a1a1a; | |
| --bg-secondary: #2d2d2d; | |
| --bg-info: #1a2332; | |
| --bg-template: #2d1a1a; | |
| --text-primary: #ffffff; | |
| --text-secondary: #b0b0b0; | |
| --border-color: #404040; | |
| --shadow-color: rgba(255, 255, 255, 0.1); | |
| } | |
| /* Auto-detect system theme */ | |
| @media (prefers-color-scheme: dark) { | |
| :root { | |
| --bg-primary: #1a1a1a; | |
| --bg-secondary: #2d2d2d; | |
| --bg-info: #1a2332; | |
| --bg-template: #2d1a1a; | |
| --text-primary: #ffffff; | |
| --text-secondary: #b0b0b0; | |
| --border-color: #404040; | |
| --shadow-color: rgba(255, 255, 255, 0.1); | |
| } | |
| } | |
| /* Manual theme toggle support */ | |
| .theme-light { | |
| --bg-primary: #ffffff; | |
| --bg-secondary: #f8f9fa; | |
| --bg-info: #f0f8ff; | |
| --bg-template: #fff5f5; | |
| --text-primary: #212529; | |
| --text-secondary: #6c757d; | |
| --border-color: #dee2e6; | |
| --shadow-color: rgba(0, 0, 0, 0.1); | |
| } | |
| .theme-dark { | |
| --bg-primary: #1a1a1a; | |
| --bg-secondary: #2d2d2d; | |
| --bg-info: #1a2332; | |
| --bg-template: #2d1a1a; | |
| --text-primary: #ffffff; | |
| --text-secondary: #b0b0b0; | |
| --border-color: #404040; | |
| --shadow-color: rgba(255, 255, 255, 0.1); | |
| } | |
| /* Theme toggle button styling */ | |
| .theme-toggle { | |
| position: fixed; | |
| top: 20px; | |
| right: 20px; | |
| z-index: 1000; | |
| background: var(--bg-secondary); | |
| border: 2px solid var(--border-color); | |
| border-radius: 50%; | |
| width: 50px; | |
| height: 50px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| cursor: pointer; | |
| transition: all 0.3s ease; | |
| box-shadow: 0 2px 8px var(--shadow-color); | |
| } | |
| .theme-toggle:hover { | |
| transform: scale(1.1); | |
| box-shadow: 0 4px 16px var(--shadow-color); | |
| } | |
| .theme-toggle:active { | |
| transform: scale(0.95); | |
| } | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| /* Theme-robust selection info box */ | |
| .selection-info { | |
| background-color: var(--bg-info); | |
| color: var(--text-primary); | |
| padding: 15px; | |
| border-radius: 10px; | |
| border-left: 4px solid var(--primary-color); | |
| box-shadow: 0 2px 4px var(--shadow-color); | |
| transition: all 0.3s ease; | |
| } | |
| .selection-info:hover { | |
| box-shadow: 0 4px 8px var(--shadow-color); | |
| transform: translateY(-1px); | |
| } | |
| /* Theme-robust template info box */ | |
| .template-info { | |
| background-color: var(--bg-template); | |
| color: var(--text-primary); | |
| padding: 15px; | |
| border-radius: 10px; | |
| border-left: 4px solid var(--secondary-color); | |
| box-shadow: 0 2px 4px var(--shadow-color); | |
| transition: all 0.3s ease; | |
| } | |
| .template-info:hover { | |
| box-shadow: 0 4px 8px var(--shadow-color); | |
| transform: translateY(-1px); | |
| } | |
| /* Enhanced button styling */ | |
| .enhanced-button { | |
| transition: all 0.3s ease; | |
| border-radius: 8px; | |
| font-weight: 500; | |
| } | |
| .enhanced-button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 12px var(--shadow-color); | |
| } | |
| /* Card-like containers */ | |
| .card-container { | |
| background-color: var(--bg-secondary); | |
| border: 1px solid var(--border-color); | |
| border-radius: 12px; | |
| padding: 20px; | |
| margin: 10px 0; | |
| box-shadow: 0 2px 8px var(--shadow-color); | |
| transition: all 0.3s ease; | |
| } | |
| .card-container:hover { | |
| box-shadow: 0 4px 16px var(--shadow-color); | |
| transform: translateY(-2px); | |
| } | |
| /* Status indicators */ | |
| .status-success { | |
| color: var(--success-color); | |
| font-weight: 500; | |
| } | |
| .status-info { | |
| color: var(--info-color); | |
| font-weight: 500; | |
| } | |
| /* Responsive design improvements */ | |
| @media (max-width: 768px) { | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding: 10px; | |
| } | |
| .card-container { | |
| padding: 15px; | |
| margin: 5px 0; | |
| } | |
| } | |
| /* Accessibility improvements */ | |
| .sr-only { | |
| position: absolute; | |
| width: 1px; | |
| height: 1px; | |
| padding: 0; | |
| margin: -1px; | |
| overflow: hidden; | |
| clip: rect(0, 0, 0, 0); | |
| white-space: nowrap; | |
| border: 0; | |
| } | |
| /* Focus indicators for better accessibility */ | |
| button:focus, | |
| input:focus, | |
| textarea:focus, | |
| select:focus { | |
| outline: 2px solid var(--primary-color); | |
| outline-offset: 2px; | |
| } | |
| /* Theme-robust Markdown content */ | |
| .markdown-content { | |
| color: var(--text-primary); | |
| } | |
| .markdown-content h1, | |
| .markdown-content h2, | |
| .markdown-content h3, | |
| .markdown-content h4, | |
| .markdown-content h5, | |
| .markdown-content h6 { | |
| color: var(--text-primary); | |
| border-bottom: 1px solid var(--border-color); | |
| padding-bottom: 8px; | |
| margin-top: 20px; | |
| margin-bottom: 15px; | |
| } | |
| .markdown-content p { | |
| color: var(--text-secondary); | |
| line-height: 1.6; | |
| margin-bottom: 12px; | |
| } | |
| .markdown-content ul, | |
| .markdown-content ol { | |
| color: var(--text-secondary); | |
| padding-left: 20px; | |
| } | |
| .markdown-content li { | |
| margin-bottom: 8px; | |
| color: var(--text-secondary); | |
| } | |
| .markdown-content strong, | |
| .markdown-content b { | |
| color: var(--text-primary); | |
| font-weight: 600; | |
| } | |
| .markdown-content code { | |
| background-color: var(--bg-secondary); | |
| color: var(--text-primary); | |
| padding: 2px 6px; | |
| border-radius: 4px; | |
| border: 1px solid var(--border-color); | |
| font-family: 'Courier New', monospace; | |
| } | |
| .markdown-content pre { | |
| background-color: var(--bg-secondary); | |
| border: 1px solid var(--border-color); | |
| border-radius: 8px; | |
| padding: 15px; | |
| overflow-x: auto; | |
| margin: 15px 0; | |
| } | |
| .markdown-content pre code { | |
| background: none; | |
| border: none; | |
| padding: 0; | |
| } | |
| /* Enhanced template info styling */ | |
| .template-info { | |
| background-color: var(--bg-template); | |
| color: var(--text-primary); | |
| padding: 20px; | |
| border-radius: 12px; | |
| border-left: 4px solid var(--secondary-color); | |
| box-shadow: 0 2px 8px var(--shadow-color); | |
| transition: all 0.3s ease; | |
| margin: 15px 0; | |
| } | |
| .template-info:hover { | |
| box-shadow: 0 4px 16px var(--shadow-color); | |
| transform: translateY(-2px); | |
| } | |
| .template-info h3 { | |
| color: var(--text-primary); | |
| margin-top: 0; | |
| margin-bottom: 15px; | |
| font-size: 1.3em; | |
| } | |
| .template-info p { | |
| color: var(--text-secondary); | |
| margin-bottom: 0; | |
| line-height: 1.5; | |
| } | |
| /* Accordion styling for theme support */ | |
| .accordion-content { | |
| background-color: var(--bg-secondary); | |
| border: 1px solid var(--border-color); | |
| border-radius: 8px; | |
| padding: 20px; | |
| margin: 10px 0; | |
| } | |
| /* Tab styling improvements */ | |
| .tab-nav { | |
| border-bottom: 2px solid var(--border-color); | |
| margin-bottom: 20px; | |
| } | |
| .tab-nav button { | |
| background-color: var(--bg-secondary); | |
| color: var(--text-secondary); | |
| border: none; | |
| padding: 12px 20px; | |
| margin-right: 5px; | |
| border-radius: 8px 8px 0 0; | |
| transition: all 0.3s ease; | |
| } | |
| .tab-nav button.active { | |
| background-color: var(--primary-color); | |
| color: white; | |
| } | |
| .tab-nav button:hover { | |
| background-color: var(--bg-info); | |
| color: var(--text-primary); | |
| } | |
| /* Equal height columns and consistent UI design */ | |
| .equal-height-columns { | |
| display: flex; | |
| align-items: stretch; | |
| } | |
| .equal-height-columns > .column { | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .equal-height-columns .card-container { | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .equal-height-columns .card-container > * { | |
| flex: 1; | |
| } | |
| .equal-height-columns .card-container textarea, | |
| .equal-height-columns .card-container .textbox { | |
| flex: 1; | |
| min-height: 200px; | |
| } | |
| .equal-height-columns .card-container .textbox textarea { | |
| height: 100% !important; | |
| min-height: 200px !important; | |
| resize: vertical; | |
| overflow-y: auto !important; | |
| word-wrap: break-word !important; | |
| white-space: pre-wrap !important; | |
| } | |
| /* Force textbox to show content properly */ | |
| .equal-height-columns .card-container .textbox { | |
| min-height: 250px; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .equal-height-columns .card-container .textbox > div { | |
| flex: 1; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .equal-height-columns .card-container .textbox > div > textarea { | |
| flex: 1; | |
| height: auto !important; | |
| min-height: 200px !important; | |
| } | |
| /* Ensure Enhanced Query textbox fills available height */ | |
| .equal-height-columns .card-container .textbox[data-testid*="enhanced"] { | |
| height: 100%; | |
| } | |
| .equal-height-columns .card-container .textbox[data-testid*="enhanced"] textarea { | |
| height: 100% !important; | |
| min-height: 300px !important; | |
| resize: vertical; | |
| } | |
| /* Consistent section styling */ | |
| .content-section { | |
| background-color: var(--bg-secondary); | |
| border: 1px solid var(--border-color); | |
| border-radius: 12px; | |
| padding: 20px; | |
| margin: 10px 0; | |
| box-shadow: 0 2px 8px var(--shadow-color); | |
| transition: all 0.3s ease; | |
| } | |
| .content-section:hover { | |
| box-shadow: 0 4px 16px var(--shadow-color); | |
| transform: translateY(-2px); | |
| } | |
| .content-section h3 { | |
| color: var(--text-primary); | |
| margin-top: 0; | |
| margin-bottom: 15px; | |
| font-size: 1.2em; | |
| border-bottom: 1px solid var(--border-color); | |
| padding-bottom: 8px; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π LLM RoutePilot | |
| This system uses an advanced Graph Neural Network (GNN) router to analyze your query and automatically selects the most suitable Large Language Model (LLM) from a pool of 10+ models to answer your questions. | |
| ## π System Features: | |
| - π§ **Advanced Graph Router**: Sophisticated GNN-based routing system with 10+ LLM options | |
| - π― **Intelligent Selection**: Analyzes query content, task type, and domain to choose optimal LLM | |
| - π **Cost-Performance Optimization**: Routes based on cost and performance trade-offs | |
| - π¨ **Prompt Templates**: Use structured templates for specialized responses | |
| - β‘ **Real-time Processing**: Fast response to user queries | |
| - π **Theme Support**: Automatically adapts to light and dark themes | |
| - π **Fallback System**: Graceful degradation if advanced routing fails | |
| """, elem_classes=["markdown-content"]) | |
| # Theme toggle button | |
| gr.HTML(""" | |
| <div class="theme-toggle" onclick="toggleTheme()" title="Toggle theme"> | |
| <span id="theme-icon">π</span> | |
| </div> | |
| <script> | |
| // Theme management | |
| let currentTheme = localStorage.getItem('theme') || 'auto'; | |
| function setTheme(theme) { | |
| const root = document.documentElement; | |
| const icon = document.getElementById('theme-icon'); | |
| // Remove existing theme classes | |
| root.classList.remove('theme-light', 'theme-dark'); | |
| if (theme === 'light') { | |
| root.classList.add('theme-light'); | |
| icon.textContent = 'π'; | |
| } else if (theme === 'dark') { | |
| root.classList.add('theme-dark'); | |
| icon.textContent = 'βοΈ'; | |
| } else { | |
| // Auto theme - use system preference | |
| if (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches) { | |
| root.classList.add('theme-dark'); | |
| icon.textContent = 'βοΈ'; | |
| } else { | |
| root.classList.add('theme-light'); | |
| icon.textContent = 'π'; | |
| } | |
| } | |
| localStorage.setItem('theme', theme); | |
| currentTheme = theme; | |
| } | |
| function toggleTheme() { | |
| if (currentTheme === 'auto') { | |
| // If auto, switch to light | |
| setTheme('light'); | |
| } else if (currentTheme === 'light') { | |
| // If light, switch to dark | |
| setTheme('dark'); | |
| } else { | |
| // If dark, switch to auto | |
| setTheme('auto'); | |
| } | |
| } | |
| // Initialize theme on page load | |
| document.addEventListener('DOMContentLoaded', function() { | |
| setTheme(currentTheme); | |
| }); | |
| // Listen for system theme changes | |
| window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', function(e) { | |
| if (currentTheme === 'auto') { | |
| setTheme('auto'); | |
| } | |
| }); | |
| </script> | |
| """) | |
| with gr.Tabs(): | |
| # Original Tab - GNN-LLM System | |
| with gr.TabItem("π€ Advanced Graph Router"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes=["card-container"]): | |
| query_input = gr.Textbox( | |
| label="π¬ Enter Your Question", | |
| placeholder="Please enter the question you want to ask...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| submit_btn = gr.Button( | |
| "π Submit Query", | |
| variant="primary", | |
| scale=1, | |
| elem_classes=["enhanced-button"] | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Group(elem_classes=["card-container"]): | |
| selection_output = gr.Textbox( | |
| label="π― Graph Router Analysis", | |
| lines=3, | |
| max_lines=5, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Group(elem_classes=["card-container"]): | |
| response_output = gr.Textbox( | |
| label="π AI Response", | |
| lines=8, | |
| max_lines=15, | |
| interactive=False | |
| ) | |
| # Event handling | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=[query_input], | |
| outputs=[response_output, selection_output], | |
| show_progress=True | |
| ) | |
| query_input.submit( | |
| fn=process_query, | |
| inputs=[query_input], | |
| outputs=[response_output, selection_output], | |
| show_progress=True | |
| ) | |
| # New Tab - Thought Template Assistant | |
| with gr.TabItem("π§ Thought Template Assistant"): | |
| gr.Markdown(""" | |
| ### π§ Thought Template System with Similarity Search | |
| This system uses embedding-based similarity search to find the most relevant thought templates for your query. | |
| It then generates a structured thought prompt and provides a response using Llama3.1 8B model. | |
| """, elem_classes=["template-info"]) | |
| with gr.Row(elem_classes=["equal-height-columns"]): | |
| with gr.Column(scale=1, elem_classes=["column"]): | |
| with gr.Group(elem_classes=["card-container"]): | |
| thought_query_input = gr.Textbox( | |
| label="π¬ Enter Your Question", | |
| placeholder="Please enter the question you want to analyze with thought templates...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| thought_template_style = gr.Dropdown( | |
| label="π Select Template Style", | |
| choices=[ | |
| ("8B Full Templates", "8b_full"), | |
| ("8B Small Templates", "8b_small"), | |
| ("70B Full Templates", "70b_full"), | |
| ("70B Small Templates", "70b_small") | |
| ], | |
| value="8b_full" | |
| ) | |
| thought_task_description = gr.Dropdown( | |
| label="π Benchmark Task (Optional)", | |
| choices=[ | |
| ("All Tasks", ""), | |
| ("ARC-Challenge", "ARC-Challenge"), | |
| ("BoolQ", "BoolQ"), | |
| ("CommonsenseQA", "CommonsenseQA"), | |
| ("GPQA", "GPQA"), | |
| ("GSM8K", "GSM8K"), | |
| ("HellaSwag", "HellaSwag"), | |
| ("HumanEval", "HumanEval"), | |
| ("MATH", "MATH"), | |
| ("MBPP", "MBPP"), | |
| ("MMLU", "MMLU"), | |
| ("Natural Questions", "Natural Questions"), | |
| ("OpenBookQA", "OpenBookQA"), | |
| ("SQuAD", "SQuAD"), | |
| ("TriviaQA", "TriviaQA") | |
| ], | |
| value="", | |
| info="Select a specific benchmark task to filter templates, or leave as 'All Tasks' to search across all tasks" | |
| ) | |
| thought_top_n = gr.Slider( | |
| label="π Number of Similar Templates", | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| info="Number of most similar templates to retrieve" | |
| ) | |
| thought_submit_btn = gr.Button( | |
| "π§ Generate Thought Template", | |
| variant="primary", | |
| elem_classes=["enhanced-button"] | |
| ) | |
| with gr.Row(): | |
| with gr.Group(elem_classes=["content-section"]): | |
| enhanced_query_output = gr.Textbox( | |
| label="π Enhanced Query", | |
| lines=15, | |
| max_lines=25, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Group(elem_classes=["content-section"]): | |
| llama_response_output = gr.Textbox( | |
| label="π€ Llama3.1 8B Response", | |
| lines=15, | |
| max_lines=25, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Group(elem_classes=["content-section"]): | |
| thought_templates_output = gr.Textbox( | |
| label="π§ Used Thought Templates", | |
| lines=15, | |
| max_lines=25, | |
| interactive=False | |
| ) | |
| # Event handling for thought template | |
| thought_submit_btn.click( | |
| fn=process_thought_template_query, | |
| inputs=[thought_query_input, thought_template_style, thought_task_description, thought_top_n], | |
| outputs=[enhanced_query_output, thought_templates_output, llama_response_output], | |
| show_progress=True | |
| ) | |
| thought_query_input.submit( | |
| fn=process_thought_template_query, | |
| inputs=[thought_query_input, thought_template_style, thought_task_description, thought_top_n], | |
| outputs=[enhanced_query_output, thought_templates_output, llama_response_output], | |
| show_progress=True | |
| ) | |
| # Add system information | |
| with gr.Accordion("System Information", open=False): | |
| gr.Markdown(""" | |
| ### Technical Architecture: | |
| - **Advanced Graph Router**: Sophisticated Graph Neural Network built with PyTorch Geometric | |
| - **Multi-Model Pool**: Access to 10+ different LLM models with varying capabilities | |
| - **Intelligent Routing**: Analyzes query embeddings, task descriptions, and performance metrics | |
| - **Cost-Performance Optimization**: Routes based on cost and performance trade-offs | |
| - **Feature Extraction**: Converts query text to graph structure for advanced analysis | |
| - **LLM Integration**: Supports API calls to various large language models via NVIDIA API | |
| - **Prompt Templates**: Structured templates for specialized responses | |
| - **Thought Templates**: Embedding-based similarity search for reasoning guidance | |
| - **Interface Framework**: Interactive web interface built with Gradio | |
| - **Theme Support**: Automatically adapts to light and dark themes | |
| ### Available LLM Models (10+ Models): | |
| - **Small Models (7B-12B)**: Fast, cost-effective for simple tasks | |
| - Llama-3.1-8B-Instruct, Qwen2.5-7B-Instruct, Granite-3.0-8B-Instruct | |
| - Gemma-7B, CodeGemma-7B, Mistral-7B-Instruct-v0.3 | |
| - **Medium Models (12B-51B)**: Balanced performance and cost | |
| - Mistral-Nemo-12B-Instruct, Llama3-ChatQA-1.5-8B | |
| - Granite-34B-Code-Instruct, Mixtral-8x7B-Instruct-v0.1 | |
| - **Large Models (51B-122B)**: High performance for complex tasks | |
| - Llama-3.3-Nemotron-Super-49B-v1, Llama-3.1-Nemotron-51B-Instruct | |
| - Llama3-ChatQA-1.5-70B, Llama-3.1-70B-Instruct | |
| - DeepSeek-R1 (671B), Mixtral-8x22B-Instruct-v0.1, Palmyra-Creative-122B | |
| ### Routing Scenarios: | |
| - **Performance First**: Prioritizes model performance over cost | |
| - **Balance**: Balances performance and cost considerations | |
| - **Cost First**: Prioritizes cost-effectiveness over performance | |
| ### Available Templates: | |
| - **π» Code Assistant**: Programming and development tasks | |
| - **π Academic Tutor**: Educational content and learning assistance | |
| - **πΌ Business Consultant**: Strategic business analysis | |
| - **βοΈ Creative Writer**: Creative writing and content creation | |
| - **π¬ Research Analyst**: Research and analysis tasks | |
| - **π¨ Custom Template**: Define your own prompt structure | |
| ### Thought Template Styles: | |
| - **8B Full Templates**: Comprehensive templates for 8B model reasoning | |
| - **8B Small Templates**: Condensed templates for 8B model reasoning | |
| - **70B Full Templates**: Comprehensive templates for 70B model reasoning | |
| - **70B Small Templates**: Condensed templates for 70B model reasoning | |
| ### Available Benchmark Tasks: | |
| - **ARC-Challenge**: AI2 Reasoning Challenge | |
| - **BoolQ**: Boolean Questions | |
| - **CommonsenseQA**: Commonsense Question Answering | |
| - **GPQA**: Graduate-Level Physics Questions | |
| - **GSM8K**: Grade School Math 8K | |
| - **HellaSwag**: HellaSwag Dataset | |
| - **HumanEval**: Human Evaluation | |
| - **MATH**: Mathematics Dataset | |
| - **MBPP**: Mostly Basic Python Problems | |
| - **MMLU**: Massive Multitask Language Understanding | |
| - **Natural Questions**: Natural Questions Dataset | |
| - **OpenBookQA**: Open Book Question Answering | |
| - **SQuAD**: Stanford Question Answering Dataset | |
| - **TriviaQA**: Trivia Question Answering | |
| ### Usage Instructions: | |
| 1. **Advanced Graph Router**: Use the first tab for queries with sophisticated GNN-based routing across 10+ LLMs | |
| 2. **Thought Template Assistant**: Use the second tab for embedding-based similarity search with Llama3.1 8B model (no routing) | |
| 3. System automatically analyzes your query and selects the optimal LLM based on content, task type, and cost-performance trade-offs | |
| 4. View detailed routing information including selected model, task description, and routing method | |
| 5. Get enhanced responses with thought templates (tab 2) | |
| 6. **Theme Support**: The interface automatically adapts to your system's theme preference | |
| 7. **Fallback System**: If advanced routing fails, the system gracefully falls back to a default model | |
| """, elem_classes=["markdown-content"]) | |
| return demo | |
| # Launch application | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True, | |
| debug=True | |
| ) |