gradio-doc / app.py
sudipta26889's picture
Switch to Qwen3-30B model with fixed duplicate torch_dtype parameter and improved CPU loading
8fd3d30
# app.py
# Hugging Face Space: Gradio Docs Chat with GPT-OSS-20B and MCP Integration
# Features:
# β€’ GPT-OSS-20B with local transformers loading for fast inference
# β€’ CPU-only loading to avoid CUDA initialization issues
# β€’ MCP tool-calling for Gradio docs access
# β€’ Streaming responses with live tool logs
# β€’ Optional "Concise / Detailed" answer styles
# β€’ Citations panel for source tracking
#
# Space secrets needed:
# - HUGGING_FACE_HUB_TOKEN or HF_TOKEN (for model access)
# - Optional: CHAT_MODEL, CHAT_PROVIDER, GRADIO_DOCS_MCP_SSE
import os
import asyncio
import json
import time
from typing import Any, Dict, Iterable, List, Optional, Tuple
# Force CPU-only environment to avoid CUDA initialization
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["USE_CUDA"] = "0"
os.environ["USE_GPU"] = "0"
# Load environment variables from .env file if it exists
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
import gradio as gr
# Try to import MCPClient with fallback
try:
from huggingface_hub import MCPClient
MCP_AVAILABLE = True
except ImportError:
MCPClient = None
MCP_AVAILABLE = False
print("Warning: MCPClient not available. Install huggingface_hub>=0.34.0")
# Optional ZeroGPU shim for Hugging Face Spaces
SPACES_ZERO_GPU = bool(os.environ.get("SPACE_ZERO_GPU", ""))
try:
import spaces # type: ignore
if spaces is not None:
@spaces.GPU
def _zero_gpu_probe():
return "ok"
except Exception:
pass
# ----------------------------
# Configuration
# ----------------------------
GRADIO_DOCS_MCP_SSE = os.environ.get(
"GRADIO_DOCS_MCP_SSE",
"https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse",
)
# Model configuration - local Qwen model loading (more efficient than GPT-OSS-20B)
MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507"
PROVIDER = os.environ.get("CHAT_PROVIDER", "auto")
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
# Model-specific configuration
USE_GPT_OSS = True
BASE_SYSTEM_PROMPT = (
"You are a helpful assistant that answers strictly using the Gradio documentation "
"via the MCP tools provided by the Gradio Docs MCP server. Prefer the latest docs. "
"Cite relevant class/function names (e.g., gr.Interface) and include short code examples when helpful."
)
CONCISE_SUFFIX = " Keep answers concise (3–6 sentences) unless code is necessary."
DETAILED_SUFFIX = " Provide a detailed, step-by-step answer with short code where helpful."
# ----------------------------
# Model Clients (lazy initialization)
# ----------------------------
mcp_client: Optional[MCPClient] = None
gpt_oss_tokenizer = None
gpt_oss_model = None
_initialized = False
_init_lock = asyncio.Lock()
_model_loading_lock = asyncio.Lock()
def _current_system_prompt(style: str) -> str:
"""Get the system prompt with style suffix."""
return BASE_SYSTEM_PROMPT + (CONCISE_SUFFIX if style == "Concise" else DETAILED_SUFFIX)
def _reset_clients():
"""Reset all global clients."""
global mcp_client, gpt_oss_tokenizer, gpt_oss_model, _initialized
mcp_client = None
gpt_oss_tokenizer = None
gpt_oss_model = None
_initialized = False
def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPClient:
"""Get or create MCP client."""
global mcp_client
if mcp_client is None:
if not MCP_AVAILABLE:
raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0")
mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key)
return mcp_client
async def get_gpt_oss_model_and_tokenizer():
"""Get or create GPT-OSS-20B model and tokenizer with strict CPU-only loading."""
global gpt_oss_tokenizer, gpt_oss_model
# Check if already loaded
if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
return gpt_oss_tokenizer, gpt_oss_model
# Use lock to prevent multiple simultaneous loads
async with _model_loading_lock:
# Double-check after acquiring lock
if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
return gpt_oss_tokenizer, gpt_oss_model
try:
# Import here to avoid CUDA initialization in main process
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Force CPU-only torch configuration
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0
print("πŸ”„ Loading Qwen tokenizer...")
gpt_oss_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("πŸ”„ Loading Qwen model (CPU-only)...")
# Clean CPU-only loading configuration (no duplicate parameters)
gpt_oss_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32, # Use float32 for CPU compatibility
device_map=None, # Don't use device mapping
trust_remote_code=True,
low_cpu_mem_usage=True,
)
# Explicitly move to CPU
gpt_oss_model = gpt_oss_model.to("cpu")
# Set model to evaluation mode
gpt_oss_model.eval()
print("βœ… Qwen model loaded successfully on CPU!")
return gpt_oss_tokenizer, gpt_oss_model
except Exception as e:
print(f"❌ Failed to load Qwen model: {e}")
# Reset globals on error
gpt_oss_tokenizer = None
gpt_oss_model = None
raise e
async def generate_with_gpt_oss(messages: List[Dict[str, Any]]) -> str:
"""Generate response using local GPT-OSS-20B model."""
try:
# Lazy load model only when needed
tokenizer, model = await get_gpt_oss_model_and_tokenizer()
# Convert messages to Qwen format
qwen_messages = []
for msg in messages:
if msg["role"] == "system":
# Convert system message to user message for Qwen
qwen_messages.append({
"role": "user",
"content": f"System: {msg['content']}"
})
else:
qwen_messages.append(msg)
# Apply chat template and generate
text = tokenizer.apply_chat_template(
qwen_messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Generate with timeout protection
try:
import torch
with torch.no_grad(): # Disable gradients for inference
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512, # Reduced from 16384 for faster response
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
max_time=60.0, # 60 second timeout
)
except Exception as gen_error:
raise Exception(f"Generation Error: {str(gen_error)}")
# Decode the generated text
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)
return generated_text
except Exception as e:
raise Exception(f"Qwen Model Error: {str(e)}")
async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
"""Initialize MCP server connection."""
global _initialized
if _initialized:
return
async with _init_lock:
if _initialized:
return
if not MCP_AVAILABLE:
raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0")
client = get_mcp_client(model_id, provider, api_key)
result = client.add_mcp_server(
type="sse",
url=GRADIO_DOCS_MCP_SSE,
timeout=45,
)
# Handle async/sync versions
if hasattr(result, '__await__'):
await result
_initialized = True
# ----------------------------
# Message Processing
# ----------------------------
def to_llm_messages(history_msgs: List[Dict[str, Any]], user_msg: str, style: str) -> List[Dict[str, Any]]:
"""Convert chat history to LLM format with system prompt."""
msgs = [{"role": "system", "content": _current_system_prompt(style)}]
for msg in history_msgs or []:
role = msg.get("role")
content = msg.get("content")
if role in ("user", "assistant") and isinstance(content, str):
msgs.append({"role": role, "content": content})
msgs.append({"role": "user", "content": user_msg})
return msgs
# ----------------------------
# Response Formatting
# ----------------------------
def _append_log(log_lines: List[str], line: str, max_lines: int = 200) -> None:
"""Append log line with size limit."""
log_lines.append(line)
if len(log_lines) > max_lines:
del log_lines[: len(log_lines) - max_lines]
def _format_tool_log(log_lines: List[str]) -> str:
"""Format tool log for display."""
return "\n".join(log_lines) if log_lines else "_No tool activity yet._"
def _format_citations(cites: List[Tuple[str, Optional[str]]]) -> str:
"""Format citations for display."""
if not cites:
return "_No citations captured yet._"
recent = cites[-12:] # Show recent citations
lines = []
for label, url in recent:
if url:
lines.append(f"- **{label}** β€” {url}")
else:
lines.append(f"- **{label}**")
return "\n".join(lines)
# ----------------------------
# Response Streaming
# ----------------------------
async def stream_answer(
messages_for_llm: List[Dict[str, Any]],
model_id: str,
provider: str,
api_key: Optional[str],
) -> Iterable[Dict[str, Any]]:
"""Stream responses from either GPT-OSS-20B or MCP client."""
tool_log: List[str] = []
citations: List[Tuple[str, Optional[str]]] = []
# Handle GPT-OSS-20B via local model
if USE_GPT_OSS:
try:
# Generate response using local model
generated_text = await generate_with_gpt_oss(messages_for_llm)
# Stream character by character
for char in generated_text:
yield {
"delta": char,
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations),
}
except Exception as e:
yield {
"delta": f"❌ {str(e)}",
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations),
}
return
# Handle MCP-based models
if not MCP_AVAILABLE:
yield {
"delta": "❌ MCPClient not available. Install huggingface_hub>=0.34.0",
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations),
}
return
if not api_key:
yield {
"delta": "⚠️ Missing token: set HUGGING_FACE_HUB_TOKEN or HF_TOKEN",
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations),
}
return
try:
await ensure_mcp_init(model_id, provider, api_key)
client = get_mcp_client(model_id, provider, api_key)
except Exception as e:
yield {
"delta": f"❌ Failed to initialize MCP client: {str(e)}",
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations),
}
return
# Stream MCP responses
try:
async for chunk in client.process_single_turn_with_tools(messages_for_llm):
if isinstance(chunk, dict):
ctype = chunk.get("type")
if ctype == "tool_log":
name = chunk.get("tool", "tool")
status = chunk.get("status", "")
_append_log(tool_log, f"- {name} **{status}**")
yield {
"delta": "",
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations)
}
elif ctype == "text_delta":
yield {
"delta": chunk.get("delta", ""),
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations)
}
elif ctype == "text":
yield {
"delta": chunk.get("text", ""),
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations)
}
elif ctype == "tool_result":
# Process tool results and citations
tool_name = chunk.get("tool", "tool")
content = chunk.get("content")
# Extract citation info
url = None
if isinstance(content, dict):
url = content.get("url") or content.get("link")
title = content.get("title") or content.get("name")
label = title or tool_name
elif isinstance(content, str):
label = tool_name
if "http://" in content or "https://" in content:
start = content.find("http")
url = content[start : start + 200].split("\n")[0].strip()
else:
label = tool_name
citations.append((label, url))
_append_log(tool_log, f" β€’ {tool_name} returned result")
# Format content snippet
snippet = ""
if isinstance(content, str):
snippet = content.strip()
if len(snippet) > 700:
snippet = snippet[:700] + "…"
snippet = f"\n\n**Result (from {tool_name}):**\n{snippet}"
yield {
"delta": snippet,
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations),
}
else:
# Fallback for plain string responses
yield {
"delta": str(chunk),
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations)
}
except Exception as e:
msg = str(e)
if "401" in msg or "Unauthorized" in msg:
err = f"❌ Unauthorized (401). Check your token permissions.\nModel: `{model_id}`\nProvider: `{provider}`"
elif "404" in msg or "Not Found" in msg:
err = f"❌ Model not found (404). Model `{model_id}` may not be available via hf-inference."
else:
err = f"❌ Error: {msg}"
yield {
"delta": err,
"tool_log": _format_tool_log(tool_log),
"citations": _format_citations(citations)
}
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(
"# πŸ€– Gradio Docs Chat\n"
"Ask anything about **Gradio**. Powered by Qwen3-30B with local transformers loading."
)
with gr.Row():
with gr.Column(scale=7):
chat = gr.Chatbot(
label="Gradio Docs Assistant",
height=520,
type="messages",
)
with gr.Row():
msg = gr.Textbox(
placeholder="e.g., How do I use gr.Interface with multiple inputs?",
scale=9,
autofocus=True,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
clear = gr.ClearButton(components=[chat, msg], value="Clear")
with gr.Column(scale=5):
with gr.Accordion("βš™οΈ Settings", open=False):
style = gr.Radio(
label="Answer Style",
choices=["Concise", "Detailed"],
value="Detailed",
)
model_info = gr.Markdown(
f"**Model:** `{MODEL_ID}` (Local Loading) \n"
f"**Provider:** `{PROVIDER}` \n"
"_(CPU-only loading for stable inference)_"
)
with gr.Accordion("πŸ›  Tool Activity (live)", open=True):
tool_log_md = gr.Markdown("_No tool activity yet._")
with gr.Accordion("πŸ“Ž Citations (recent)", open=True):
citations_md = gr.Markdown("_No citations captured yet._")
async def on_submit(user_msg: str, history_msgs: List[Dict[str, Any]], style_choice: str):
"""Handle user message submission and stream response."""
history_msgs = (history_msgs or []) + [{"role": "user", "content": user_msg}]
history_msgs.append({"role": "assistant", "content": ""})
yield history_msgs, gr.update(value="_No tool activity yet._"), gr.update(value="_No citations captured yet._")
messages_for_llm = to_llm_messages(history_msgs[:-1], user_msg, style_choice)
async for chunk in stream_answer(messages_for_llm, MODEL_ID, PROVIDER, HF_TOKEN):
delta = chunk.get("delta", "")
if delta:
history_msgs[-1]["content"] += delta
yield history_msgs, gr.update(value=chunk.get("tool_log", "")), gr.update(value=chunk.get("citations", ""))
# Wire up event handlers
msg.submit(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)
send_btn.click(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)
# ----------------------------
# Launch App
# ----------------------------
print(f"πŸš€ Starting Gradio Docs Chat with Qwen3-30B (Local Loading)")
print(f"πŸ“ Model: {MODEL_ID}")
print(f"πŸ”— MCP Server: {GRADIO_DOCS_MCP_SSE}")
demo = demo.queue(max_size=32)
demo.launch(ssr_mode=False)