Spaces:
Sleeping
Sleeping

Switch to Qwen3-30B model with fixed duplicate torch_dtype parameter and improved CPU loading
8fd3d30
# app.py | |
# Hugging Face Space: Gradio Docs Chat with GPT-OSS-20B and MCP Integration | |
# Features: | |
# β’ GPT-OSS-20B with local transformers loading for fast inference | |
# β’ CPU-only loading to avoid CUDA initialization issues | |
# β’ MCP tool-calling for Gradio docs access | |
# β’ Streaming responses with live tool logs | |
# β’ Optional "Concise / Detailed" answer styles | |
# β’ Citations panel for source tracking | |
# | |
# Space secrets needed: | |
# - HUGGING_FACE_HUB_TOKEN or HF_TOKEN (for model access) | |
# - Optional: CHAT_MODEL, CHAT_PROVIDER, GRADIO_DOCS_MCP_SSE | |
import os | |
import asyncio | |
import json | |
import time | |
from typing import Any, Dict, Iterable, List, Optional, Tuple | |
# Force CPU-only environment to avoid CUDA initialization | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
os.environ["USE_CUDA"] = "0" | |
os.environ["USE_GPU"] = "0" | |
# Load environment variables from .env file if it exists | |
try: | |
from dotenv import load_dotenv | |
load_dotenv() | |
except ImportError: | |
pass | |
import gradio as gr | |
# Try to import MCPClient with fallback | |
try: | |
from huggingface_hub import MCPClient | |
MCP_AVAILABLE = True | |
except ImportError: | |
MCPClient = None | |
MCP_AVAILABLE = False | |
print("Warning: MCPClient not available. Install huggingface_hub>=0.34.0") | |
# Optional ZeroGPU shim for Hugging Face Spaces | |
SPACES_ZERO_GPU = bool(os.environ.get("SPACE_ZERO_GPU", "")) | |
try: | |
import spaces # type: ignore | |
if spaces is not None: | |
def _zero_gpu_probe(): | |
return "ok" | |
except Exception: | |
pass | |
# ---------------------------- | |
# Configuration | |
# ---------------------------- | |
GRADIO_DOCS_MCP_SSE = os.environ.get( | |
"GRADIO_DOCS_MCP_SSE", | |
"https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse", | |
) | |
# Model configuration - local Qwen model loading (more efficient than GPT-OSS-20B) | |
MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507" | |
PROVIDER = os.environ.get("CHAT_PROVIDER", "auto") | |
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
# Model-specific configuration | |
USE_GPT_OSS = True | |
BASE_SYSTEM_PROMPT = ( | |
"You are a helpful assistant that answers strictly using the Gradio documentation " | |
"via the MCP tools provided by the Gradio Docs MCP server. Prefer the latest docs. " | |
"Cite relevant class/function names (e.g., gr.Interface) and include short code examples when helpful." | |
) | |
CONCISE_SUFFIX = " Keep answers concise (3β6 sentences) unless code is necessary." | |
DETAILED_SUFFIX = " Provide a detailed, step-by-step answer with short code where helpful." | |
# ---------------------------- | |
# Model Clients (lazy initialization) | |
# ---------------------------- | |
mcp_client: Optional[MCPClient] = None | |
gpt_oss_tokenizer = None | |
gpt_oss_model = None | |
_initialized = False | |
_init_lock = asyncio.Lock() | |
_model_loading_lock = asyncio.Lock() | |
def _current_system_prompt(style: str) -> str: | |
"""Get the system prompt with style suffix.""" | |
return BASE_SYSTEM_PROMPT + (CONCISE_SUFFIX if style == "Concise" else DETAILED_SUFFIX) | |
def _reset_clients(): | |
"""Reset all global clients.""" | |
global mcp_client, gpt_oss_tokenizer, gpt_oss_model, _initialized | |
mcp_client = None | |
gpt_oss_tokenizer = None | |
gpt_oss_model = None | |
_initialized = False | |
def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPClient: | |
"""Get or create MCP client.""" | |
global mcp_client | |
if mcp_client is None: | |
if not MCP_AVAILABLE: | |
raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0") | |
mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key) | |
return mcp_client | |
async def get_gpt_oss_model_and_tokenizer(): | |
"""Get or create GPT-OSS-20B model and tokenizer with strict CPU-only loading.""" | |
global gpt_oss_tokenizer, gpt_oss_model | |
# Check if already loaded | |
if gpt_oss_tokenizer is not None and gpt_oss_model is not None: | |
return gpt_oss_tokenizer, gpt_oss_model | |
# Use lock to prevent multiple simultaneous loads | |
async with _model_loading_lock: | |
# Double-check after acquiring lock | |
if gpt_oss_tokenizer is not None and gpt_oss_model is not None: | |
return gpt_oss_tokenizer, gpt_oss_model | |
try: | |
# Import here to avoid CUDA initialization in main process | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Force CPU-only torch configuration | |
torch.cuda.is_available = lambda: False | |
torch.cuda.device_count = lambda: 0 | |
print("π Loading Qwen tokenizer...") | |
gpt_oss_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print("π Loading Qwen model (CPU-only)...") | |
# Clean CPU-only loading configuration (no duplicate parameters) | |
gpt_oss_model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float32, # Use float32 for CPU compatibility | |
device_map=None, # Don't use device mapping | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
) | |
# Explicitly move to CPU | |
gpt_oss_model = gpt_oss_model.to("cpu") | |
# Set model to evaluation mode | |
gpt_oss_model.eval() | |
print("β Qwen model loaded successfully on CPU!") | |
return gpt_oss_tokenizer, gpt_oss_model | |
except Exception as e: | |
print(f"β Failed to load Qwen model: {e}") | |
# Reset globals on error | |
gpt_oss_tokenizer = None | |
gpt_oss_model = None | |
raise e | |
async def generate_with_gpt_oss(messages: List[Dict[str, Any]]) -> str: | |
"""Generate response using local GPT-OSS-20B model.""" | |
try: | |
# Lazy load model only when needed | |
tokenizer, model = await get_gpt_oss_model_and_tokenizer() | |
# Convert messages to Qwen format | |
qwen_messages = [] | |
for msg in messages: | |
if msg["role"] == "system": | |
# Convert system message to user message for Qwen | |
qwen_messages.append({ | |
"role": "user", | |
"content": f"System: {msg['content']}" | |
}) | |
else: | |
qwen_messages.append(msg) | |
# Apply chat template and generate | |
text = tokenizer.apply_chat_template( | |
qwen_messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
# Generate with timeout protection | |
try: | |
import torch | |
with torch.no_grad(): # Disable gradients for inference | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=512, # Reduced from 16384 for faster response | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id, | |
max_time=60.0, # 60 second timeout | |
) | |
except Exception as gen_error: | |
raise Exception(f"Generation Error: {str(gen_error)}") | |
# Decode the generated text | |
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | |
generated_text = tokenizer.decode(output_ids, skip_special_tokens=True) | |
return generated_text | |
except Exception as e: | |
raise Exception(f"Qwen Model Error: {str(e)}") | |
async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]): | |
"""Initialize MCP server connection.""" | |
global _initialized | |
if _initialized: | |
return | |
async with _init_lock: | |
if _initialized: | |
return | |
if not MCP_AVAILABLE: | |
raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0") | |
client = get_mcp_client(model_id, provider, api_key) | |
result = client.add_mcp_server( | |
type="sse", | |
url=GRADIO_DOCS_MCP_SSE, | |
timeout=45, | |
) | |
# Handle async/sync versions | |
if hasattr(result, '__await__'): | |
await result | |
_initialized = True | |
# ---------------------------- | |
# Message Processing | |
# ---------------------------- | |
def to_llm_messages(history_msgs: List[Dict[str, Any]], user_msg: str, style: str) -> List[Dict[str, Any]]: | |
"""Convert chat history to LLM format with system prompt.""" | |
msgs = [{"role": "system", "content": _current_system_prompt(style)}] | |
for msg in history_msgs or []: | |
role = msg.get("role") | |
content = msg.get("content") | |
if role in ("user", "assistant") and isinstance(content, str): | |
msgs.append({"role": role, "content": content}) | |
msgs.append({"role": "user", "content": user_msg}) | |
return msgs | |
# ---------------------------- | |
# Response Formatting | |
# ---------------------------- | |
def _append_log(log_lines: List[str], line: str, max_lines: int = 200) -> None: | |
"""Append log line with size limit.""" | |
log_lines.append(line) | |
if len(log_lines) > max_lines: | |
del log_lines[: len(log_lines) - max_lines] | |
def _format_tool_log(log_lines: List[str]) -> str: | |
"""Format tool log for display.""" | |
return "\n".join(log_lines) if log_lines else "_No tool activity yet._" | |
def _format_citations(cites: List[Tuple[str, Optional[str]]]) -> str: | |
"""Format citations for display.""" | |
if not cites: | |
return "_No citations captured yet._" | |
recent = cites[-12:] # Show recent citations | |
lines = [] | |
for label, url in recent: | |
if url: | |
lines.append(f"- **{label}** β {url}") | |
else: | |
lines.append(f"- **{label}**") | |
return "\n".join(lines) | |
# ---------------------------- | |
# Response Streaming | |
# ---------------------------- | |
async def stream_answer( | |
messages_for_llm: List[Dict[str, Any]], | |
model_id: str, | |
provider: str, | |
api_key: Optional[str], | |
) -> Iterable[Dict[str, Any]]: | |
"""Stream responses from either GPT-OSS-20B or MCP client.""" | |
tool_log: List[str] = [] | |
citations: List[Tuple[str, Optional[str]]] = [] | |
# Handle GPT-OSS-20B via local model | |
if USE_GPT_OSS: | |
try: | |
# Generate response using local model | |
generated_text = await generate_with_gpt_oss(messages_for_llm) | |
# Stream character by character | |
for char in generated_text: | |
yield { | |
"delta": char, | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations), | |
} | |
except Exception as e: | |
yield { | |
"delta": f"β {str(e)}", | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations), | |
} | |
return | |
# Handle MCP-based models | |
if not MCP_AVAILABLE: | |
yield { | |
"delta": "β MCPClient not available. Install huggingface_hub>=0.34.0", | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations), | |
} | |
return | |
if not api_key: | |
yield { | |
"delta": "β οΈ Missing token: set HUGGING_FACE_HUB_TOKEN or HF_TOKEN", | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations), | |
} | |
return | |
try: | |
await ensure_mcp_init(model_id, provider, api_key) | |
client = get_mcp_client(model_id, provider, api_key) | |
except Exception as e: | |
yield { | |
"delta": f"β Failed to initialize MCP client: {str(e)}", | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations), | |
} | |
return | |
# Stream MCP responses | |
try: | |
async for chunk in client.process_single_turn_with_tools(messages_for_llm): | |
if isinstance(chunk, dict): | |
ctype = chunk.get("type") | |
if ctype == "tool_log": | |
name = chunk.get("tool", "tool") | |
status = chunk.get("status", "") | |
_append_log(tool_log, f"- {name} **{status}**") | |
yield { | |
"delta": "", | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations) | |
} | |
elif ctype == "text_delta": | |
yield { | |
"delta": chunk.get("delta", ""), | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations) | |
} | |
elif ctype == "text": | |
yield { | |
"delta": chunk.get("text", ""), | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations) | |
} | |
elif ctype == "tool_result": | |
# Process tool results and citations | |
tool_name = chunk.get("tool", "tool") | |
content = chunk.get("content") | |
# Extract citation info | |
url = None | |
if isinstance(content, dict): | |
url = content.get("url") or content.get("link") | |
title = content.get("title") or content.get("name") | |
label = title or tool_name | |
elif isinstance(content, str): | |
label = tool_name | |
if "http://" in content or "https://" in content: | |
start = content.find("http") | |
url = content[start : start + 200].split("\n")[0].strip() | |
else: | |
label = tool_name | |
citations.append((label, url)) | |
_append_log(tool_log, f" β’ {tool_name} returned result") | |
# Format content snippet | |
snippet = "" | |
if isinstance(content, str): | |
snippet = content.strip() | |
if len(snippet) > 700: | |
snippet = snippet[:700] + "β¦" | |
snippet = f"\n\n**Result (from {tool_name}):**\n{snippet}" | |
yield { | |
"delta": snippet, | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations), | |
} | |
else: | |
# Fallback for plain string responses | |
yield { | |
"delta": str(chunk), | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations) | |
} | |
except Exception as e: | |
msg = str(e) | |
if "401" in msg or "Unauthorized" in msg: | |
err = f"β Unauthorized (401). Check your token permissions.\nModel: `{model_id}`\nProvider: `{provider}`" | |
elif "404" in msg or "Not Found" in msg: | |
err = f"β Model not found (404). Model `{model_id}` may not be available via hf-inference." | |
else: | |
err = f"β Error: {msg}" | |
yield { | |
"delta": err, | |
"tool_log": _format_tool_log(tool_log), | |
"citations": _format_citations(citations) | |
} | |
# ---------------------------- | |
# Gradio UI | |
# ---------------------------- | |
with gr.Blocks(fill_height=True) as demo: | |
gr.Markdown( | |
"# π€ Gradio Docs Chat\n" | |
"Ask anything about **Gradio**. Powered by Qwen3-30B with local transformers loading." | |
) | |
with gr.Row(): | |
with gr.Column(scale=7): | |
chat = gr.Chatbot( | |
label="Gradio Docs Assistant", | |
height=520, | |
type="messages", | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="e.g., How do I use gr.Interface with multiple inputs?", | |
scale=9, | |
autofocus=True, | |
) | |
send_btn = gr.Button("Send", scale=1, variant="primary") | |
clear = gr.ClearButton(components=[chat, msg], value="Clear") | |
with gr.Column(scale=5): | |
with gr.Accordion("βοΈ Settings", open=False): | |
style = gr.Radio( | |
label="Answer Style", | |
choices=["Concise", "Detailed"], | |
value="Detailed", | |
) | |
model_info = gr.Markdown( | |
f"**Model:** `{MODEL_ID}` (Local Loading) \n" | |
f"**Provider:** `{PROVIDER}` \n" | |
"_(CPU-only loading for stable inference)_" | |
) | |
with gr.Accordion("π Tool Activity (live)", open=True): | |
tool_log_md = gr.Markdown("_No tool activity yet._") | |
with gr.Accordion("π Citations (recent)", open=True): | |
citations_md = gr.Markdown("_No citations captured yet._") | |
async def on_submit(user_msg: str, history_msgs: List[Dict[str, Any]], style_choice: str): | |
"""Handle user message submission and stream response.""" | |
history_msgs = (history_msgs or []) + [{"role": "user", "content": user_msg}] | |
history_msgs.append({"role": "assistant", "content": ""}) | |
yield history_msgs, gr.update(value="_No tool activity yet._"), gr.update(value="_No citations captured yet._") | |
messages_for_llm = to_llm_messages(history_msgs[:-1], user_msg, style_choice) | |
async for chunk in stream_answer(messages_for_llm, MODEL_ID, PROVIDER, HF_TOKEN): | |
delta = chunk.get("delta", "") | |
if delta: | |
history_msgs[-1]["content"] += delta | |
yield history_msgs, gr.update(value=chunk.get("tool_log", "")), gr.update(value=chunk.get("citations", "")) | |
# Wire up event handlers | |
msg.submit(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True) | |
send_btn.click(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True) | |
# ---------------------------- | |
# Launch App | |
# ---------------------------- | |
print(f"π Starting Gradio Docs Chat with Qwen3-30B (Local Loading)") | |
print(f"π Model: {MODEL_ID}") | |
print(f"π MCP Server: {GRADIO_DOCS_MCP_SSE}") | |
demo = demo.queue(max_size=32) | |
demo.launch(ssr_mode=False) |