File size: 18,891 Bytes
e1c5ae4
23672b1
6d7a07a
ccd721b
 
23672b1
6d7a07a
23672b1
 
1c5b9ef
6d7a07a
23672b1
 
e1c5ae4
0102b23
e1c5ae4
cb3e358
 
6d7a07a
0102b23
079d1c0
 
 
 
 
23672b1
 
 
 
 
 
 
0102b23
 
23672b1
 
 
 
 
 
 
 
 
 
 
fdf4fd8
 
23672b1
 
 
 
fdf4fd8
23672b1
fdf4fd8
e1c5ae4
 
 
 
 
 
 
0102b23
8fd3d30
 
23672b1
51bd84d
0102b23
23672b1
 
 
6d7a07a
 
e1c5ae4
6d7a07a
e1c5ae4
6d7a07a
 
 
e1c5ae4
23672b1
e1c5ae4
 
ccd721b
 
e1c5ae4
 
ccd721b
0102b23
6d7a07a
23672b1
d62a2bb
0102b23
23672b1
 
ccd721b
6d7a07a
ccd721b
 
6d7a07a
 
 
23672b1
e1c5ae4
 
23672b1
 
6d7a07a
e1c5ae4
0102b23
ccd721b
079d1c0
ccd721b
a9cd681
ccd721b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
079d1c0
 
 
 
8fd3d30
 
ccd721b
8fd3d30
 
ccd721b
 
 
079d1c0
ccd721b
 
8fd3d30
ccd721b
 
079d1c0
 
 
ccd721b
 
 
8fd3d30
ccd721b
 
 
8fd3d30
ccd721b
 
 
 
 
 
 
 
 
 
 
8fd3d30
 
635f991
 
8fd3d30
 
635f991
8fd3d30
635f991
 
8fd3d30
ccd721b
 
8fd3d30
 
 
ccd721b
cb3e358
a9cd681
8fd3d30
079d1c0
ccd721b
 
 
 
8fd3d30
 
 
ccd721b
 
 
 
 
 
 
635f991
ccd721b
8fd3d30
 
f9d9584
ccd721b
635f991
ccd721b
8fd3d30
635f991
23672b1
 
e1c5ae4
 
 
23672b1
e1c5ae4
 
 
23672b1
 
 
 
6d7a07a
23672b1
 
 
 
e1c5ae4
23672b1
 
 
 
e1c5ae4
 
 
23672b1
e1c5ae4
6d7a07a
23672b1
 
 
 
 
 
1c5b9ef
 
23672b1
1c5b9ef
 
 
6d7a07a
23672b1
6d7a07a
 
23672b1
6d7a07a
 
 
 
 
23672b1
 
1c5b9ef
6d7a07a
23672b1
6d7a07a
 
23672b1
 
6d7a07a
23672b1
6d7a07a
 
 
 
 
 
23672b1
 
 
6d7a07a
 
 
 
 
 
23672b1
6d7a07a
23672b1
 
ccd721b
23672b1
 
ccd721b
 
23672b1
 
 
 
 
 
 
 
 
 
 
ccd721b
23672b1
 
 
 
 
 
 
 
 
 
 
 
 
6d7a07a
 
 
23672b1
6d7a07a
 
 
e1c5ae4
 
23672b1
 
 
 
 
 
 
 
 
 
 
 
e1c5ae4
1c5b9ef
e1c5ae4
 
6d7a07a
e1c5ae4
 
 
6d7a07a
23672b1
 
 
 
 
6d7a07a
e1c5ae4
23672b1
 
 
 
 
6d7a07a
e1c5ae4
23672b1
 
 
 
 
6d7a07a
e1c5ae4
23672b1
6d7a07a
e1c5ae4
23672b1
 
6d7a07a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23672b1
 
6d7a07a
 
 
 
 
 
23672b1
6d7a07a
 
 
 
 
 
e1c5ae4
23672b1
 
 
 
 
 
6d7a07a
e1c5ae4
 
 
23672b1
b7bf9b1
23672b1
0102b23
6d7a07a
23672b1
 
 
 
 
 
0102b23
e1c5ae4
 
 
0102b23
e1c5ae4
23672b1
8fd3d30
e1c5ae4
 
0102b23
6d7a07a
 
 
 
23672b1
6d7a07a
 
 
 
 
 
 
 
0102b23
6d7a07a
e1c5ae4
6d7a07a
 
 
 
 
 
 
23672b1
ccd721b
635f991
8fd3d30
6d7a07a
 
 
 
 
 
 
 
 
23672b1
1c5b9ef
 
23672b1
6d7a07a
 
 
1c5b9ef
ccd721b
6d7a07a
 
 
 
0102b23
23672b1
6d7a07a
 
0102b23
d62a2bb
23672b1
d62a2bb
8fd3d30
ccd721b
23672b1
9ed206f
23672b1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
# app.py
# Hugging Face Space: Gradio Docs Chat with GPT-OSS-20B and MCP Integration
# Features:
#   β€’ GPT-OSS-20B with local transformers loading for fast inference
#   β€’ CPU-only loading to avoid CUDA initialization issues
#   β€’ MCP tool-calling for Gradio docs access
#   β€’ Streaming responses with live tool logs
#   β€’ Optional "Concise / Detailed" answer styles
#   β€’ Citations panel for source tracking
#
# Space secrets needed:
#   - HUGGING_FACE_HUB_TOKEN or HF_TOKEN (for model access)
#   - Optional: CHAT_MODEL, CHAT_PROVIDER, GRADIO_DOCS_MCP_SSE

import os
import asyncio
import json
import time
from typing import Any, Dict, Iterable, List, Optional, Tuple

# Force CPU-only environment to avoid CUDA initialization
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["USE_CUDA"] = "0"
os.environ["USE_GPU"] = "0"

# Load environment variables from .env file if it exists
try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass

import gradio as gr

# Try to import MCPClient with fallback
try:
    from huggingface_hub import MCPClient
    MCP_AVAILABLE = True
except ImportError:
    MCPClient = None
    MCP_AVAILABLE = False
    print("Warning: MCPClient not available. Install huggingface_hub>=0.34.0")

# Optional ZeroGPU shim for Hugging Face Spaces
SPACES_ZERO_GPU = bool(os.environ.get("SPACE_ZERO_GPU", ""))
try:
    import spaces  # type: ignore
    if spaces is not None:
        @spaces.GPU
        def _zero_gpu_probe():
            return "ok"
except Exception:
    pass

# ----------------------------
# Configuration
# ----------------------------
GRADIO_DOCS_MCP_SSE = os.environ.get(
    "GRADIO_DOCS_MCP_SSE",
    "https://gradio-docs-mcp.hf.space/gradio_api/mcp/sse",
)

# Model configuration - local Qwen model loading (more efficient than GPT-OSS-20B)
MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507"
PROVIDER = os.environ.get("CHAT_PROVIDER", "auto")
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")

# Model-specific configuration
USE_GPT_OSS = True

BASE_SYSTEM_PROMPT = (
    "You are a helpful assistant that answers strictly using the Gradio documentation "
    "via the MCP tools provided by the Gradio Docs MCP server. Prefer the latest docs. "
    "Cite relevant class/function names (e.g., gr.Interface) and include short code examples when helpful."
)
CONCISE_SUFFIX = " Keep answers concise (3–6 sentences) unless code is necessary."
DETAILED_SUFFIX = " Provide a detailed, step-by-step answer with short code where helpful."

# ----------------------------
# Model Clients (lazy initialization)
# ----------------------------
mcp_client: Optional[MCPClient] = None
gpt_oss_tokenizer = None
gpt_oss_model = None
_initialized = False
_init_lock = asyncio.Lock()
_model_loading_lock = asyncio.Lock()

def _current_system_prompt(style: str) -> str:
    """Get the system prompt with style suffix."""
    return BASE_SYSTEM_PROMPT + (CONCISE_SUFFIX if style == "Concise" else DETAILED_SUFFIX)

def _reset_clients():
    """Reset all global clients."""
    global mcp_client, gpt_oss_tokenizer, gpt_oss_model, _initialized
    mcp_client = None
    gpt_oss_tokenizer = None
    gpt_oss_model = None
    _initialized = False

def get_mcp_client(model_id: str, provider: str, api_key: Optional[str]) -> MCPClient:
    """Get or create MCP client."""
    global mcp_client
    if mcp_client is None:
        if not MCP_AVAILABLE:
            raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0")
        mcp_client = MCPClient(model=model_id, provider=provider, api_key=api_key)
    return mcp_client

async def get_gpt_oss_model_and_tokenizer():
    """Get or create GPT-OSS-20B model and tokenizer with strict CPU-only loading."""
    global gpt_oss_tokenizer, gpt_oss_model
    
    # Check if already loaded
    if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
        return gpt_oss_tokenizer, gpt_oss_model
    
    # Use lock to prevent multiple simultaneous loads
    async with _model_loading_lock:
        # Double-check after acquiring lock
        if gpt_oss_tokenizer is not None and gpt_oss_model is not None:
            return gpt_oss_tokenizer, gpt_oss_model
        
        try:
            # Import here to avoid CUDA initialization in main process
            import torch
            from transformers import AutoTokenizer, AutoModelForCausalLM
            
            # Force CPU-only torch configuration
            torch.cuda.is_available = lambda: False
            torch.cuda.device_count = lambda: 0
            
            print("πŸ”„ Loading Qwen tokenizer...")
            gpt_oss_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
            
            print("πŸ”„ Loading Qwen model (CPU-only)...")
            # Clean CPU-only loading configuration (no duplicate parameters)
            gpt_oss_model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID,
                torch_dtype=torch.float32,  # Use float32 for CPU compatibility
                device_map=None,  # Don't use device mapping
                trust_remote_code=True,
                low_cpu_mem_usage=True,

            )
            
            # Explicitly move to CPU
            gpt_oss_model = gpt_oss_model.to("cpu")
            
            # Set model to evaluation mode
            gpt_oss_model.eval()
            
            print("βœ… Qwen model loaded successfully on CPU!")
            return gpt_oss_tokenizer, gpt_oss_model
            
        except Exception as e:
            print(f"❌ Failed to load Qwen model: {e}")
            # Reset globals on error
            gpt_oss_tokenizer = None
            gpt_oss_model = None
            raise e

async def generate_with_gpt_oss(messages: List[Dict[str, Any]]) -> str:
    """Generate response using local GPT-OSS-20B model."""
    try:
        # Lazy load model only when needed
        tokenizer, model = await get_gpt_oss_model_and_tokenizer()
        
        # Convert messages to Qwen format
        qwen_messages = []
        for msg in messages:
            if msg["role"] == "system":
                # Convert system message to user message for Qwen
                qwen_messages.append({
                    "role": "user", 
                    "content": f"System: {msg['content']}"
                })
            else:
                qwen_messages.append(msg)
        
        # Apply chat template and generate
        text = tokenizer.apply_chat_template(
            qwen_messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        
        # Generate with timeout protection
        try:
            import torch
            with torch.no_grad():  # Disable gradients for inference
                generated_ids = model.generate(
                    **model_inputs,
                    max_new_tokens=512,  # Reduced from 16384 for faster response
                    do_sample=True,
                    temperature=0.7,
                    pad_token_id=tokenizer.eos_token_id,
                    max_time=60.0,  # 60 second timeout
                )
        except Exception as gen_error:
            raise Exception(f"Generation Error: {str(gen_error)}")
        
        # Decode the generated text
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)
        
        return generated_text
        
    except Exception as e:
        raise Exception(f"Qwen Model Error: {str(e)}")

async def ensure_mcp_init(model_id: str, provider: str, api_key: Optional[str]):
    """Initialize MCP server connection."""
    global _initialized
    if _initialized:
        return
    
    async with _init_lock:
        if _initialized:
            return
        
        if not MCP_AVAILABLE:
            raise ImportError("MCPClient not available. Install huggingface_hub>=0.34.0")
        
        client = get_mcp_client(model_id, provider, api_key)
        result = client.add_mcp_server(
            type="sse",
            url=GRADIO_DOCS_MCP_SSE,
            timeout=45,
        )
        
        # Handle async/sync versions
        if hasattr(result, '__await__'):
            await result
        _initialized = True

# ----------------------------
# Message Processing
# ----------------------------
def to_llm_messages(history_msgs: List[Dict[str, Any]], user_msg: str, style: str) -> List[Dict[str, Any]]:
    """Convert chat history to LLM format with system prompt."""
    msgs = [{"role": "system", "content": _current_system_prompt(style)}]
    
    for msg in history_msgs or []:
        role = msg.get("role")
        content = msg.get("content")
        if role in ("user", "assistant") and isinstance(content, str):
            msgs.append({"role": role, "content": content})
    
    msgs.append({"role": "user", "content": user_msg})
    return msgs

# ----------------------------
# Response Formatting
# ----------------------------
def _append_log(log_lines: List[str], line: str, max_lines: int = 200) -> None:
    """Append log line with size limit."""
    log_lines.append(line)
    if len(log_lines) > max_lines:
        del log_lines[: len(log_lines) - max_lines]

def _format_tool_log(log_lines: List[str]) -> str:
    """Format tool log for display."""
    return "\n".join(log_lines) if log_lines else "_No tool activity yet._"

def _format_citations(cites: List[Tuple[str, Optional[str]]]) -> str:
    """Format citations for display."""
    if not cites:
        return "_No citations captured yet._"
    
    recent = cites[-12:]  # Show recent citations
    lines = []
    for label, url in recent:
        if url:
            lines.append(f"- **{label}** β€” {url}")
        else:
            lines.append(f"- **{label}**")
    return "\n".join(lines)

# ----------------------------
# Response Streaming
# ----------------------------
async def stream_answer(
    messages_for_llm: List[Dict[str, Any]],
    model_id: str,
    provider: str,
    api_key: Optional[str],
) -> Iterable[Dict[str, Any]]:
    """Stream responses from either GPT-OSS-20B or MCP client."""
    tool_log: List[str] = []
    citations: List[Tuple[str, Optional[str]]] = []

    # Handle GPT-OSS-20B via local model
    if USE_GPT_OSS:
        try:
            # Generate response using local model
            generated_text = await generate_with_gpt_oss(messages_for_llm)
            
            # Stream character by character
            for char in generated_text:
                yield {
                    "delta": char,
                    "tool_log": _format_tool_log(tool_log),
                    "citations": _format_citations(citations),
                }
                
        except Exception as e:
            yield {
                "delta": f"❌ {str(e)}",
                "tool_log": _format_tool_log(tool_log),
                "citations": _format_citations(citations),
            }
        return

    # Handle MCP-based models
    if not MCP_AVAILABLE:
        yield {
            "delta": "❌ MCPClient not available. Install huggingface_hub>=0.34.0",
            "tool_log": _format_tool_log(tool_log),
            "citations": _format_citations(citations),
        }
        return

    if not api_key:
        yield {
            "delta": "⚠️ Missing token: set HUGGING_FACE_HUB_TOKEN or HF_TOKEN",
            "tool_log": _format_tool_log(tool_log),
            "citations": _format_citations(citations),
        }
        return

    try:
        await ensure_mcp_init(model_id, provider, api_key)
        client = get_mcp_client(model_id, provider, api_key)
    except Exception as e:
        yield {
            "delta": f"❌ Failed to initialize MCP client: {str(e)}",
            "tool_log": _format_tool_log(tool_log),
            "citations": _format_citations(citations),
        }
        return

    # Stream MCP responses
    try:
        async for chunk in client.process_single_turn_with_tools(messages_for_llm):
            if isinstance(chunk, dict):
                ctype = chunk.get("type")

                if ctype == "tool_log":
                    name = chunk.get("tool", "tool")
                    status = chunk.get("status", "")
                    _append_log(tool_log, f"- {name} **{status}**")
                    yield {
                        "delta": "", 
                        "tool_log": _format_tool_log(tool_log), 
                        "citations": _format_citations(citations)
                    }

                elif ctype == "text_delta":
                    yield {
                        "delta": chunk.get("delta", ""), 
                        "tool_log": _format_tool_log(tool_log), 
                        "citations": _format_citations(citations)
                    }

                elif ctype == "text":
                    yield {
                        "delta": chunk.get("text", ""), 
                        "tool_log": _format_tool_log(tool_log), 
                        "citations": _format_citations(citations)
                    }

                elif ctype == "tool_result":
                    # Process tool results and citations
                    tool_name = chunk.get("tool", "tool")
                    content = chunk.get("content")
                    
                    # Extract citation info
                    url = None
                    if isinstance(content, dict):
                        url = content.get("url") or content.get("link")
                        title = content.get("title") or content.get("name")
                        label = title or tool_name
                    elif isinstance(content, str):
                        label = tool_name
                        if "http://" in content or "https://" in content:
                            start = content.find("http")
                            url = content[start : start + 200].split("\n")[0].strip()
                    else:
                        label = tool_name

                    citations.append((label, url))
                    _append_log(tool_log, f"  β€’ {tool_name} returned result")
                    
                    # Format content snippet
                    snippet = ""
                    if isinstance(content, str):
                        snippet = content.strip()
                        if len(snippet) > 700:
                            snippet = snippet[:700] + "…"
                        snippet = f"\n\n**Result (from {tool_name}):**\n{snippet}"
                    
                    yield {
                        "delta": snippet,
                        "tool_log": _format_tool_log(tool_log),
                        "citations": _format_citations(citations),
                    }

            else:
                # Fallback for plain string responses
                yield {
                    "delta": str(chunk), 
                    "tool_log": _format_tool_log(tool_log), 
                    "citations": _format_citations(citations)
                }

    except Exception as e:
        msg = str(e)
        if "401" in msg or "Unauthorized" in msg:
            err = f"❌ Unauthorized (401). Check your token permissions.\nModel: `{model_id}`\nProvider: `{provider}`"
        elif "404" in msg or "Not Found" in msg:
            err = f"❌ Model not found (404). Model `{model_id}` may not be available via hf-inference."
        else:
            err = f"❌ Error: {msg}"
        
        yield {
            "delta": err, 
            "tool_log": _format_tool_log(tool_log), 
            "citations": _format_citations(citations)
        }

# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(
        "# πŸ€– Gradio Docs Chat\n"
        "Ask anything about **Gradio**. Powered by Qwen3-30B with local transformers loading."
    )

    with gr.Row():
        with gr.Column(scale=7):
            chat = gr.Chatbot(
                label="Gradio Docs Assistant",
                height=520,
                type="messages",
            )
            with gr.Row():
                msg = gr.Textbox(
                    placeholder="e.g., How do I use gr.Interface with multiple inputs?",
                    scale=9,
                    autofocus=True,
                )
                send_btn = gr.Button("Send", scale=1, variant="primary")

            clear = gr.ClearButton(components=[chat, msg], value="Clear")

        with gr.Column(scale=5):
            with gr.Accordion("βš™οΈ Settings", open=False):
                style = gr.Radio(
                    label="Answer Style",
                    choices=["Concise", "Detailed"],
                    value="Detailed",
                )
                model_info = gr.Markdown(
                    f"**Model:** `{MODEL_ID}` (Local Loading)  \n"
                    f"**Provider:** `{PROVIDER}`  \n"
                    "_(CPU-only loading for stable inference)_"
                )

            with gr.Accordion("πŸ›  Tool Activity (live)", open=True):
                tool_log_md = gr.Markdown("_No tool activity yet._")

            with gr.Accordion("πŸ“Ž Citations (recent)", open=True):
                citations_md = gr.Markdown("_No citations captured yet._")

    async def on_submit(user_msg: str, history_msgs: List[Dict[str, Any]], style_choice: str):
        """Handle user message submission and stream response."""
        history_msgs = (history_msgs or []) + [{"role": "user", "content": user_msg}]
        history_msgs.append({"role": "assistant", "content": ""})
        
        yield history_msgs, gr.update(value="_No tool activity yet._"), gr.update(value="_No citations captured yet._")

        messages_for_llm = to_llm_messages(history_msgs[:-1], user_msg, style_choice)

        async for chunk in stream_answer(messages_for_llm, MODEL_ID, PROVIDER, HF_TOKEN):
            delta = chunk.get("delta", "")
            if delta:
                history_msgs[-1]["content"] += delta
            yield history_msgs, gr.update(value=chunk.get("tool_log", "")), gr.update(value=chunk.get("citations", ""))

    # Wire up event handlers
    msg.submit(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)
    send_btn.click(on_submit, inputs=[msg, chat, style], outputs=[chat, tool_log_md, citations_md], queue=True)

# ----------------------------
# Launch App
# ----------------------------
print(f"πŸš€ Starting Gradio Docs Chat with Qwen3-30B (Local Loading)")
print(f"πŸ“ Model: {MODEL_ID}")
print(f"πŸ”— MCP Server: {GRADIO_DOCS_MCP_SSE}")

demo = demo.queue(max_size=32)
demo.launch(ssr_mode=False)