# app.py # Hugging Face Space: Gradio Docs Chat with GPT-OSS-20B and MCP Integration # Features: # • GPT-OSS-20B with local transformers loading for fast inference # • CPU-only loading to avoid CUDA initialization issues # • MCP tool-calling for Gradio docs access # • Streaming responses with live tool logs # • Optional "Concise / Detailed" answer styles # • Citations panel for source tracking # # Space secrets needed: # - HUGGING_FACE_HUB_TOKEN or HF_TOKEN (for model access) # - Optional: CHAT_MODEL, CHAT_PROVIDER, GRADIO_DOCS_MCP_SSE import os import asyncio import json import time from typing import Any, Dict, Iterable, List, Optional, Tuple # Force CPU-only environment to avoid CUDA initialization os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["USE_CUDA"] = "0" os.environ["USE_GPU"] = "0" # Load environment variables from .env file if it exists try: from dotenv import load_dotenv load_dotenv() except ImportError: pass import gradio as gr # Try to import MCPClient with fallback try: from huggingface_hub import MCPClient MCP_AVAILABLE = True except ImportError: MCPClient = None MCP_AVAILABLE = False print("Warning: MCPClient not available. Install huggingface_hub>=0.34.0") # Optional ZeroGPU shim for Hugging Face Spaces SPACES_ZERO_GPU = bool(os.environ.get("SPACE_ZERO_GPU", "")) try: import spaces # type: ignore if spaces is not None: @spaces.GPU def _zero_gpu_probe(): return "ok" except Exception: pass # ---------------------------- # Configuration # ---------------------------- GRADIO_DOCS_MCP_SSE = os.environ.get( "GRADIO_DOCS_MCP_SSE", "https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse", ) # Model configuration - local Qwen model loading (more efficient than GPT-OSS-20B) MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507" PROVIDER = os.environ.get("CHAT_PROVIDER", "auto") HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") # Model-specific configuration USE_GPT_OSS = True BASE_SYSTEM_PROMPT = ( "You are a helpful assistant that answers strictly using the Gradio documentation " "via the MCP tools provided by the Gradio Docs MCP server. Prefer the latest docs. " "Cite relevant class/function names (e.g., gr.Interface) and include short code examples when helpful." ) CONCISE_SUFFIX = " Keep answers concise (3–6 sentences) unless code is necessary." DETAILED_SUFFIX = " Provide a detailed, step-by-step answer with short code where helpful." # ---------------------------- # Model Clients (lazy initialization) # ---------------------------- mcp_client: Optional[MCPClient] = None gpt_oss_tokenizer = None gpt_oss_model = None _initialized = False _init_lock = asyncio.Lock() _model_loading_lock = asyncio.Lock() def _current_system_prompt(style: str) -> str: """Get the system prompt with style suffix.""" return BASE_SYSTEM_PROMPT + (CONCISE_SUFFIX if style == "Concise" else DETAILED_SUFFIX) def _reset_clients(): """Reset all global clients.""" global mcp_client, gpt_oss_tokenizer, gpt_oss_model, _initialized mcp_client = None gpt_oss_tokenizer = None gpt_oss_model = None _initialized = False def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPClient: """Get or create MCP client.""" global mcp_client if mcp_client is None: if not MCP_AVAILABLE: raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0") mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key) return mcp_client async def get_gpt_oss_model_and_tokenizer(): """Get or create GPT-OSS-20B model and tokenizer with strict CPU-only loading.""" global gpt_oss_tokenizer, gpt_oss_model # Check if already loaded if gpt_oss_tokenizer is not None and gpt_oss_model is not None: return gpt_oss_tokenizer, gpt_oss_model # Use lock to prevent multiple simultaneous loads async with _model_loading_lock: # Double-check after acquiring lock if gpt_oss_tokenizer is not None and gpt_oss_model is not None: return gpt_oss_tokenizer, gpt_oss_model try: # Import here to avoid CUDA initialization in main process import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Force CPU-only torch configuration torch.cuda.is_available = lambda: False torch.cuda.device_count = lambda: 0 print("🔄 Loading Qwen tokenizer...") gpt_oss_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print("🔄 Loading Qwen model (CPU-only)...") # Clean CPU-only loading configuration (no duplicate parameters) gpt_oss_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, # Use float32 for CPU compatibility device_map=None, # Don't use device mapping trust_remote_code=True, low_cpu_mem_usage=True, ) # Explicitly move to CPU gpt_oss_model = gpt_oss_model.to("cpu") # Set model to evaluation mode gpt_oss_model.eval() print("✅ Qwen model loaded successfully on CPU!") return gpt_oss_tokenizer, gpt_oss_model except Exception as e: print(f"❌ Failed to load Qwen model: {e}") # Reset globals on error gpt_oss_tokenizer = None gpt_oss_model = None raise e async def generate_with_gpt_oss(messages: List[Dict[str, Any]]) -> str: """Generate response using local GPT-OSS-20B model.""" try: # Lazy load model only when needed tokenizer, model = await get_gpt_oss_model_and_tokenizer() # Convert messages to Qwen format qwen_messages = [] for msg in messages: if msg["role"] == "system": # Convert system message to user message for Qwen qwen_messages.append({ "role": "user", "content": f"System: {msg['content']}" }) else: qwen_messages.append(msg) # Apply chat template and generate text = tokenizer.apply_chat_template( qwen_messages, tokenize=False, add_generation_prompt=True, ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Generate with timeout protection try: import torch with torch.no_grad(): # Disable gradients for inference generated_ids = model.generate( **model_inputs, max_new_tokens=512, # Reduced from 16384 for faster response do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id, max_time=60.0, # 60 second timeout ) except Exception as gen_error: raise Exception(f"Generation Error: {str(gen_error)}") # Decode the generated text output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() generated_text = tokenizer.decode(output_ids, skip_special_tokens=True) return generated_text except Exception as e: raise Exception(f"Qwen Model Error: {str(e)}") async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]): """Initialize MCP server connection.""" global _initialized if _initialized: return async with _init_lock: if _initialized: return if not MCP_AVAILABLE: raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0") client = get_mcp_client(model_id, provider, api_key) result = client.add_mcp_server( type="sse", url=GRADIO_DOCS_MCP_SSE, timeout=45, ) # Handle async/sync versions if hasattr(result, '__await__'): await result _initialized = True # ---------------------------- # Message Processing # ---------------------------- def to_llm_messages(history_msgs: List[Dict[str, Any]], user_msg: str, style: str) -> List[Dict[str, Any]]: """Convert chat history to LLM format with system prompt.""" msgs = [{"role": "system", "content": _current_system_prompt(style)}] for msg in history_msgs or []: role = msg.get("role") content = msg.get("content") if role in ("user", "assistant") and isinstance(content, str): msgs.append({"role": role, "content": content}) msgs.append({"role": "user", "content": user_msg}) return msgs # ---------------------------- # Response Formatting # ---------------------------- def _append_log(log_lines: List[str], line: str, max_lines: int = 200) -> None: """Append log line with size limit.""" log_lines.append(line) if len(log_lines) > max_lines: del log_lines[: len(log_lines) - max_lines] def _format_tool_log(log_lines: List[str]) -> str: """Format tool log for display.""" return "\n".join(log_lines) if log_lines else "_No tool activity yet._" def _format_citations(cites: List[Tuple[str, Optional[str]]]) -> str: """Format citations for display.""" if not cites: return "_No citations captured yet._" recent = cites[-12:] # Show recent citations lines = [] for label, url in recent: if url: lines.append(f"- **{label}** — {url}") else: lines.append(f"- **{label}**") return "\n".join(lines) # ---------------------------- # Response Streaming # ---------------------------- async def stream_answer( messages_for_llm: List[Dict[str, Any]], model_id: str, provider: str, api_key: Optional[str], ) -> Iterable[Dict[str, Any]]: """Stream responses from either GPT-OSS-20B or MCP client.""" tool_log: List[str] = [] citations: List[Tuple[str, Optional[str]]] = [] # Handle GPT-OSS-20B via local model if USE_GPT_OSS: try: # Generate response using local model generated_text = await generate_with_gpt_oss(messages_for_llm) # Stream character by character for char in generated_text: yield { "delta": char, "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations), } except Exception as e: yield { "delta": f"❌ {str(e)}", "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations), } return # Handle MCP-based models if not MCP_AVAILABLE: yield { "delta": "❌ MCPClient not available. Install huggingface_hub>=0.34.0", "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations), } return if not api_key: yield { "delta": "⚠️ Missing token: set HUGGING_FACE_HUB_TOKEN or HF_TOKEN", "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations), } return try: await ensure_mcp_init(model_id, provider, api_key) client = get_mcp_client(model_id, provider, api_key) except Exception as e: yield { "delta": f"❌ Failed to initialize MCP client: {str(e)}", "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations), } return # Stream MCP responses try: async for chunk in client.process_single_turn_with_tools(messages_for_llm): if isinstance(chunk, dict): ctype = chunk.get("type") if ctype == "tool_log": name = chunk.get("tool", "tool") status = chunk.get("status", "") _append_log(tool_log, f"- {name} **{status}**") yield { "delta": "", "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations) } elif ctype == "text_delta": yield { "delta": chunk.get("delta", ""), "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations) } elif ctype == "text": yield { "delta": chunk.get("text", ""), "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations) } elif ctype == "tool_result": # Process tool results and citations tool_name = chunk.get("tool", "tool") content = chunk.get("content") # Extract citation info url = None if isinstance(content, dict): url = content.get("url") or content.get("link") title = content.get("title") or content.get("name") label = title or tool_name elif isinstance(content, str): label = tool_name if "http://" in content or "https://" in content: start = content.find("http") url = content[start : start + 200].split("\n")[0].strip() else: label = tool_name citations.append((label, url)) _append_log(tool_log, f" • {tool_name} returned result") # Format content snippet snippet = "" if isinstance(content, str): snippet = content.strip() if len(snippet) > 700: snippet = snippet[:700] + "…" snippet = f"\n\n**Result (from {tool_name}):**\n{snippet}" yield { "delta": snippet, "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations), } else: # Fallback for plain string responses yield { "delta": str(chunk), "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations) } except Exception as e: msg = str(e) if "401" in msg or "Unauthorized" in msg: err = f"❌ Unauthorized (401). Check your token permissions.\nModel: `{model_id}`\nProvider: `{provider}`" elif "404" in msg or "Not Found" in msg: err = f"❌ Model not found (404). Model `{model_id}` may not be available via hf-inference." else: err = f"❌ Error: {msg}" yield { "delta": err, "tool_log": _format_tool_log(tool_log), "citations": _format_citations(citations) } # ---------------------------- # Gradio UI # ---------------------------- with gr.Blocks(fill_height=True) as demo: gr.Markdown( "# 🤖 Gradio Docs Chat\n" "Ask anything about **Gradio**. Powered by Qwen3-30B with local transformers loading." ) with gr.Row(): with gr.Column(scale=7): chat = gr.Chatbot( label="Gradio Docs Assistant", height=520, type="messages", ) with gr.Row(): msg = gr.Textbox( placeholder="e.g., How do I use gr.Interface with multiple inputs?", scale=9, autofocus=True, ) send_btn = gr.Button("Send", scale=1, variant="primary") clear = gr.ClearButton(components=[chat, msg], value="Clear") with gr.Column(scale=5): with gr.Accordion("⚙️ Settings", open=False): style = gr.Radio( label="Answer Style", choices=["Concise", "Detailed"], value="Detailed", ) model_info = gr.Markdown( f"**Model:** `{MODEL_ID}` (Local Loading) \n" f"**Provider:** `{PROVIDER}` \n" "_(CPU-only loading for stable inference)_" ) with gr.Accordion("🛠 Tool Activity (live)", open=True): tool_log_md = gr.Markdown("_No tool activity yet._") with gr.Accordion("📎 Citations (recent)", open=True): citations_md = gr.Markdown("_No citations captured yet._") async def on_submit(user_msg: str, history_msgs: List[Dict[str, Any]], style_choice: str): """Handle user message submission and stream response.""" history_msgs = (history_msgs or []) + [{"role": "user", "content": user_msg}] history_msgs.append({"role": "assistant", "content": ""}) yield history_msgs, gr.update(value="_No tool activity yet._"), gr.update(value="_No citations captured yet._") messages_for_llm = to_llm_messages(history_msgs[:-1], user_msg, style_choice) async for chunk in stream_answer(messages_for_llm, MODEL_ID, PROVIDER, HF_TOKEN): delta = chunk.get("delta", "") if delta: history_msgs[-1]["content"] += delta yield history_msgs, gr.update(value=chunk.get("tool_log", "")), gr.update(value=chunk.get("citations", "")) # Wire up event handlers msg.submit(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True) send_btn.click(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True) # ---------------------------- # Launch App # ---------------------------- print(f"🚀 Starting Gradio Docs Chat with Qwen3-30B (Local Loading)") print(f"📁 Model: {MODEL_ID}") print(f"🔗 MCP Server: {GRADIO_DOCS_MCP_SSE}") demo = demo.queue(max_size=32) demo.launch(ssr_mode=False)